# Markov State Models
must be run in config.py validation_mode = False and in environment md_sims

In [4]:
import torch
import matplotlib.pyplot as plt
from global_utils import CommittorNet, mpath, cnmsam
from sklearn.cluster import DBSCAN
import numpy as np
from config import V
import json

  X, Y = torch.meshgrid(torch.tensor(x), torch.tensor(y), indexing='ij')


In [None]:
run_name = "wells_linear_5_a0_b4_K2" # this is the awful training run on the 5 Linear Wells show in the presentation
dim = 2
K = 2

In [None]:
net2 = CommittorNet(dim=dim,K=K)
net2.load_state_dict(torch.load(mpath(f"{run_name}.pt")))
running_xs_all = torch.load(mpath(f"{run_name}_rxs.pt")).float()
out_means_k = torch.load(mpath(f"{run_name}_k.pt"))
centers_k = torch.load(mpath(f"{run_name}_ctrs.pt"))

In [None]:
with torch.no_grad():
    q_nn = cnmsam(net2,running_xs_all,torch.ones(K)).detach().cpu().numpy()
running_xs = running_xs_all.cpu().numpy()

In [None]:
print(q_nn.shape)
np.sum(q_nn,axis=1) # for K = 2, this should be all ones
q_A = q_nn[:,0].squeeze()

In [None]:
db = DBSCAN(eps=0.01, min_samples=10)
labels = db.fit_predict(q_A.reshape([-1,1]))
unique_labels = set(labels)
colors = plt.cm.get_cmap('tab10', len(unique_labels))
print(unique_labels)

In [None]:
k_centers_out = []
plt.figure(figsize=(6,6))
for k in unique_labels:
    class_member_mask = (labels == k)
    q_k = q_A[class_member_mask]
    xy_k = running_xs_all[class_member_mask]
    V_xy_k = V(xy_k)
    xy_ck = xy_k[torch.argmin(V_xy_k)]
    
    if k == -1:
        # Noise
        plt.scatter(xy_k[:,0], q_k, c='k', s=10, label='noise')
    else:
        k_centers_out.append(xy_ck)
        k_mean = np.mean(q_k)
        print(k,xy_ck, k_mean)
        plt.scatter(xy_k[:,0],q_k, c=[colors(k)], s=10, label=f'Cluster {k}')
        plt.scatter(xy_ck[0],k_mean,c=colors(k),marker="*",s=300)
plt.ylabel(r"$q_A$ Committor",fontsize=16)
plt.xlabel("X Coordinate",fontsize=16)
plt.title("Bad Rate Estimator, Good Metastability Estimator",fontsize=18)
plt.legend(fontsize=14)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.show()

print(torch.stack(k_centers_out))

## Markov State Model

In [5]:
run_name = "wells_linear_5_a0_b4_K5"

In [22]:
with open(mpath(run_name + "_exits.json"), encoding="utf-8") as f:
    transfers_dict = json.load(f)
transfers = [transfers_dict[key] for key in transfers_dict.keys()]
rates = torch.load(mpath(run_name+ "_k.pt")).cpu().numpy()[-1,:]

In [21]:
rates[-1,:]

array([2.34161626e-05, 1.53143666e-05, 3.93428392e-03, 1.47087876e-03,
       1.53457274e-03])

In [13]:
rates = [3.52506437e-05, 3.70972579e-05, 2.27186130e-03, 6.92297970e-04, 2.00196078e-03]
# transfers = [[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2], [4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4], [0, 0, 3, 3, 3, 3, 3, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, 3, 3, 0, 0, 0, 3, 0, 0, 0, 0, 3, 0, 3, 0, 0, 0, 0, 0, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0], [2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 4, 2, 2, 2, 2, 2, 4, 4, 4, 4, 2, 4, 2, 2, 2, 4, 2, 4, 4, 4, 2, 4, 2, 2], [1, 1, 1, 1, 1, 3, 3, 1, 3, 3, 1, 1, 1, 3, 1, 3, 1, 1, 3, 1, 3, 1, 3, 3, 1, 1, 1, 1, 1, 1, 3, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 3, 1, 1, 3, 3, 1]]

In [23]:
import numpy as np

def build_Q(exit_rates, exit_indices):
    K = len(exit_rates)
    Q = np.zeros((K, K))
    for i, dests in enumerate(exit_indices):
        n = len(dests)
        if n:
            counts = np.bincount(dests, minlength=K)
            Q[i] = exit_rates[i] * counts / n
        Q[i, i] = -exit_rates[i]
    return Q

def mfpt_all_pairs(Q):
    K = Q.shape[0]
    M = np.zeros((K, K))
    for j in range(K):
        A = Q.copy()
        b = -np.ones(K)
        A[j, :] = 0
        A[j, j] = 1
        b[j] = 0
        M[:, j] = np.linalg.solve(A, b)
    return M


In [24]:
Q_run = build_Q(rates, transfers)
print(Q_run)

[[-2.34161626e-05  0.00000000e+00  2.34161626e-05  0.00000000e+00
   0.00000000e+00]
 [ 0.00000000e+00 -1.53143666e-05  0.00000000e+00  0.00000000e+00
   1.53143666e-05]
 [ 2.74207667e-03  0.00000000e+00 -3.93428392e-03  1.19220725e-03
   0.00000000e+00]
 [ 0.00000000e+00  0.00000000e+00  8.27369300e-04 -1.47087876e-03
   6.43509456e-04]
 [ 0.00000000e+00  9.59107965e-04  0.00000000e+00  5.75464779e-04
  -1.53457274e-03]]


In [25]:
Q_run.sum(axis=1) # should be 0s

array([ 0.00000000e+00,  0.00000000e+00, -2.16840434e-19,  0.00000000e+00,
        0.00000000e+00])

In [26]:
mfpt = mfpt_all_pairs(Q_run).round()
mfpt

array([[     0., 349080.,  42706., 141767., 270686.],
       [301352.,      0., 263072., 175866.,  65298.],
       [ 38280., 306374.,      0.,  99062., 227980.],
       [125486., 207313.,  87206.,      0., 128919.],
       [236054.,  78394., 197774., 110568.,      0.]])

In [27]:
print("effective rate 0->1:",f"{1/mfpt[0,1]:.2e}")
print("effective rate 1->0:",f"{1/mfpt[1,0]:.2e}")

effective rate 0->1: 2.86e-06
effective rate 1->0: 3.32e-06


Very good agreement forward and backward!