<font color='red'>TODO: double-check that signs and overall iteration formaula is correct</font>

Imposing an exponential prior of the form 

$$\prod_k \exp{-\frac{\left(D_i - D_{i+1}\right)^2}{2\gamma^2}}$$

on the diffusion coefficents

$$D_i := \Delta Q_i^2 \sqrt{p_{i,i+1} p_{i+1,i}}$$

obtained from a Markov State Model with transition matrix $p_{ij}$ (satisfying detailed balance) is equivalent to solving the convex constrained optimisation problem:

$$
\text{maximise} \sum_{i,j} c_{ij} \left(\log{x_{ij}}-\log{x_i}\right) + \sum_{k} - \frac{\left(D_i - D_{i+1}\right)^2}{2 \gamma^2}
$$

subject to 

$$
x_{ij} = x_{ji} \\
\sum_{j} x_{ij} > 0 \\
x_{ij} \geq 0
$$

where $x_{ij} = \pi_i p_{ij}$ and $x_i = \sum_j x_{ij}$. This can be solved by fixed-point iteration:

$$
x_{ij}^{(k+1)} = \frac{c_{ij}+c_{ji}}{\frac{c_i}{x_i^{(k)}}+\frac{c_j}{x_j^{(k)}} + f_{ij}(\gamma, \mathbf{Q} \vert x^{(k)}_{ij})}
$$

where 

$$
f_{ij}(\gamma, \mathbf{Q} \vert x^{(k)}_{ij}) = \frac{1}{\gamma^2}\left(g_{ij} + g_{ji}\right)
$$

and

$$
    g_{ij}(\mathbf{Q}) = \frac{\text{$\Delta $Q}_i^2 \text{$\Delta
   $Q}_{i-1}^2 x_{i-1,i} \delta
   _{i+1,j}}{x_{i-1} x_i^2
   x_{i+1}}-\frac{2 \text{$\Delta $Q}_i^4
   x_{i,i+1} \delta _{i+1,j}}{x_i^2
   x_{i+1}^2}+\frac{\text{$\Delta $Q}_i^2
   \text{$\Delta $Q}_{i+1}^2 x_{i+1,i+2} \delta
   _{i+1,j}}{x_i x_{i+1}^2
   x_{i+2}}+\frac{2 \text{$\Delta $Q}_{i-1}^4
   x_{i-1,i}^2}{x_{i-1}^2
   x_i^3}-\frac{\text{$\Delta $Q}_{i-2}^2
   \text{$\Delta $Q}_{i-1}^2 x_{i-2,i-1}
   x_{i-1,i}}{x_{i-2} x_{i-1}^2
   x_i^2}-\frac{2 \text{$\Delta $Q}_i^2
   \text{$\Delta $Q}_{i-1}^2 x_{i-1,i}
   x_{i,i+1}}{x_{i-1} x_i^3
   x_{i+1}}+\frac{2 \text{$\Delta $Q}_i^4
   x_{i,i+1}^2}{x_i^3
   x_{i+1}^2}-\frac{\text{$\Delta $Q}_i^2
   \text{$\Delta $Q}_{i+1}^2 x_{i,i+1}
   x_{i+1,i+2}}{x_i^2 x_{i+1}^2
   x_{i+2}}.
$$

The inequality constraints 

$$
\sum_{j} x_{ij} > 0 \\
x_{ij} \geq 0
$$

and symmetry of $x_{ij}$ are guaranteed when using a suitable starting iterate e.g.

$$
x_{ij}^{(0)} = \frac{c_{ij}+c_{ji}}{\sum_{i,j} c_{ij} + c_{ji}}.
$$

In [None]:
import numpy as np 
import matplotlib.pyplot as plt

In [None]:
Q = np.array([0, 0.1, 0.2, 0.3])
C = np.array([[1000, 50, 20, 10],[48, 50, 1, 1],[1, 1, 600, 40],[9, 7, 33, 300]])

X0 = (C + C.T)/(2*np.sum(C))

gamma = 0.01

print("Q", Q)
print("C", C)
print("X0", X0)
print("gamma", gamma)

In [None]:
def dQ2(Q, i):
    return (Q[i+1] - Q[i])**2 

