In [None]:
import pandas as pd
import numpy as np
import deeptime as dt
import matplotlib.pyplot as plt
import hexagonal_grid
from tqdm import tqdm



In [None]:
d=pd.read_table("SOM.neuron.classification.dat.xz")
d["RFrame"]=d.groupby(['Replica']).cumcount()
d["State"]=d["Neuron.classif"]-1

In [None]:
dg=d.groupby("Replica")
traj_list=[dg.get_group(x).State.to_numpy() for x in dg.groups]

In [None]:
models = []
lagtimes = [1,10,20,30,60,100,200,300]
for lagtime in tqdm(lagtimes):
    counts = dt.markov.TransitionCountEstimator(lagtime=lagtime, count_mode='sliding').fit_fetch(traj_list)
    #mod=dt.markov.msm.BayesianMSM(n_samples=50).fit(counts,ignore_counting_mode=True).fetch()
    mod=dt.markov.msm.MaximumLikelihoodMSM().fit_fetch(counts)
    models.append(mod)
    #models.append()


In [None]:
its_data = dt.util.validation.implied_timescales(models)

fig, ax = plt.subplots(1, 1)
dt.plots.plot_implied_timescales(its_data, n_its=12, ax=ax)
#ax.set_yscale('log')
ax.set_title('Implied timescales')
ax.set_xlabel('lag time (steps)')
ax.set_ylabel('timescale (steps)')



In [None]:
pccamodel = models[2]
pcca = pccamodel.pcca(n_metastable_sets=6)


In [None]:
pcca.memberships

In [None]:
hgrid = hexagonal_grid.hexagonal_grid(20,20)

In [None]:
plt.scatter(hgrid[:,0],hgrid[:,1],s=100,c=pcca.memberships[:,0])

In [None]:
plt.plot(pcca.memberships[:,0])

In [None]:
np.where(pcca.memberships[:,0]>.5)

In [None]:
plt.scatter(hgrid[:,0],hgrid[:,1],s=100,c=pcca.memberships[:,5])
plt.colorbar()

In [None]:
pcca.coarse_grained_stationary_probability

In [None]:
hexagonal_grid.state_plot(hgrid, pcca.assignments)
plt.colorbar()