In [None]:
%matplotlib inline

import numpy as np
from argparse import Namespace
import plotly.express as px
import matplotlib.pyplot as plt
import time
import yaml
import copy
import pandas as pd
import pickle

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

from vida.data_processing.load_data import *
from vida.data_processing.convertor import *
from vida.data_processing.misc import *
from vida.data_processing.split_data import *

from vida.model.scatter_transform import transform_dataset, get_normalized_moments

import networkx as nx

from sklearn.decomposition import PCA
import phate

import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
import matplotlib.pyplot as plt

# Load CTMC data

In [None]:
# 1k trajectories, each have multiple states
with open('./data/jordan/traj_data.pkl', 'rb') as fp:
    trajs = pickle.load(fp)
    print('loaded trajectories dictionary')

## Data I
# keys = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,15,17,18,19,21,22,25,27,30,31,33,34,37,39,41,42]   
# trajs_states = []; trajs_times = []; trajs_energies = []
# num = 2
# for k in keys:
#     trajs_states += trajs[k]['trajs_states'][:num]
#     trajs_times += trajs[k]['trajs_times'][:num]
#     trajs_energies += trajs[k]['trajs_energies'][:num]
    
# len(trajs_states), len(trajs_times), len(trajs_energies)

## Data II
k = 27
trajs_states = trajs[k]['trajs_states']
trajs_times = trajs[k]['trajs_times']
trajs_energies = trajs[0]['trajs_energies']

len(trajs_states), len(trajs_times), len(trajs_energies)

In [None]:
# convert concantenate two individual structures to one structure
def concat_helix_structures(dp):
    """concatenate two individual structures to one structure
    Args:
        SIM: list of individual structures
    Returns:
        SIM_concat: concatenated structure
    """
    dp_concat = copy.deepcopy(dp)
    dp_pair = []
    for i in range(len(dp_concat)):
        if "&" in dp_concat[i]:
            dp_concat[i] = dp_concat[i].replace("&","")
            dp_pair.append(0)
            
        if "+" in dp_concat[i]:
            dp_concat[i] = dp_concat[i].replace("+","")
            dp_pair.append(1)
            
    return np.array(dp_concat), np.array(dp_pair)


# cooncatanate all sturcutres: 
def concat_all(states, times, energies):
    SIMS_dp = []
    SIMS_dp_og = []
    SIMS_pair = []
    SIMS_G = []
    SIMS_T = []
    
    for i in range(len(states)):
    # for i in range(100):
        sims_dp, sims_pair = concat_helix_structures(states[i])

        SIMS_dp.append(sims_dp)
        SIMS_dp_og.append(states[i])
        SIMS_pair.append(sims_pair)
        SIMS_G.append(energies[i])
        SIMS_T.append(times[i])
    
    SIMS_dp = np.concatenate(SIMS_dp)
    SIMS_dp_og = np.concatenate(SIMS_dp_og)
    SIMS_pair = np.concatenate(SIMS_pair)
    SIMS_G = np.concatenate(SIMS_G)
    SIMS_T = np.concatenate(SIMS_T)
        
    return SIMS_dp, SIMS_dp_og, SIMS_pair, SIMS_G, SIMS_T

In [None]:
SIMS_dp, SIMS_dp_og, SIMS_pair, SIMS_G, SIMS_T = concat_all(trajs_states, trajs_times, trajs_energies)
SIMS_dp.shape, SIMS_dp_og.shape, SIMS_pair.shape, SIMS_G.shape, SIMS_T.shape

In [None]:
# remvoe duplicate data
indices_S = np.unique(SIMS_dp,return_index=True)[1]

SIMS_dp_uniq = SIMS_dp[indices_S]
SIMS_dp_og_uniq = SIMS_dp_og[indices_S]
SIMS_pair_uniq = SIMS_pair[indices_S]
SIMS_G_uniq = SIMS_G[indices_S]

SIMS_dp_uniq.shape, SIMS_pair_uniq.shape, SIMS_G_uniq.shape

In [None]:
# find index to recover to all data from unique data
coord_id_S = np.empty(len(SIMS_dp))
for i in range(len(SIMS_dp_uniq)):
    temp = SIMS_dp == SIMS_dp_uniq[i]
    indx = np.argwhere(temp==True)
    coord_id_S[indx] = i
coord_id_S = coord_id_S.astype(int)

coord_id_S.shape, coord_id_S.max()

In [None]:
print(np.where(SIMS_G_uniq[coord_id_S] == SIMS_G)[0].shape)
print(np.where(SIMS_dp_uniq[coord_id_S] == SIMS_dp)[0].shape)

#### Adjacency matrix

In [None]:
# convert dot-parenthesis notation to adjacency matrix
def dot2adj(db_str,hairpin=False,helix=True):
    """converts DotBracket str to np adj matrix
    
    Args:
        db_str (str): N-len dot bracket string
    
    Returns:
        [np array]: NxN adjacency matrix
    """
    
    dim = len(str(db_str))

    # get pair tuples
    pair_list = dot2pairs(db_str)
    sym_pairs = symmetrized_edges(pair_list)


    # initialize the NxN mat (N=len of RNA str)
    adj_mat = np.zeros((dim,dim))

    adj_mat[sym_pairs[0,:], sym_pairs[1,:]] = 1
    
    if hairpin == True:
        True

    if helix == True:
        assert dim % 2 == 0, "Not a valid helix sequence."
        end2head = np.ceil(dim/2).astype(int)

        if db_str[end2head-1:end2head+1] != "()":
            adj_mat[end2head-1, end2head] = 0
            adj_mat[end2head, end2head-1] = 0

    return adj_mat