def compute_G(Q, X, x):
    G = np.zeros(X.shape, dtype='float64')
    for i in range(X.shape[0]):
        for j in range(X.shape[1]):
            if i > 0:
                G[i][j] += 2*dQ2(Q,i-1)**2 *X[i-1][i]**2/(x[i-1]**2 * x[i]**3)
                
            if i > 1:
                G[i][j] -= dQ2(Q,i-2)*dQ2(Q,i-1)*X[i-2][i-1]/(x[i-2] * x[i-1]**2 * x[i]**2)
            
            if i < X.shape[0]-1:
                if i+1 == j:
                    G[i][j] -= 2*dQ2(Q,i)**2 *X[i][i+1]/(x[i]**2 * x[i+1]**2) 
                G[i][j] += 2*dQ2(Q,i)**2 *X[i][i+1]**2/(x[i]**3 * x[i+1]**2)

                
            if i < X.shape[0]-2:
                if i+1 == j:
                    G[i][j] += dQ2(Q,i)*dQ2(Q,i+1)*X[i+1][i+2]/(x[i] * x[i+1]**2 * x[i+2]) 
                G[i][j] -= dQ2(Q,i)*dQ2(Q,i+1)*X[i][i+1]*X[i+1][i+2]/(x[i]**2 * x[i+1]**2 * x[i+2])
                
            
            if 0 < i < X.shape[0]-1:
                if i+1 == j:
                    G[i][j] += dQ2(Q,i)*dQ2(Q,i-1)*X[i-1][i]/(x[i-1] * x[i]**2 * x[i+1]) 
                G[i][j] -= 2*dQ2(Q,i)*dQ2(Q,i-1)*X[i-1][i]*X[i][i+1]/(x[i-1] * x[i]**3 * x[i+1])
                
    return G


def compute_F(Q, X, x, gamma):
    G = compute_G(Q, X, x)
    min_deltaQ2 = min([dQ2(Q,i) for i in range(len(Q)-1)])
    
    return 1/(min_deltaQ2 * gamma)**2 * G * G.T

In [None]:
def update_X(old_X, C, Q, gamma):
    new_X = np.zeros(old_X.shape, dtype='float64')
    old_x = np.sum(old_X, axis=1, dtype='float64')
    F = compute_F(Q, old_X, old_x, gamma)
    c = np.sum(C, axis=1)
    
    for i in range(old_X.shape[0]):
        for j in range(old_X.shape[1]):
            new_X[i][j] = (C[i][j]+C[j][i])/(c[i]/old_x[i] + c[j]/old_x[j] + F[i][j])
            
    return new_X

def compute_error(old_X, new_X):
    return np.linalg.norm(np.sum(new_X, axis=1) - np.sum(old_X, axis=1))

def fit_markov_state_model(counts, coordinates, gamma, error):
    old_X = np.random.uniform(0,1,size=counts.shape)  #(counts + counts.T)/(2*np.sum(counts)) 
    old_X = old_X + old_X.T
    current_err = float('inf')
    
    iterations = 0
    while current_err > error:
        new_X = update_X(old_X, counts, coordinates, gamma)
        current_err = compute_error(old_X, new_X)
        old_X = new_X
        iterations += 1
        
    stationary_distribution = np.sum(old_X, axis=1)
    transition_matrix = (old_X.T/stationary_distribution).T
    stationary_distribution /= np.sum(stationary_distribution)
    
    print(f'Finished in {iterations} iteration(s). Error {round(current_err,7)}.')
        
    return stationary_distribution, transition_matrix

In [None]:
def compute_diffusion_coefficients(X, rate_matrix):
    D = np.zeros(len(X)-1)
    Xd = np.zeros(len(X)-1)
    for i in range(len(X)-1):
        D[i] = ((X[i+1]-X[i])**2) * np.sqrt(rate_matrix[i][i+1]*rate_matrix[i+1][i])
        Xd[i] = 0.5 * (X[i] + X[i+1]) 
        
    return D, Xd

for gamma in [10, 0.1, 0.01, 0.001, 0.00005]:
    stationary_distribution, transition_matrix = fit_markov_state_model(counts=C, coordinates=Q, gamma=gamma, error=0.01)

    fig = plt.figure(figsize=(7,7))
    plt.imshow(transition_matrix)
    plt.title(r'$p_{ij}$', fontsize=16)
    plt.xlabel('j', fontsize=16)
    plt.ylabel('i', fontsize=16)
    plt.colorbar()
    plt.show()

    plt.plot(stationary_distribution)
    plt.xlabel('i', fontsize=16)
    plt.ylabel(r'$\pi(i)$', fontsize=16)

    print(stationary_distribution)
    print(np.sum(stationary_distribution))
    
    plt.show()
    D = compute_diffusion_coefficients(Q, transition_matrix)
    plt.plot(D[0])