In [None]:
# AIR: use the mdanalysis environment

import pandas as pd
import numpy as np
import deeptime as dt
import matplotlib.pyplot as plt
import hexagonal_grid
from tqdm import tqdm
import os

In [None]:
%cd {os.environ['HOME']}/Sync/work_in_progress/sh2-som

In [None]:
d=pd.read_table("SOM.neuron.classification.dat.xz")
d["RFrame"]=d.groupby(['Replica']).cumcount()
d["State"]=d["Neuron.classif"]-1

# Fix: RFrame is the replica's frame
#      State is the 0-based neuron assignment

In [None]:
# Split by replica. Make each replica's state list a list element.
# traj_list is therefore a list of numpy arrays holding 0-based states

dg=d.groupby("Replica")
traj_list=[dg.get_group(x).State.to_numpy() for x in dg.groups]

In [None]:
models = []
#lagtimes = [1,10,20,30,60,100,200,300]
lagtimes = np.arange(1, 400, 20)
for lagtime in tqdm(lagtimes):
    counts = dt.markov.TransitionCountEstimator(lagtime=lagtime, count_mode='sliding').fit_fetch(traj_list)
    #mod=dt.markov.msm.BayesianMSM(n_samples=50).fit(counts,ignore_counting_mode=True).fetch()
    mod=dt.markov.msm.MaximumLikelihoodMSM().fit_fetch(counts)
    models.append(mod)
    #models.append()


In [None]:
its_data = dt.util.validation.implied_timescales(models)

fig, ax = plt.subplots(1, 1)
dt.plots.plot_implied_timescales(its_data, n_its=12, ax=ax)
#ax.set_yscale('log')
ax.set_title('Implied timescales')
ax.set_xlabel('lag time (steps)')
ax.set_ylabel('timescale (steps)')



In [None]:
hgrid = hexagonal_grid.hexagonal_grid(20,20)

In [None]:
chosen_model = hexagonal_grid.index_of_closest(lagtimes, 250)
chosen_model

In [None]:
pccamodel = models[chosen_model]
n_macro=10
pcca = pccamodel.pcca(n_metastable_sets=n_macro)


In [None]:
for m in range(n_macro):
    hexagonal_grid.state_plot(hgrid, pcca.memberships[:,m])
    plt.show()

In [None]:
print(pcca.coarse_grained_stationary_probability)
plt.bar(range(n_macro),
        -.6*np.log(pcca.coarse_grained_stationary_probability))
plt.ylabel("dG (kcal/mol)")
plt.xlabel("Macrostate")

In [None]:
hexagonal_grid.state_plot(hgrid, pcca.assignments)
plt.colorbar()