def dot2pairs(dp_str):
    """converts a DotBracket str to adj matrix

    uses a dual-checking method
    - str1 = original str
    - str2 = reversed str

    iterates through both strings simult and collects indices
    forward str iteration: collecting opening indicies - list1
    backwards str iteration: collecting closing indices - list2
    - as soon as a "(" is found in str2, the first(typo,last?) entry of list1 is paired
      with the newly added index/entry of list2 
    
    Args:
        dotbracket_str (str): dot bracket string (ex. "((..))")
    
    Returns:
        [array]: numpy adjacency matrix
    """ 
    dim = len(str(dp_str))

    # pairing indices lists
    l1_indcs = []
    l2_indcs = []
    pair_list = []

    for indx in range(dim):
        
        # checking stage
        # forward str
        if dp_str[indx] == "(":

            l1_indcs.append(indx)
 
        if dp_str[indx] == ")":
            l2_indcs.append(indx)

        # pairing stage
        # check that either list is not empty
        if len(l2_indcs) * len(l1_indcs) > 0:
            pair = (l1_indcs[-1], l2_indcs[0])
            pair_list.append(pair)
        
            # cleaning stage
            l1_indcs.pop(-1)
            l2_indcs.pop(0)
    
    # get path graph pairs
    G = nx.path_graph(dim)
    path_graph_pairs = G.edges()
    
    return pair_list + list(path_graph_pairs)

def symmetrized_edges(pairs_list):
    
    # conver pairs to numpy array [2,-1]
    edge_array = np.array(pairs_list)
 
    # concatenate with opposite direction edges
    # print(edge_array.T[[1,0]].T.shape)
    reverse_edges = np.copy(edge_array)
    reverse_edges[:, [0,1]] = reverse_edges[:, [1,0]]
    full_edge_array = np.vstack((edge_array, reverse_edges))
    
    return full_edge_array.T

# convert dot-parenthesis notation to adjacency matrix in a single trajectory
def sim_adj_new(dps):
    """convert dot-parenthesis notation to adjacency matrix
    Args:
        sim: [list of sims] dot-parenthesis notation, energy floats
            eg. ['...............', ...]
    Returns:
        (tuple): NxN adjacency np matrix
    """
    adj_mtr = []
        
    for s in dps:
        adj = dot2adj(s)
        adj_mtr.append(adj)
    adj_mtr = np.array(adj_mtr) # get adjacency matrix

    return adj_mtr

In [None]:
SIMS_adj_uniq = sim_adj_new(SIMS_dp_uniq)
print(SIMS_adj_uniq.shape)

SIMS_adj = SIMS_adj_uniq[coord_id_S]
print(SIMS_adj.shape)

#### Holding time

In [None]:
def sim_ht_new(sim_T, sim_dp):
    """calculate holding time for each trajectory
    """
    sim_HT = np.array([])
    idx = np.where(sim_T==0)[0]
    
    for i in range(len(idx)):
        if i < len(idx)-1:
            temp_T = sim_T[idx[i]:idx[i+1]]
            sim_HT = np.append(sim_HT,np.concatenate([np.diff(temp_T),[0]]))
        else:
            temp_T = sim_T[idx[i]:]
            sim_HT = np.append(sim_HT,np.concatenate([np.diff(temp_T),[0]]))
    
    # get each individual trajectory's index
    trj_id = np.where(SIMS_dp == SIMS_dp[-1])[0]

    return sim_HT, trj_id

In [None]:
SIMS_HT, trj_id = sim_ht_new(SIMS_T, SIMS_dp)
print(SIMS_HT.shape)
print(trj_id.shape)

# get unique holding time of unique states
SIMS_HT_uniq = mean_holdingtime(SIMS_HT, indices_S, coord_id_S)
print(SIMS_HT_uniq.shape)

#### Scattering coefficient

In [None]:
# # Multiple trajectories
scat_coeff_array_S = transform_dataset(SIMS_adj_uniq)
SIMS_scar_uniq = get_normalized_moments(scat_coeff_array_S).squeeze()

In [None]:
SIMS_scar = SIMS_scar_uniq[coord_id_S]
SIMS_scar.shape

### Save data for training

In [None]:
# # save npz file for shortest path
# with open('./data/jordan/pretraining/reaction_27.npz', 'wb') as f:
#     np.savez(f,
#              SIMS_dp_uniq = SIMS_dp_uniq,
#              SIMS_dp_og_uniq = SIMS_dp_og_uniq,
#              SIMS_pair_uniq = SIMS_pair_uniq,
#              SIMS_G_uniq = SIMS_G_uniq,
#              SIMS_T = SIMS_T,  
#              SIMS_HT = SIMS_HT,
#              SIMS_HT_uniq = SIMS_HT_uniq,
#              SIMS_adj_uniq = SIMS_adj_uniq,
#              SIMS_scar_uniq = SIMS_scar_uniq,
#              trj_id = trj_id,
#              indices_S = indices_S,
#              coord_id_S = coord_id_S,
#              )


# # load data
fnpz_data = "./data/jordan/pretraining/reaction_27.npz"
data_npz = np.load(fnpz_data)
# asssign data to variables
for var in data_npz.files:
     locals()[var] = data_npz[var]
     print(var, locals()[var].shape)
# recover full data based on coord_id, indices, and unique data
SIMS_dp = SIMS_dp_uniq[coord_id_S]
SIMS_dp_og = SIMS_dp_og_uniq[coord_id_S]
SIMS_pair = SIMS_pair_uniq[coord_id_S]
SIMS_G = SIMS_G_uniq[coord_id_S]
SIMS_adj = SIMS_adj_uniq[coord_id_S]
SIMS_scar = SIMS_scar_uniq[coord_id_S]

# Construct graph

### Construct weights (expected holding time)

In [None]:
print(SIMS_HT_uniq.max(), SIMS_HT_uniq[SIMS_HT_uniq!=0].min(), SIMS_HT_uniq.std(), SIMS_HT_uniq.shape)

### Construct edges

In [None]:
all_nodes = coord_id_S
print("Initial node:", all_nodes[0], "  Final node:", all_nodes[-1])

In [None]:
# pairwise nodes
# note: this step connect final->initial node,
# which will be remove later
all_edges_temp = []
for previous, current in zip(all_nodes, all_nodes[1:]):
    all_edges_temp.append((previous, current))

