# Estimating a dynamic random field using the graphtime library

Simon Olsson 2018, simon.olsson / at / fu-berlin.de or [@smnlsssn](http://www.twitter.com/smnlssn)

In this notebook we illustrate the use of the graphtime library to estimate dMRFs. The notebook will reproduce Fig 2, from our manuscript.

We cover:
 - the generation of simulation data, here with the Ising model
 - the estimation of dMRFs 
 - rudimentary visualization of results
 
Imports and some function definitions

In [None]:
%matplotlib inline
from matplotlib import pyplot as plt
import matplotlib as mpl
import numpy as np
import matplotlib.gridspec as gridspec

from sklearn.preprocessing import LabelBinarizer

font = {'sans-serif': "Arial",
        'family': "sans-serif",
    'size'   : 8}

mpl.rc('font', **font)

from graphtime import markov_random_fields
from graphtime import utils as _ut
from graphtime import ising_utils as ising

import msmtools as mt
import pyemma as pe

In [None]:
is_positive_state = lambda x:x.mean(axis=1)>-0.01

def generate_biased_data(Tmat, ntrajs, subsys_configurations, 
                         maxlen=1000, minlen=10, truncation_condition = is_positive_state):
    """
        Generate biased data-set where trajectories are truncated if `truncation_condition` of future state 
        evaluates to True.
        
        Arguments:
        ---------------
        Tmat (ndarray) : N times N matrix of transition probabilities (a Markov state model)
        ntrajs (int) : number of trajectories to generate
        subsys_configurations (ndarray) : N times K matrix mapping each Markov state in `Tmat` to sub-system 
                                            configurations
        maxlen (int=1000) : maximum length (steps) of a trajectory
        minlen (int=10) : minimum length (step) of a trajectory
        truncation_condition (function) : a callable which takes an K dimensional `ndarray` and returns a boolean.
        
    """
    otrajs = []
    gtrajs = 0
    while gtrajs<ntrajs:
        tmptraj = _ut.simulate_MSM(Tmat, maxlen)    
        truct_spots = np.where(truncation_condition(subsys_configurations[tmptraj]))[0]#[0]-1
        if len(truct_spots)>0:
            trunc_pos = truct_spots[0]-1
            if trunc_pos>minlen:
                otrajs.append(subsys_configurations[tmptraj[:trunc_pos]])
                gtrajs=gtrajs+1
        else:
            otrajs.append(subsys_configurations[tmptraj[:]])
            gtrajs=gtrajs+1
    
    return otrajs

def featurize(X):
    dts=[]
    for x in X.copy():
        x[np.where(x==-1)] = 0
        dts.append ([int(''.join(map(str, f)),2) for f in x])
    return dts

Initiate reference Ising model

In [None]:
# number of spins (sub-systems) of Ising model to simulate
nspin = 9

# generate reference transition matrix (discretized Glauber rate model)
ising_tmatrix = ising.Ising_tmatrix(nspin)

# generate map from Markov state to Ising system configuration, global to local encoding [-1,1]
ising_configurations = np.array(ising.all_Ising_states(nspin))


## Generate data
Generates 16 data-sets of equilibrium and non-equilibrium data. The non-equilibrium data only see global states with a net-negative Ising magnetization. The Ising magnetization is simply the average of the local state encoding. Conversely, for the the equilibrium data all states are in principle allowed. All simulations are initialized in an all negative configuration.  

In [None]:
noneq_data = [generate_biased_data(ising_tmatrix, 300, ising_configurations, maxlen=2000) for i in range(16)]

nframes_neqd = [np.vstack(bd).shape[0]-2*(len(noneq_data)-1) for bd in noneq_data]

eq_data = [[ising_configurations[_ut.simulate_MSM(ising_tmatrix, nf)]] for nf in  nframes_neqd]

## Estimate models

Estimate dMRFs for equilibrium data-sets

In [None]:
eq_data_dmrfs_no_intercepts = []
eq_data_dmrfs = []

logreg_kwargs_no_intercepts = {'fit_intercept': False, 
                   'penalty': 'l1', 'C': 50., 'tol': 1e-4, 'solver': 'saga'}

logreg_kwargs_intercepts = {'fit_intercept': True, 
                   'penalty': 'l1', 'C': 50., 'tol': 1e-4, 'solver': 'saga'}

for _gen_data in eq_data:
    #estimate model forcing fields (bias) to zero
    eq_data_dmrfs_no_intercepts.append(markov_random_fields.estimate_dMRF(_gen_data, lag = 1, 
                                                      logistic_regression_kwargs = logreg_kwargs_no_intercepts,
                                                      Encoder = LabelBinarizer(neg_label = -1,
                                                                             pos_label = 1)))
    
    #estimate model with fields (bias)
    eq_data_dmrfs.append(markov_random_fields.estimate_dMRF(_gen_data, lag = 1, 
                                                      logistic_regression_kwargs = logreg_kwargs_intercepts,
                                                      Encoder = LabelBinarizer(neg_label = -1,
                                                                             pos_label = 1)))


Estimate dMRFs for non-equilibrium data-sets

In [None]:
neq_data_dmrfs_no_intercepts = []
neq_data_dmrfs = []

for bd in noneq_data: 
    #estimate model forcing fields (bias) to zero
    neq_data_dmrfs_no_intercepts.append(markov_random_fields.estimate_dMRF(bd, lag = 1, 
                                              logistic_regression_kwargs = logreg_kwargs_no_intercepts,
                                              Encoder = LabelBinarizer(neg_label = -1,
                                                                     pos_label = 1)))
    #estimate model with fields (bias)
    neq_data_dmrfs.append(markov_random_fields.estimate_dMRF(bd, lag = 1, 
                                                  logistic_regression_kwargs = logreg_kwargs_intercepts,
                                                  Encoder = LabelBinarizer(neg_label = -1,
                                                                             pos_label = 1)))

## Reconstruct transition matrices from estimated dMRFs

In [None]:
tmats_eq_no_intercepts = [_dmrf.generate_transition_matrix() for _dmrf in eq_data_dmrfs_no_intercepts]
tmats_eq = [_dmrf.generate_transition_matrix() for _dmrf in eq_data_dmrfs]

tmats_neq_no_intercepts = [_dmrf.generate_transition_matrix() for _dmrf in neq_data_dmrfs_no_intercepts]
tmats_neq = [_dmrf.generate_transition_matrix() for _dmrf in neq_data_dmrfs]


In [None]:
ts_eq = [mt.analysis.timescales(t)[1:4] for t in tmats_eq_no_intercepts]
ts_neq = [mt.analysis.timescales(t)[1:4] for t in tmats_neq_no_intercepts] 

In [None]:
MSM = [pe.msm.estimate_markov_model(featurize(bd.copy()), lag = 1, reversible = False) for bd in noneq_data]

## Build Figure 1

In [None]:
double_column_width = 6.968
single_column_width= 3.307

fig = plt.figure(figsize=(single_column_width, 1.6*single_column_width))

gs = gridspec.GridSpec(170, 100)
ax = np.array([[plt.subplot(gs[70+i*50:70+(i+1)*50, j*50:(j+1)*50]) for i in range(2)] for j in range(2)]).T

ax[0,0].grid('on')
ax[0,0].plot([-0.5,4], [-0.5,4], color='k', ls = '--', lw=0.5)
ax[0,0].errorbar(np.mean([_dmrf.get_subsystem_couplings().ravel() for _dmrf in neq_data_dmrfs_no_intercepts], axis=0), 
            np.mean([_dmrf.get_subsystem_couplings().ravel() for _dmrf in eq_data_dmrfs_no_intercepts], axis=0),
            xerr=np.std([_dmrf.get_subsystem_couplings().ravel() for _dmrf in neq_data_dmrfs_no_intercepts], axis=0),
            yerr=np.std([_dmrf.get_subsystem_couplings().ravel() for _dmrf in eq_data_dmrfs_no_intercepts], axis=0), fmt=".")
ax[0,0].set_xlim([-0.5,4])
ax[0,0].set_ylim([-0.5,4])
ax[0,0].annotate("Self coupl.", 
             xy=(np.array(neq_data_dmrfs_no_intercepts[0].get_subsystem_couplings()).max(), 
                 np.array(eq_data_dmrfs_no_intercepts[0].get_subsystem_couplings()).max()), 
             xytext=(0, 0.8*np.array(neq_data_dmrfs_no_intercepts[0].get_subsystem_couplings()).max()),
             arrowprops=dict(arrowstyle="->"))
ax[0,0].set_xlabel(r'NED $J_{ij}$', fontsize=8)
ax[0,0].set_ylabel(r'ED $J_{ij}$', fontsize=8)


ax[0,1].hist(ising_configurations.mean(axis=1), 
             weights = mt.analysis.statdist(np.mean(tmats_eq_no_intercepts, axis=0)), 
             normed=True, histtype='step', bins=10, log = False, label='dMRF ED')

ax[0,1].hist(ising_configurations.mean(axis=1), 
             weights = mt.analysis.statdist(np.mean(tmats_neq_no_intercepts, axis=0)), 
             normed=True, histtype='step', bins=10, log = False, label='dMRF NED')

ax[0,1].hist(np.concatenate([ising_configurations.mean(axis=1)[m.active_set] for m in MSM]), 
             weights = np.concatenate([m.stationary_distribution for m in MSM]), 
             normed=True, histtype='step', bins=5, log = False, label='MSM NED')

ax[0,1].hist(ising_configurations.mean(axis=1), 
             weights = mt.analysis.statdist(ising_tmatrix), 
             normed=True, histtype='step', bins=10, log = False, label='True')

ax[0,1].hist(2*np.vstack([np.vstack(bd) for bd in bd]).mean(axis=1)-1, 
             normed=True, histtype='step', bins=5, log = False, label='NED', ls='--')

ax[0,1].hist(np.vstack(_gen_data).mean(axis=1), 
             normed=True, histtype='step', bins=10, log = True, label='ED', ls='--')

ax[0,1].set_ylim([2e-1, 5])
ax[0,1].set_xlabel(r'$\langle M \rangle$', fontsize=8)
ax[0,1].set_ylabel(r'$p(\langle M \rangle$)', fontsize=8)

avgp = np.mean(tmats_eq_no_intercepts, axis=0).ravel()
lo,up = mt.util.statistics.confidence_interval([d.ravel() for d in tmats_eq_no_intercepts])

ax[1,0].errorbar(ising_tmatrix.ravel(), 
                avgp,
                yerr=(avgp - lo, up - avgp),
                fmt='.')

ax[1,0].plot([-0,1],[-0,1], lw=0.5, ls='--', color='k')
ax[1,0].set_xlim([-0,1])
ax[1,0].set_ylim([-0,1])
ax[1,0].set_xlabel(r'True $T_{ij}$', fontsize=8 )
ax[1,0].set_ylabel(r'NED dMRF $T_{ij}$', fontsize=8 )



w=0.27
avgp = np.mean(ts_eq, axis=0)
lo,up = mt.util.statistics.confidence_interval([d for d in ts_eq])
ax[1, 1].bar(np.arange(1, 4) - w, avgp, label = 'ED dMRF', yerr = (avgp - lo, up - avgp),
            width = w, fill = False, edgecolor = "C0", ecolor = "C0")

avgp = np.mean(ts_neq, axis=0)
lo,up = mt.util.statistics.confidence_interval([d for d in ts_neq])
ax[1, 1].bar(np.arange(1, 4), avgp, label = 'NED dMRF', yerr = (avgp - lo, up - avgp), 
            width = w, fill = False, edgecolor = "C1", ecolor = "C1" )

avgp = np.mean([m.timescales(k = 3) for m in MSM], axis = 0)
lo,up = mt.util.statistics.confidence_interval([d for d in np.mean([m.timescales(k = 3) for m in MSM], axis = 0)])
ax[1, 1].bar(np.arange(1, 4) + w,avgp, yerr = (avgp - lo, up - avgp), label = 'NED MSM' , 
            width = w, fill = False, edgecolor = "C2", ecolor = "C2")


for _i,_ts in enumerate(mt.analysis.timescales(ising_tmatrix)[1:4]):
    ax[1,1].hlines(_ts, (_i+1)-w, (_i+1)+w , label='True', color='red')
    
ax[1, 1].set_xlabel('Process', fontsize=8)
ax[1, 1].set_ylabel('Time-scale / step', fontsize=8)
ax[1, 1].set_xticks(range(1,9))
ax[1, 1].set_xlim([1-w*2.2,3+w*2.2])

#ax[1,1].xaxis.grid('on')

#ax[1,1].legend()

for a,lbl in zip(ax.ravel(), ('B', 'C', 'D', 'E')):
    a.text(-0.5, 1.15, lbl, transform=a.transAxes,
      fontsize=12, va='top')
legend_ax = plt.subplot(gs[60:70,:])
legend_ax.legend([child for child in ax[0,1].get_children() if isinstance(child, mpl.patches.Polygon)],
                 [child.get_label() for child in ax[0,1].get_children() if isinstance(child, mpl.patches.Polygon)]
                 ,loc=(0.12,1.1), ncol=2)
legend_ax.axis('off')
ax2 = [] 
for i in range(2):
    if i==0:
        ax2.append(plt.subplot(gs[:50, i*50:(i+1)*50]))
    else:
        ax2.append(plt.subplot(gs[:50, i*50:(i+1)*50], sharey=ax2[0]))
    
ax2[0].plot(eq_data[0][0].mean(axis=1), color='darkorchid', alpha=1, lw=0.5)
ax2[0].set_ylabel(r'$\langle M \rangle$')
ax2[0].set_xlabel(r'time / step')
ax2[1].set_xlabel(r'time / step')
ax2[0].set_title('Equilibrium data')
ax2[1].set_title('Non-equilibrium data')
[ax2[1].plot(2*bd.mean(axis=1)-1., color='darkorchid', alpha=0.1, lw=0.5) for bd in noneq_data[0] ]
ax2[1].plot([0,1000], [-0.111,-0.111], color='k', ls=':')
ax2[1].annotate("Max. net mag.", xy=(80, -0.111), xytext=(80,0.2),
            arrowprops=dict(arrowstyle="->"))
ax2[1].set_xlim((-20,1000))
ax2[1].set_ylim(ax2[0].get_ylim())


ax2[0].text(-0.5, 1.11, "A", transform=ax2[0].transAxes,
      fontsize=12, va='top')

fig.tight_layout()
#fig.savefig('Fig1.pdf')

In [None]:
fig.savefig('Fig1.png', dpi=300)