In [2]:
import sklearn.preprocessing as pre, scipy, numpy as np, matplotlib.pyplot as plt, glob, sys, os
import pandas as pd, seaborn as sns, argparse, pyemma as py, pickle, copy
from sklearn.model_selection import train_test_split

import warnings
warnings.filterwarnings('ignore')

from hde import HDE, analysis

sys.path.insert(1, '../msms/')
from utils import * 

Using TensorFlow backend.


In [None]:
tstep = 100 #ps

# leave out first skip frames
max_frames = 100000
max_trajs = 25
max_epochs = 10  # for training SRVs

lag = 10 # converged for srv, tica not quite
dim = 5 #  3 captures most pertinent modes 

cluster_stride = 100
n_clustercenters = 200
nits, its_lags = 5, 100
msm_lag = lag

committer_quota = 0.25

for seq in ['CGCATATATAT', 'CCTATATATCC', 'TATAGCGATAT', 'TTTTTTTTTTT']:

    base_list = [base for base in abasic_configs[seq].keys()]
    
    SRV_list = []
    TICA_list = []
    srv_indv_list = []
    tica_indv_list = []
    comb_dists = []
    all_dists = []

    # fit each individually
    for base in base_list:

        # combine first and second run
        npy_name = glob.glob(f'../abasic_dists/{seq}_msm_add-Tms/{base}*')[0]
        dists = np.load(npy_name)[:max_trajs, -max_frames:, :]
        all_dists.append(dists)

        # get common basis and combine for later
        common_idxs = translate_pairs(len(seq), base)
        #comb_dists += [d[:, common_idxs] for d in dists]

        SRV = fit_SRV([1/d for d in dists], dim, max_epochs, lag)
        srv_basis = [SRV.transform(1/d) for d in dists]
        srv_indv_list.append(srv_basis)
        SRV_list.append(SRV)

        TICA = py.coordinates.tica([1/d for d in dists], dim=dim, lag=lag)
        tica_basis = TICA.get_output()
        tica_indv_list.append(tica_basis)
        TICA_list.append(TICA)
        #print(base, dists.shape)

    committor_list = []
    dtraj_list = []
    cluster_list = []
    its_list = []
    msm_list = []
    H_list, D_list = [], []

    fig, axes = plt.subplots(4, figsize=(10, 5), sharey=True)
    for base, base_dists, basis, ax in zip(base_list, all_dists, srv_indv_list, axes):

        print(base, np.shape(base_dists))
        cluster = py.coordinates.cluster_kmeans(basis, 
                    stride=cluster_stride, k=n_clustercenters, max_iter=50)
        cluster_list.append(cluster)

        # construct two-state pcca and determine microstate probabilities
        dtraj = cluster.dtrajs
        msm = py.msm.bayesian_markov_model(dtraj, lag=msm_lag)
        pcca = msm.pcca(2)
        pi = msm.stationary_distribution
        sorted_micros = np.argsort(msm.metastable_memberships[:, 0])

        # add microstates to each bin until quota is full
        A, pi_sum_A = [], 0
        for i in sorted_micros:
            A.append(i)
            pi_sum_A += pi[i]
            if pi_sum_A > committer_quota:
                break#'CGCATATATAT'

        # repeat in reverse direction for second bin
        B, pi_sum_B = [], 0
        for i in reversed(sorted_micros):
            B.append(i)
            pi_sum_B += pi[i]
            if pi_sum_B > committer_quota:
                break

        print(A, pi_sum_A)
        print(B, pi_sum_B)

        # determine which bin is H and D
        base_dists = np.concatenate(base_dists)
        dtraj = np.concatenate(dtraj)
        A_dtraj = np.isin(dtraj, A)
        A_mean = np.mean(base_dists[A_dtraj])
        if A_mean < 2: H, D = A, B
        else: D, H = A, B

        # evaluate committers for all intermediate micros and plots
        qf = msm.committor_forward(H, D)
        qb = msm.committor_backward(H, D)
        committor_list.append(qb)
        ax.scatter(qf, pi)
        ax.set_title(base)

        its = py.msm.its(dtraj, lags=its_lags, nits=nits)
        dtraj_list.append(dtraj)
        its_list.append(its)
        msm_list.append(msm)
        H_list.append(H)
        D_list.append(D)

    plt.subplots_adjust(hspace=(0.35))
    
    # add in low temp data and cluster in same set of microstates defined above
    skip_low = 2001
    dtraj_low_list = []

    # plot to look for differences in the bases
    fig, axes = plt.subplots(4, 2, figsize=(5, 10), sharex=True, sharey=True)
    pstride = 100

    for base, cluster, basis, SRV, TICA, equ_dists, ax_row in \
        zip(base_list, cluster_list, srv_indv_list, SRV_list, TICA_list, all_dists, axes):

        # load low temp data
        npy_name = glob.glob(f'../abasic_dists/{seq}_msm_Tm-15/{base}*')[0]
        dists = np.load(npy_name)[:, skip_low:, :]
        #dists = equ_dists[:, ::pstride, :] # sanity check using same data

        # translate into SRV and TICA basis
        srv_low_basis = [SRV.transform(1/d) for d in dists]
        tica_low_basis = [TICA.transform(1/d) for d in dists]

        # cluster into pre-defined microstates based on srv basis
        dtraj_low = cluster.assign(srv_low_basis)
        dtraj_low_list.append(dtraj_low)

        # plot Tm SRV distribution compared to low T
        srv_plot_tm = np.concatenate(basis)[::pstride]
        srv_plot_low = np.concatenate(srv_low_basis)[::pstride//10]

        ax_row[0].scatter(srv_plot_tm[:, 0], srv_plot_tm[:, 1])
        ax_row[1].scatter(srv_plot_low[:, 0], srv_plot_low[:, 1])


    axes[0, 0].set_title('SRV Tm')
    axes[0, 1].set_title('SRV T Low')

    ## also do some structural comparisons at low temp that we can report
    
    # pickle dump and load outputs for later
    save_name = f'./save_outputs/{seq}_lag-{msm_lag}_dict'
    save_dict = {'base_list':base_list, 
                 'all_dists':all_dists, 
                 'srv_indv_list':srv_indv_list, 
                 'tica_indv_list':tica_indv_list,
                 'SRV_list':SRV_list,
                 'committor_list':committor_list,
                 'dtraj_list':dtraj_list,
                 'dtraj_low_list':dtraj_low_list,
                 'msm_list':msm_list,
                 'its_list':its_list,
                 'H_list':H_list,
                 'D_list':D_list,
                 'cluster_list':cluster_list
                }
    pickle.dump(save_dict, open(save_name, 'wb'))