In [None]:
# remove all edges that connect final->initial node
all_edges = list(filter((all_nodes[-1],all_nodes[0]).__ne__, all_edges_temp))

In [None]:
# sanity check
if (all_nodes[-1],all_nodes[0]) in all_edges:
    print("yes")
else:
    print(f"There are no edges from final node {all_nodes[-1]} to initial node {all_nodes[0]}")
    
print("before prune: ", len(all_edges))
print("after prune: ",len(all_edges_temp))
print("difference: ", len(all_edges_temp) - len(all_edges))

### Construct modified undirected weight graph

In [None]:
MUG = nx.Graph()

for  i in range(len(all_edges)):
    idx0 = all_edges[i][0]
    idx1 = all_edges[i][1]
    
    if SIMS_HT_uniq[idx0] < SIMS_HT_uniq[idx1]:
        weight = SIMS_HT_uniq[idx0]
    else:
        weight = SIMS_HT_uniq[idx1]
        
    if SIMS_HT_uniq[idx0] == 0 or SIMS_HT_uniq[idx1] == 0:
        weight = SIMS_HT_uniq[idx0] + SIMS_HT_uniq[idx1]
        
    MUG.add_edge(int(all_edges[i][0]), int(all_edges[i][1]), weight = float(weight))

In [None]:
MUG.get_edge_data(1000,1001)

### Collect and save KNN=100 neighbor's shortest paths

In [None]:
%%script false --no-raise-error

# collect the shortest path for each node
knn = 100
X_j = []
D_ij = []

for i in range(len(SIMS_HT_uniq)):
    length = nx.single_source_dijkstra_path_length(MUG, i)
    length_arr = np.array(list(length.items()), dtype=object)
    
    # # find the shortest path that is less than cutoff
    # # and get the index of the node
    # pos = np.where(length_arr[:,1] < cutoff)
    # Xj = length_arr[pos][:,0]
    # dij = length_arr[pos][:,1]
    # shortestpath_dict[i] = Xj, dij
    
    # get the nearest three nodes
    X_j.append(length_arr[1:knn+1,0])
    D_ij.append(length_arr[1:knn+1,1])
    
    if i % 5000 == 0:
        print(i)
        
# save npz file for shortest path
with open(f'./data/jordan/graph/reaction28_hortestpath_knn={knn}.npz', 'wb') as f:
    np.savez(f,
             X_j = X_j,
             D_ij = D_ij,
         )

X_j = np.array(np.load(f'./data/jordan/graph/reaction28_hortestpath_knn=100.npz',allow_pickle=True)["X_j"], dtype=int)
D_ij = np.array(np.load(f'./data/jordan/graph/reaction28_hortestpath_knn=100.npz',allow_pickle=True)["D_ij"], dtype=float)

In [None]:
%%script false --no-raise-error

# collect the shortest path for each node
knn = 100

for i in range(len(SIMS_HT_uniq)):
    length = nx.single_source_dijkstra_path_length(MUG, i)
    length_arr = np.array(list(length.items()), dtype=object)
    with open(f'./data/jordan/graph/reaction27/{i}.npz', 'wb') as f:
        np.savez(f,
                 X_j = length_arr[1:knn+1,0],
                 D_ij = length_arr[1:knn+1,1],
             )
    if i % 5000 == 0:
        print(i)


X_j = []
D_ij = []
for i in range(len(SIMS_HT_uniq)):
    fpath = f'./data/jordan/graph/reaction27/{i}.npz'
    X_j.append(np.load(fpath,allow_pickle=True)["X_j"])
    D_ij.append(np.load(fpath,allow_pickle=True)["D_ij"])
    
X_j = np.array(X_j, dtype=int)
D_ij = np.array(D_ij, dtype=float)


# save npz file for shortest path
with open(f'./data/jordan/graph/reaction27_shortestpath_knn={knn}.npz', 'wb') as f:
    np.savez(f,
             X_j = X_j,
             D_ij = D_ij,
         )

In [None]:
X_j = np.array(np.load(f'./data/jordan/graph/reaction27_shortestpath_knn=100.npz',allow_pickle=True)["X_j"], dtype=int)
D_ij = np.array(np.load(f'./data/jordan/graph/reaction27_shortestpath_knn=100.npz',allow_pickle=True)["D_ij"], dtype=float)

X_j.shape, D_ij.shape

#### Normalize distance 

In [None]:
from sklearn.preprocessing import MinMaxScaler

scaler = MinMaxScaler(feature_range=(1,3))  # range [1,3]
# scaler = MinMaxScaler(feature_range=(1,10))  # range [1,10]

norm_Dij = scaler.fit_transform(D_ij) 
print(norm_Dij.max(), norm_Dij.min(), norm_Dij.mean(), norm_Dij.std())
# print(np.where(norm_Dij == 3))

log_Dij = np.log(D_ij)+np.abs(np.log(D_ij).min())+1
print(log_Dij.max(), log_Dij.min(), log_Dij.mean(), log_Dij.std())

### Find the importance weight 

In [None]:
# calculate the probability of being visited during a simulated trajectory 
# from the initial state
split_id = trj_id + 1 # index for split to each trajectory
P_tot = np.zeros(len(SIMS_dp_uniq))

for i in range(len(split_id)):
    if i == 0:
        trj = set(coord_id_S[0:split_id[i]])
    else:
        trj = set(coord_id_S[split_id[i-1]:split_id[i]])

    P_tot[list(trj)] += 1

P_tot = P_tot / 100

P_tot.shape, P_tot.max(), P_tot.min()

# ViDa Model

### Make dataset

In [None]:
data_tup = (torch.Tensor(SIMS_scar_uniq),
            torch.Tensor(SIMS_G_uniq),
            torch.arange(len(SIMS_scar_uniq)))
data_dataset = torch.utils.data.TensorDataset(*data_tup)

In [None]:
# split the dataset into train and validation
data_size = len(data_dataset)
train_size = int(0.7 * data_size)
val_size = data_size - train_size

train_data, val_data = torch.utils.data.random_split(data_dataset, [train_size, val_size], 
                                                     generator=torch.Generator().manual_seed(42))

print(data_size, len(train_data), len(val_data))

### Set up configurations

In [None]:
# set up hyperparameters
input_dim = data_tup[0].shape[-1]

config = Namespace(
    reaction = 'reaction0',
    type = 'dist_knn',
    knn = 100,
    device = 'mps', # change to cuda if using Nvida GPU
    log_dir = f'{time.strftime("%m%d-%H%M")}', # log directory
    batch_size = 256,
    input_dim = input_dim,
    output_dim = input_dim,
    latent_dim = 25, # bottleneck dimension
    hidden_dim = 400,
    n_epochs = 60,
    learning_rate = 1e-4, # learning rate
    # learning_rate = 8e-5, # learning rate
    
    log_interval = 10, # how many batches to wait before logging training status    
    patience = 40, # how many epochs to wait before early stopping
    
    # hyperparameters for loss function
    alpha = 1.0,
    beta = 1e-4,
    # beta = 1e-5,
    gamma = 1.0,
    delta = 0.05,
    
)

### Make datalaoder

In [None]:
data_loader = torch.utils.data.DataLoader(data_dataset, batch_size=config.batch_size,
                                          shuffle=False, num_workers=0)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, 
                                           shuffle=True, num_workers=0)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=config.batch_size,
                                         shuffle=False, num_workers=0)

In [None]:
print('data_loader: ', len(data_loader.dataset), len(data_loader), data_loader.batch_size)
print('train_loader: ', len(train_loader.dataset), len(train_loader), train_loader.batch_size)
print('val_loader: ', len(val_loader.dataset), len(val_loader), val_loader.batch_size)

### Encoder

In [None]:
class Encoder(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, latent_dim):
        '''
        Args:
        ----
            - input_dim: the dimension of the input node feature
            - hiddent_dim: the dimension of the hidden layer
            - latent_dim: the dimension of the latent space (bottleneck layer)
        '''
        super(Encoder, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        
        self.fc1 = nn.Linear(self.input_dim, self.hidden_dim)
        self.bn1 = nn.BatchNorm1d(self.hidden_dim)
        self.fc2 = nn.Linear(400, 400)
        self.bn2 = nn.BatchNorm1d(400)
        
        # Split the result into mu and var components
        # of the latent Gaussian distribution, note how we only output
        # diagonal values of covariance matrix. Here we assume
        # they are conditionally independent
        self.hid2mu = nn.Linear(400, self.latent_dim)
        self.hid2logvar = nn.Linear(400, self.latent_dim)
        
    def forward(self, x):
        x = self.bn1(F.relu(self.fc1(x)))
        x = self.bn2(F.relu(self.fc2(x)))
        mu = self.hid2mu(x)
        logvar = self.hid2logvar(x)
        return mu, logvar


### Decoder

In [None]:
class Decoder(nn.Module):
    
    def __init__(self, latent_dim, hidden_dim, output_dim):
        '''
        Args:
        ----
            - latent_dim: the dimension of the latent space (bottleneck layer)
            - hiddent_dim: the dimension of the hidden layer
            - output_dim: the dimension of the output node feature
        '''
        super(Decoder, self).__init__()
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        
        self.fc1 = nn.Linear(self.latent_dim, 400)
        self.bn1 = nn.BatchNorm1d(self.hidden_dim)
        self.fc2 = nn.Linear(400, 400)
        self.bn2 = nn.BatchNorm1d(400)
        self.fc3 = nn.Linear(400, self.output_dim)
        
    def forward(self, z):
        x = self.bn1(F.relu(self.fc1(z)))
        x = self.bn2(F.relu(self.fc2(x)))
        x = self.fc3(x)
        return x

### Regressor

In [None]:
class Regressor(nn.Module):
    
    def __init__(self, latent_dim):
        '''
        The regressor is used to predict the energy of the node
        
        Args:
        ----
            - latent_dim: the dimension of the latent space (bottleneck layer)
        '''
        super(Regressor, self).__init__()
        self.latent_dim = latent_dim
        
        self.regfc1 = nn.Linear(self.latent_dim, 15)
        self.regfc2 = nn.Linear(15, 1)
        
    def forward(self, z):
        y = F.relu(self.regfc1(z))
        y = self.regfc2(y)
        return y

### VIDA model

In [None]:
class VIDA(nn.Module):
    
    def __init__(self, encoder, decoder, regressor):
        '''
        Args:
        ----
            - input_dim: the dimension of the input node feature
            - hiddent_dim: the dimension of the hidden layer
            - latent_dim: the dimension of the latent space (bottleneck layer)
            - output_dim: the dimension of the output node feature (same as input_dim)
        '''
        super(VIDA, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.regressor = regressor
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        z = mu + eps*std
        return z
        
    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        y_pred = self.regressor(z)
        return x_recon, y_pred, z, mu, logvar

### Loss functions

In [None]:
def vae_loss(x_recon, x, mu, logvar):
    '''
    Compute the VAE loss
    
    Args:
        - x_recon: the reconstructed node feature
        - x: the original node feature
        - mu: the mean of the latent space
        - logvar: the log variance of the latent space
    
    Returns:
    - loss: PyTorch Tensor containing (scalar) the loss for the VAE
    '''
    BCE = F.mse_loss(x_recon.flatten(), x.flatten())
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE, KLD


def pred_loss(y_pred, y):
    '''
    Compute the energy prediction loss
    
    Args:
    ----
        - y_pred: the predicted energy of the node
        - y: the true energy of the node
    
    Returns:
        - loss: PyTorch Tensor containing (scalar) the loss for the prediction
    '''
    return F.mse_loss(y_pred.flatten(), y.flatten())


def distance_knn_loss_torch(config, zi, zj, Dij, P_tot, idx, batchXj_id):
    '''
    Compute the distance loss between embeddings 
    and the minimum expected holding time
    
    Args:
        - zi: the embedding of the node i
        - zj: the embedding of the nodes j's
        - Dij: the post-processing distance between nodes i and j's
        - P_tot: the total probability of the nodes i and j's
        - idx: the index of the node i
        - batchXj_id: the index of the nodes j's
    
    Returns:
    - loss: PyTorch Tensor containing (scalar) the loss for the embedding distance
    '''
    zi = zi.reshape(-1,1,zi.shape[-1])
    l2_zizj = torch.sqrt(torch.sum((zi-zj)**2, dim=-1))
    dist_diff = (l2_zizj - (Dij[idx]).to(config.device))**2

    wij = (P_tot[idx].reshape(-1,1) * P_tot[batchXj_id]).to(config.device) # importance weight of nodes i and j       
    dist_loss = torch.sum(dist_diff * wij)
    return dist_loss

### Early_stopping function

In [None]:
def early_stop(val_loss, epoch, patience):
    """
    Check if validation loss has not improved for a certain number of epochs
    
    Args:
        val_loss (float): current validation loss
        epoch (int): current epoch number
        patience (int): number of epochs to wait before stopping if validation loss does not improve
        
    Returns:
        bool: True if validation loss has not improved for the last `patience` epochs, False otherwise
    """
    if epoch == 0:
        # First epoch, don't stop yet
        return False
    else:
        # Check if validation loss has not improved for `patience` epochs
        if val_loss >= early_stop.best_loss:
            early_stop.num_epochs_without_improvement += 1
            if early_stop.num_epochs_without_improvement >= patience:
                print("Stopping early")
                return True
            else:
                return False
        else:
            # Validation loss improved, reset counter
            early_stop.best_loss = val_loss
            early_stop.num_epochs_without_improvement = 0
            return False

### Train VIDA

In [None]:
# validation function
def validate(config, model, data_loader, val_loader, P_tot, Dij, X_j,
             vae_loss, pred_loss, distance_knn_loss):
    model.to(config.device)
    model.eval()
    
    val_loss = 0; val_bce = 0; val_kld = 0; val_pred = 0; val_dist = 0
    
    # Disable gradient calculation to speed up inference
    with torch.no_grad():
        for x, y, idx in val_loader:
            
            # Configure input
            x = x.to(config.device)
            y = y.to(config.device)
            
            # forward X_j
            model.eval()
            with torch.no_grad():
                batchXj_id = X_j[idx]
                neighbor_input = data_loader.dataset.tensors[0][batchXj_id].reshape(-1, config.input_dim).to(config.device)
                _, _, neighbor_embed, _, _ = model(neighbor_input)
                neighbor_embed = neighbor_embed.reshape(-1, config.knn, neighbor_embed.shape[-1])
                
            # embedding
            x_recon, y_pred, z, mu, logvar = model(x)
            
            # # compute the total loss
            # vae loss
            recon_loss, kl_loss = vae_loss(x_recon, x, mu, logvar)
            recon_loss = recon_loss.item()
            kl_loss = kl_loss.item()
            
            # energy prediction loss
            p_loss = pred_loss(y_pred, y).item()
                
            # distance loss
            dist_loss = distance_knn_loss(config, z, neighbor_embed, Dij, P_tot, idx, batchXj_id).item()
            
            # scaling the loss
            recon_loss = config.alpha * recon_loss
            kl_loss = config.beta * kl_loss
            p_loss = config.gamma * p_loss
            dist_loss = config.delta * dist_loss
            
            # total loss
            loss = recon_loss + kl_loss + p_loss + dist_loss
            
            val_loss += loss
            val_bce += recon_loss
            val_kld += kl_loss
            val_pred += p_loss
            val_dist += dist_loss
                        
    print('Validation Loss: {:.4f}'.format(val_loss/len(val_loader.dataset)))
    
    # Clear the cache
    torch.cuda.empty_cache()
    # torch.mps.empty_cache()
    
    return val_loss/len(val_loader.dataset), val_bce/len(val_loader.dataset), val_kld/len(val_loader.dataset), val_pred/len(val_loader.dataset), val_dist/len(val_loader.dataset)
            

In [None]:
def train(config, model, data_loader, train_loader, val_loader, P_tot, Dij, X_j,
          optimizer, vae_loss, pred_loss, distance_knn_loss, early_stop, ):
    '''
    Train VIDA!
    
    Args:
    ----
        - config: Experiment configurations
        - model: Pytorch VIDA model
        - data_loader: Pytorch DataLoader for all data
        - train_loader: Pytorch DataLoader for training set
        - val_loader: Pytorch DataLoader for validation set
        - P_tot: the total probability of each node
        - Dij: the shortest path distance between k nearest pairs of nodes
        - X_j: the index of k nearest neighbours of each node
        - optimizer: Pytorch optimizer
        - vae_loss: the VAE loss function
        - pred_loss: the energy prediction loss function
        - distant_knn_loss: the k nearest neighbour distance loss function
    '''
    
    model.to(config.device)
    
    log_dir = f'./model_config/{config.log_dir}'
    writer = SummaryWriter(log_dir=log_dir)
    
    # save the config file
    with open(f'{log_dir}/hparams.yaml','w') as f:
        yaml.dump(config,f)
    
    # Initialize early stop object
    early_stop.best_loss = np.Inf
    early_stop.num_epochs_without_improvement = 0

    # convert the numpy array to tensor
    P_tot = torch.from_numpy(P_tot.astype(np.float32))
    Dij = torch.from_numpy(Dij.astype(np.float32))
    X_j = torch.from_numpy(X_j)
    
    print('\n ------- Start Training -------')
    for epoch in range(config.n_epochs):
        start_time = time.time()
        training_loss = 0
        
        for batch_idx, (x, y, idx) in enumerate(train_loader):  # mini batch
            
            # Configure input
            x = x.to(config.device)
            y = y.to(config.device)
            
            # forward X_j
            model.eval()
            with torch.no_grad():
                batchXj_id = X_j[idx]
                neighbor_input = data_loader.dataset.tensors[0][batchXj_id].reshape(-1, config.input_dim).to(config.device)
                _, _, neighbor_embed, _, _ = model(neighbor_input)
                neighbor_embed = neighbor_embed.reshape(-1, config.knn, neighbor_embed.shape[-1])
                
            # ------------------------------------------
            #  Train VIDA
            # ------------------------------------------
            model.train()
            optimizer.zero_grad()
            
            # get the reconstructed nodes, predicted energy, and the embeddings
            x_recon, y_pred, z, mu, logvar = model(x)
        
            ## compute the total loss
            # vae loss
            recon_loss, kl_loss = vae_loss(x_recon, x, mu, logvar)
            
            # energy prediction loss
            p_loss = pred_loss(y_pred, y)

            # distance loss
            dist_loss = distance_knn_loss(config, z, neighbor_embed, Dij, P_tot, idx, batchXj_id)

            # scaling the loss
            recon_loss = config.alpha * recon_loss
            kl_loss = config.beta * kl_loss
            p_loss = config.gamma * p_loss
            dist_loss = config.delta * dist_loss
            
            # total loss
            loss = recon_loss + kl_loss + p_loss + dist_loss 
            
            # backpropagation and optimization
            loss.backward()
            optimizer.step()
            
            training_loss += loss.item()
            
            # ------------------------------------------
            # Log Progress
            # ------------------------------------------
            if batch_idx % config.log_interval == 0:
                print('Train Epoch: {}/{} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, config.n_epochs, batch_idx * len(x), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader),
                    loss.item()))
                
                writer.add_scalar('training loss',
                                  loss.item(),
                                  epoch * len(train_loader) + batch_idx)
                writer.add_scalar('recon loss',
                                  recon_loss.item(),
                                  epoch * len(train_loader) + batch_idx)
                writer.add_scalar('kl loss',
                                  kl_loss.item(),
                                  epoch * len(train_loader) + batch_idx)
                writer.add_scalar('pred loss',
                                  p_loss.item(),
                                  epoch * len(train_loader) + batch_idx)
                writer.add_scalar('dist loss',
                                  dist_loss.item(),
                                  epoch * len(train_loader) + batch_idx)
    
        print ('====> Epoch: {} Average loss: {:.4f}'.format(epoch, training_loss/len(train_loader.dataset)))
        writer.add_scalar('epoch training loss', training_loss/len(train_loader.dataset), epoch)
        
        # validation
        val_loss, val_bce, val_kld, val_pred, val_dist = validate(config, model, data_loader, val_loader, P_tot, Dij, X_j, 
                                                                  vae_loss, pred_loss, distance_knn_loss)
        writer.add_scalar('validation loss', val_loss, epoch)
        writer.add_scalar('val_recon loss', val_bce, epoch)
        writer.add_scalar('val_kl loss', val_kld, epoch)
        writer.add_scalar('val_pred loss', val_pred, epoch)
        writer.add_scalar('val_dist loss', val_dist, epoch)
        
        end_time = time.time()
        epoch_time = end_time - start_time
        print (f'Epoch {epoch} train+val time: {epoch_time:.2f} seconds \n')
        
        # Check if validation loss has not improved for `patience` epochs
        if early_stop(val_loss, epoch, config.patience):
            break
    
        # Clear the cache
        # torch.mps.empty_cache()
        torch.cuda.empty_cache()
        
        
    writer.close()
    print('\n ------- Finished Training -------')
    
    # save the model
    torch.save(model.state_dict(), f'{log_dir}/model.pt')
    

In [None]:
# define models
encoder = Encoder(input_dim=config.input_dim, hidden_dim=config.hidden_dim, latent_dim=config.latent_dim)
decoder = Decoder(latent_dim=config.latent_dim, hidden_dim=config.hidden_dim, output_dim=config.output_dim)
regressor = Regressor(latent_dim=config.latent_dim)

vida = VIDA(encoder, decoder, regressor)

# define optimizer
optimizer = torch.optim.Adam(vida.parameters(), lr=config.learning_rate)

In [None]:
# train VIDA
train(config, vida, data_loader, train_loader, val_loader, P_tot, norm_Dij, X_j,
      optimizer, vae_loss, pred_loss, distance_knn_loss_torch, early_stop)

### Load saved trained model

In [None]:
# model = VIDA(encoder, decoder, regressor)
# model.load_state_dict(torch.load('./model_config/0501-0147/model.pt'))

### Get embeddings

In [None]:
model = VIDA(encoder, decoder, regressor)

In [None]:
# do inference
model.to(config.device).eval()

with torch.no_grad():
        _, _, z, _, _ = model(data_loader.dataset.tensors[0].to(config.device))        

In [None]:
data_embed = z.to('cpu').numpy()
data_embed.shape

In [None]:
print(data_embed.max(), data_embed.min(), data_embed.mean(), data_embed.std())

### PCA

In [None]:
# # do PCA for GSAE embeded data
pca_coords = PCA(n_components=2).fit_transform(data_embed)

pca_coords.shape, (np.unique(pca_coords,axis=0)).shape

### PHATE

In [None]:
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
data_embed = scaler.fit_transform(data_embed)

In [None]:
# # do PHATE for GSAE embeded data
phate_operator = phate.PHATE(n_jobs=-2)
phate_coords = phate_operator.fit_transform(data_embed)

phate_coords.shape, (np.unique(phate_coords,axis=0)).shape

### Save DRs

In [None]:
""" Save all DRs
"""
# save for python
fnpz_data = f"data/jordan/plotdata/plot_{config.log_dir}_rxn0.npz"
with open(fnpz_data, 'wb') as f:
    np.savez(f,
            data_embed=data_embed,
            # plotting data
            pca_coords=pca_coords,
            phate_coords=phate_coords,
            )

In [None]:
# load plot data
fnpz_data = "data/jordan/plotdata/plot_0511-1609.npz"
data_npz = np.load(fnpz_data)

# asssign data to variables
for var in data_npz.files:
    locals()[var] = data_npz[var]
    print(var, locals()[var].shape)

pca_all_coords = pca_coords[coord_id_S]
phate_all_coords = phate_coords[coord_id_S]

## Interactive plot

In [None]:
# make dataframe for plotting   
df = pd.DataFrame(data={
            "Energy": SIMS_G_uniq, "Pair": SIMS_pair_uniq, "DP": SIMS_dp_og_uniq, "HT": SIMS_HT_uniq,
            "PCA 1": pca_coords[:,0], "PCA 2": pca_coords[:,1],
            "PHATE 1": phate_coords[:,0], "PHATE 2": phate_coords[:,1]
            }
            )

# make dataframe for plotting   
dfall = pd.DataFrame(data={
        "Energy": SIMS_G, "Pair": SIMS_pair,"DP": SIMS_dp_og,"HT": SIMS_HT, "TotalT": SIMS_T,
        "PCA 1": pca_all_coords[:,0], "PCA 2": pca_all_coords[:,1],
        "PHATE 1": phate_all_coords[:,0], "PHATE 2": phate_all_coords[:,1],
        }
        )

df0 = df[df["Pair"] == 0]
df1 = df[df["Pair"] == 1]

dfall0 = dfall[dfall["Pair"] == 0]
dfall1 = dfall[dfall["Pair"] == 1]

In [None]:
###############################################################################
# get the ith single trajectory
###############################################################################

def plot_trj(trj_id,dfall,i,vis,dim):
        TRJ_ID = trj_id+1
        
        if i == 0:
                s = 0
                s_prime = TRJ_ID[i]
        elif i == len(trj_id):
                s = TRJ_ID[i-1]
                s_prime = len(dfall)
        else:
                s = TRJ_ID[i-1]
                s_prime = TRJ_ID[i]
        
        # get energy, pair, DP, HT, TotalT for each trajectory
        subdf = pd.DataFrame(data={
            "Energy": dfall["Energy"][s:s_prime],
            "Pair": dfall["Pair"][s:s_prime],
            "DP": dfall["DP"][s:s_prime],
            "HT": dfall["HT"][s:s_prime],
            "TotalT": dfall["TotalT"][s:s_prime],
            }
            )
        
        # get step numbers for the trajectory less than 2000 steps
        if len(subdf["DP"]) < 2000:
                Step = []
                for i in range(len(subdf["DP"])):
                        step=[]
                        for j in range(len(subdf["DP"])):
                                if subdf["DP"].iloc[i] == subdf["DP"].iloc[j]:
                                        step.append(j+1)
                        Step.append(step)
                subdf["Step"] = np.array(Step,dtype=object)
        else:
                subdf["Step"] = None
        
        # get X,Y,Z (if applicable) coordinates for the trajectory              
        if dim=="2D":
                subdf["sub X"] = dfall["{} 1".format(vis)][s:s_prime]
                subdf["sub Y"] = dfall["{} 2".format(vis)][s:s_prime]

                Zi=None
                Zf=None
                
        elif dim=="3D":
                subdf["sub X"] = dfall["{} X".format(vis)][s:s_prime]
                subdf["sub Y"] = dfall["{} Y".format(vis)][s:s_prime]
                subdf["sub Z"] = dfall["{} Z".format(vis)][s:s_prime]
                 
                Zi=subdf["sub Z"].iloc[0]
                Zf=subdf["sub Z"].iloc[-1]
                 
        # get initial and final coordinates: list structure
        Xi=subdf["sub X"].iloc[0]; Xf=subdf["sub X"].iloc[-1]
        Yi=subdf["sub Y"].iloc[0]; Yf=subdf["sub Y"].iloc[-1]
        
        return subdf,(Xi,Xf),(Yi,Yf),(Zi,Zf)


In [None]:
###############################################################################
# plot 2D energy landscape
###############################################################################

def interactive_plotly_2D(n_trace,df,dfall,trj_id,vis):
    fig = go.Figure()
    
    # plot energy landscape background
    fig.add_trace(go.Scattergl(
            x=df["{} 1".format(vis)], 
            y=df["{} 2".format(vis)], 
            mode='markers',
            marker=dict(
                sizemode='area',
                size=df["HT"],
                sizeref=1e-6,
                color=df["Energy"],
                colorscale="Plasma",
                showscale=True,
                # colorbar_x=-0.2,
                colorbar=dict(
                    title="Free energy (kcal/mol)",  
                    x=-0.2,
                    titleside="top",  
                    len=1.065,
                    y=0.5,
                ),
                line=dict(width=0.2
                ),
            ),
            text=df['DP'],
            hovertemplate=
                "DP notation: <br> <b>%{text}</b><br>" +
                "X: %{x}   " + "   Y: %{y} <br>"+
                "Energy:  %{marker.color:.3f} kcal/mol<br><br>"+
                "Average holding time:  %{marker.size:.5g} s<br>",
                    
            name="states",
            # showlegend=False,
        )
    )

    # layout trajectory on top of energy landscape
    for i in range(n_trace):
    # for i in range(1):
    
        subdf = plot_trj(trj_id,dfall,i,vis,dim="2D")[0]
        fig.add_trace(
            go.Scattergl(
                x=subdf["sub X"], 
                y=subdf["sub Y"],
                mode='lines+markers',
                line=dict(
                    # color='rgb({}, {}, {})'.format((i/100*255),(i/100*255),(i/100*255)),
                    color="black",
                    width=2,
                ),
                marker=dict(
                    sizemode='area',
                    size=subdf["HT"],
                    sizeref=1e-6,
                    # sizeref=0,
                    color=subdf["Energy"],
                    colorscale="Plasma",
                    # showscale=False,
                    colorbar=dict(
                        x=-0.2,
                        tickvals=[],
                        y=0.5,
                        len=1,
                        ),
                ),
                
                text=subdf["Step"],
                customdata = np.stack((subdf['Pair'],
                                       subdf["TotalT"],
                                       subdf["DP"],
                                       ),axis=-1),
                hovertemplate=
                    "Step:  <b>%{text}</b><br><br>"+
                    "DP notation: <br> <b>%{customdata[2]}</b><br>" +
                    "X: %{x}   " + "   Y: %{y} <br>"+
                    "Energy:  %{marker.color:.3f} kcal/mol<br><br>"+
                    "Holding time for last step:  %{marker.size:.5g} s<br>"+
                    "Total time until current state:  %{customdata[1]:.5e} s<br>",
                visible='legendonly'
                        )
        )

    # label initial and final states
    fig.add_trace(
        go.Scattergl(
            x=plot_trj(trj_id,dfall,0,vis,dim="2D")[1],
            y=plot_trj(trj_id,dfall,0,vis,dim="2D")[2],
            mode='markers+text',
            marker_color="lime", 
            marker_size=15,
            text=["I", "F"],
            textposition="middle center",
            textfont=dict(
            family="sans serif",
            size=16,
            color="black"
        ),
            hoverinfo='skip',
            showlegend=False,
                        )
    )
    
    fig.update_xaxes(
        range=[min(df["{} 1".format(vis)])*1.1,max(df["{} 1".format(vis)])*1.1]
    )
    
    fig.update_yaxes(
        range=[min(df["{} 2".format(vis)])*1.1,max(df["{} 2".format(vis)])*1.1]
    )

    fig.update_layout(
        # autosize=True,
        # width=700,
        # height=700,
        # margin=dict(
        #     l=50,
        #     r=50,
        #     b=100,
        #     t=100,
        #     pad=4
        # ),
        title="{} Vis".format(vis),
        xaxis=dict(
                title="{} 1".format(vis),
            ),
        yaxis=dict(
                title="{} 2".format(vis),
            ),
        legend=dict(
            title="Single Trajectory",
            title_font=dict(size=10),
            font=dict(
                # family="Courier",
                size=10,
                color="black"
        )
        )
    )
    
    return fig

In [None]:
# VIS_METHOD = ["PCA", "PHATE"]
VIS_METHOD = ["PHATE"]

n_trace = len(trj_id)

for vis in VIS_METHOD: 
    fig = interactive_plotly_2D(n_trace,df,dfall,trj_id,vis)
    pio.write_html(fig, file=f"./data/jordan/plots/CTMC_{vis}_reaction27_0511-1609.html", auto_open=True)
    print("DONE: ", vis)

## Visualize

### PCA Vis

In [None]:
%matplotlib inline
X = pca_coords[:,0]
Y = pca_coords[:,1]

# PCA: 2 components
fig,ax = plt.subplots(figsize=(8,6))
im = ax.scatter(X, Y, 
          c=SIMS_G, 
          cmap='plasma',
        )

plt.colorbar(im)


annotations=["I","F"]
# annotations=["F"]
x = [X[0],X[f_indx_new]]
y = [Y[0],Y[f_indx_new]]
plt.scatter(x,y,s=150, c="green", alpha=1)
for i, label in enumerate(annotations):
    plt.annotate(label, (x[i]-0.3,y[i]-0.3),fontsize=15,c="black")

#### Try use PCA directly without AE

In [None]:
pca_coords1 = PCA(n_components=3).fit_transform(SIMS_scar)   # multiple trj

X = pca_coords1[:,0]
Y = pca_coords1[:,1]
Z = pca_coords1[:,2]

# PCA: 2 components
fig,ax = plt.subplots(figsize=(8,6))
im = ax.scatter(X, Y, 
          c=SIMS_G, 
          cmap='plasma',
        )

plt.colorbar(im)

annotations=["I","F"]
x = [X[0],X[f_indx_new]]
y = [Y[0],Y[f_indx_new]]
plt.scatter(x,y,s=150, c="green", alpha=1)
for i, label in enumerate(annotations):
    plt.annotate(label, (x[i]-0.3,y[i]-0.3),fontsize=15,c="black")

In [None]:
cm = PCA(n_components=25)
cm.fit(data_embed)

PC_values = np.arange(cm.n_components_) + 1
plt.plot(PC_values, np.cumsum(cm.explained_variance_ratio_), 'ro-', linewidth=2)
plt.title('Scree Plot: PCA')
plt.xlabel('Number of principal components')
plt.ylabel('Cumulative explained variance');
# plt.xticks(np.arange(0, data_embed.shape[-1]+1, 1))

plt.show()

np.cumsum(cm.explained_variance_ratio_)

### PHATE Vis

In [None]:
X_phate = phate_coords[:,0]
Y_phate = phate_coords[:,1]

fig,ax = plt.subplots(figsize=(8,6))
im = ax.scatter(X_phate,Y_phate,
                c=SIMS_G,            
                cmap='plasma',
               )
plt.colorbar(im)

annotations=["I","F"]
x = [X_phate[0],X_phate[f_indx_new]]
y = [Y_phate[0],Y_phate[f_indx_new]]
plt.scatter(x,y,s=50, c="green", alpha=1)
for i, label in enumerate(annotations):
    plt.annotate(label, (x[i],y[i]),fontsize=30,c="black")

#### PHATE without AE

In [None]:
phate_operator = phate.PHATE(n_jobs=-2)
phate1 = phate_operator.fit_transform(SIMS_scar)   # multiple trj

fig,ax = plt.subplots(figsize=(8,6))
im = ax.scatter(phate1[:,0],
          phate1[:,1],
          c=SIMS_G, 
          cmap='plasma',
        )

plt.colorbar(im)

annotations=["I","F"]
x = [phate1[:,0][0],phate1[:,0][f_indx]]
y = [phate1[:,1][0],phate1[:,1][f_indx]]
plt.scatter(x,y,s=50, c="green", alpha=1)
for i, label in enumerate(annotations):
    plt.annotate(label, (x[i],y[i]),fontsize=20,c="black")