In [None]:
%matplotlib inline

import numpy as np
from argparse import Namespace
import plotly.express as px
import matplotlib.pyplot as plt
import time
import yaml
import copy
import pandas as pd
import pickle
import os

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

from vida.data_processing.load_data import *
from vida.data_processing.convertor import *
from vida.data_processing.misc import *
from vida.data_processing.split_data import *

from vida.model.scatter_transform import transform_dataset, get_normalized_moments

import networkx as nx

from sklearn.preprocessing import MinMaxScaler
from sklearn.decomposition import PCA
import phate

import plotly.express as px
import plotly.io as pio
import plotly.graph_objects as go
import matplotlib.pyplot as plt

## Load First Step data

In [None]:
# load Jordan's fist step data 
rxn_name = "npstates5m_39"
fnpz_data = f"./data/jordan/multistrand_firststep/bins_{rxn_name}.npz"

trajs = np.load(fnpz_data,allow_pickle=True)
print(trajs.files)

trajs_states = trajs['trajs_states'] 
trajs_times = trajs['trajs_times']
trajs_energies = trajs['trajs_energies']
trajs_types = trajs['trajs_types'].item()
trajs['trajs_states'].shape, len(trajs_types) # total number of trajectories, each trajectory is a list of states

In [None]:
# collect successful and unsuccessful trajectories
num_success_traj = 0
success_threshold = 50 # 3 for npstates5m_46, 50 for others
select_trajs = []
select_trajs_id = []
select_times = []
select_energies = []
final_states = '(((((((((((((((((((((((+)))))))))))))))))))))))'

i = 0
while num_success_traj < success_threshold:
    select_trajs.append(trajs_states[i])
    select_times.append(trajs_times[i])
    select_energies.append(trajs_energies[i])
    if trajs_states[i][-1] == final_states:
        num_success_traj += 1
        select_trajs_id.append(i)
    i += 1
    
select_trajs = np.array(select_trajs, dtype=object)
select_times = np.array(select_times, dtype=object)
select_energies = np.array(select_energies, dtype=object)

print(select_trajs.shape)
print(len(select_trajs_id))

In [None]:
import numpy as np
import pickle
import gzip

# Define the filename for the saved data
file_name = "Hata-39.pkl.gz"

# Create a dictionary to store both the NumPy array and the set
data_to_save = {
    "select_trajs": select_trajs,
    "select_times": select_times,
    "select_energies": select_energies,
    "select_types": SIMS_type_uniq,
}

# Save the data to the file using pickle
with open(file_name, 'wb') as file:
    pickle.dump(data_to_save, file)

# # To load the data back later, you can use the following code:
# with open(file_name, 'rb') as file:
#     loaded_data = pickle.load(file)

# # You can access the NumPy array and set as follows:
# select_trajs = loaded_data["select_trajs"]
# select_times = loaded_data["select_times"]
# select_energies = loaded_data["select_energies"]
# SIMS_type_uniq = loaded_data["select_types"]


In [None]:
select_energies

In [None]:
%%script false --no-raise-error
# # for reaction 45
## select 1000 failed trajectories randomly
## then add 50 successful trajectories
all_idx = np.arange(len(select_trajs))

failed_idx = np.random.choice(np.delete(all_idx, select_trajs_id), size=1000, replace=False)

# add 50 successful trajectories
all_new_idx = np.concatenate((failed_idx, select_trajs_id))

select_trajs = select_trajs[all_new_idx]
select_times = select_times[all_new_idx]
select_energies = select_energies[all_new_idx]

print(select_trajs.shape)

# make the new select_trajs_id
num_success_traj = 0
success_threshold = 50 
select_trajs_id = []
i = 0
while num_success_traj < success_threshold:
    if select_trajs[i][-1] == final_states:
        num_success_traj += 1
        select_trajs_id.append(i)
    i += 1
    
# make the new select_trajs_id
num_success_traj = 0
success_threshold = 50 
select_trajs_id = []
i = 0
while num_success_traj < success_threshold:
    if select_trajs[i][-1] == final_states:
        num_success_traj += 1
        select_trajs_id.append(i)
    i += 1

## Preprocess data

#### SIMS_data

In [None]:
# convert concantenate two individual structures to one structure
def concat_helix_structures(dp):
    """concatenate two individual structures to one structure
    Args:
        SIM: list of individual structures
    Returns:
        SIM_concat: concatenated structure
    """
    dp_concat = copy.deepcopy(dp)
    dp_pair = []
    for i in range(len(dp_concat)):
        if "&" in dp_concat[i]:
            dp_concat[i] = dp_concat[i].replace("&","")
            dp_pair.append(0)
            
        if "+" in dp_concat[i]:
            dp_concat[i] = dp_concat[i].replace("+","")
            dp_pair.append(1)
            
    return np.array(dp_concat), np.array(dp_pair)


# cooncatanate all sturcutres: 
def concat_all(states, times, energies):
    SIMS_dp = []
    SIMS_dp_og = []
    SIMS_pair = []
    SIMS_G = []
    SIMS_T = []
    
    for i in range(len(states)):
    # for i in range(100):
        sims_dp, sims_pair = concat_helix_structures(states[i])

        SIMS_dp.append(sims_dp)
        SIMS_dp_og.append(states[i])
        SIMS_pair.append(sims_pair)
        SIMS_G.append(energies[i])
        SIMS_T.append(times[i])
    
    SIMS_dp = np.concatenate(SIMS_dp)
    SIMS_dp_og = np.concatenate(SIMS_dp_og)
    SIMS_pair = np.concatenate(SIMS_pair)
    SIMS_G = np.concatenate(SIMS_G)
    SIMS_T = np.concatenate(SIMS_T)
        
    return SIMS_dp, SIMS_dp_og, SIMS_pair, SIMS_G, SIMS_T

In [None]:
SIMS_dp, SIMS_dp_og, SIMS_pair, SIMS_G, SIMS_T = concat_all(select_trajs, select_times, select_energies)
SIMS_dp.shape, SIMS_dp_og.shape, SIMS_pair.shape, SIMS_G.shape, SIMS_T.shape

In [None]:
# remvoe duplicate data
indices_S = np.unique(SIMS_dp,return_index=True)[1]

SIMS_dp_uniq = SIMS_dp[indices_S]
SIMS_dp_og_uniq = SIMS_dp_og[indices_S]
SIMS_pair_uniq = SIMS_pair[indices_S]
SIMS_G_uniq = SIMS_G[indices_S]

SIMS_dp_uniq.shape, SIMS_pair_uniq.shape, SIMS_G_uniq.shape

In [None]:
# find index to recover to all data from unique data
coord_id_S = np.empty(len(SIMS_dp))
for i in range(len(SIMS_dp_uniq)):
    temp = SIMS_dp == SIMS_dp_uniq[i]
    indx = np.argwhere(temp==True)
    coord_id_S[indx] = i
coord_id_S = coord_id_S.astype(int)

coord_id_S.shape, coord_id_S.max()

In [None]:
# label data's type
SIMS_type_uniq = []
for i in range(len(SIMS_dp_og_uniq)):
    SIMS_type_uniq.append(trajs_types[SIMS_dp_og_uniq[i]])
SIMS_type_uniq = np.array(SIMS_type_uniq)
SIMS_type_uniq.shape

#### Adjacency matrix

In [None]:
# convert dot-parenthesis notation to adjacency matrix
def dot2adj(db_str,hairpin=False,helix=True):
    """converts DotBracket str to np adj matrix
    
    Args:
        db_str (str): N-len dot bracket string
    
    Returns:
        [np array]: NxN adjacency matrix
    """
    
    dim = len(str(db_str))

    # get pair tuples
    pair_list = dot2pairs(db_str)
    sym_pairs = symmetrized_edges(pair_list)


    # initialize the NxN mat (N=len of RNA str)
    adj_mat = np.zeros((dim,dim))

    adj_mat[sym_pairs[0,:], sym_pairs[1,:]] = 1
    
    if hairpin == True:
        True

    if helix == True:
        assert dim % 2 == 0, "Not a valid helix sequence."
        end2head = np.ceil(dim/2).astype(int)

        if db_str[end2head-1:end2head+1] != "()":
            adj_mat[end2head-1, end2head] = 0
            adj_mat[end2head, end2head-1] = 0

    return adj_mat

def dot2pairs(dp_str):
    """converts a DotBracket str to adj matrix

    uses a dual-checking method
    - str1 = original str
    - str2 = reversed str

    iterates through both strings simult and collects indices
    forward str iteration: collecting opening indicies - list1
    backwards str iteration: collecting closing indices - list2
    - as soon as a "(" is found in str2, the first(typo,last?) entry of list1 is paired
      with the newly added index/entry of list2 
    
    Args:
        dotbracket_str (str): dot bracket string (ex. "((..))")
    
    Returns:
        [array]: numpy adjacency matrix
    """ 
    dim = len(str(dp_str))

    # pairing indices lists
    l1_indcs = []
    l2_indcs = []
    pair_list = []

    for indx in range(dim):
        
        # checking stage
        # forward str
        if dp_str[indx] == "(":

            l1_indcs.append(indx)
 
        if dp_str[indx] == ")":
            l2_indcs.append(indx)

        # pairing stage
        # check that either list is not empty
        if len(l2_indcs) * len(l1_indcs) > 0:
            pair = (l1_indcs[-1], l2_indcs[0])
            pair_list.append(pair)
        
            # cleaning stage
            l1_indcs.pop(-1)
            l2_indcs.pop(0)
    
    # get path graph pairs
    G = nx.path_graph(dim)
    path_graph_pairs = G.edges()
    
    return pair_list + list(path_graph_pairs)

def symmetrized_edges(pairs_list):
    
    # conver pairs to numpy array [2,-1]
    edge_array = np.array(pairs_list)
 
    # concatenate with opposite direction edges
    # print(edge_array.T[[1,0]].T.shape)
    reverse_edges = np.copy(edge_array)
    reverse_edges[:, [0,1]] = reverse_edges[:, [1,0]]
    full_edge_array = np.vstack((edge_array, reverse_edges))
    
    return full_edge_array.T

# convert dot-parenthesis notation to adjacency matrix in a single trajectory
def sim_adj_new(dps):
    """convert dot-parenthesis notation to adjacency matrix
    Args:
        sim: [list of sims] dot-parenthesis notation, energy floats
            eg. ['...............', ...]
    Returns:
        (tuple): NxN adjacency np matrix
    """
    adj_mtr = []
        
    for s in dps:
        adj = dot2adj(s)
        adj_mtr.append(adj)
    adj_mtr = np.array(adj_mtr) # get adjacency matrix

    return adj_mtr

In [None]:
SIMS_adj_uniq = sim_adj_new(SIMS_dp_uniq)
print(SIMS_adj_uniq.shape)

# SIMS_adj = SIMS_adj_uniq[coord_id_S]
# print(SIMS_adj.shape)

#### Expected holding time

In [None]:
def sim_ht_new(sim_T, sim_dp):
    """calculate holding time for each trajectory
    """
    sim_HT = np.array([])
    idx = np.where(sim_T==0)[0]
    
    for i in range(len(idx)):
        if i < len(idx)-1:
            temp_T = sim_T[idx[i]:idx[i+1]]
            sim_HT = np.append(sim_HT,np.concatenate([np.diff(temp_T),[0]]))
        else:
            temp_T = sim_T[idx[i]:]
            sim_HT = np.append(sim_HT,np.concatenate([np.diff(temp_T),[0]]))
    
    # get each individual trajectory's index
    # trj_id = np.where(SIMS_dp == SIMS_dp[-1])[0]
    trj_id = np.where(sim_HT==0)[0]

    return sim_HT, trj_id

In [None]:
SIMS_HT, trj_id = sim_ht_new(SIMS_T, SIMS_dp)
print(SIMS_HT.shape)
print(trj_id.shape)

# get unique holding time of unique states
SIMS_HT_uniq = mean_holdingtime(SIMS_HT, indices_S, coord_id_S)
print(SIMS_HT_uniq.shape)

#### Cumulative holding time

In [None]:
SIMS_cumu_HT_uniq,cumu_account_uniq = cumu_holdingtime(SIMS_HT, indices_S, coord_id_S)
print(SIMS_cumu_HT_uniq.shape, cumu_account_uniq.shape)

#### Scattering coefficient

In [None]:
# # Multiple trajectories
scat_coeff_array_S = transform_dataset(SIMS_adj_uniq)
SIMS_scar_uniq = get_normalized_moments(scat_coeff_array_S).squeeze()

In [None]:
SIMS_scar = SIMS_scar_uniq[coord_id_S]
SIMS_scar.shape

### Save data for training

In [None]:
%%script false --no-raise-error

# save npz file for shortest path
with open(f'./data/jordan/pretraining/{rxn_name}.npz', 'wb') as f:
    np.savez(f,
             SIMS_dp_uniq = SIMS_dp_uniq,
             SIMS_dp_og_uniq = SIMS_dp_og_uniq,
             SIMS_pair_uniq = SIMS_pair_uniq,
             SIMS_G_uniq = SIMS_G_uniq,
             SIMS_T = SIMS_T,  
             SIMS_HT = SIMS_HT,
             SIMS_HT_uniq = SIMS_HT_uniq,
             SIMS_cumu_HT_uniq = SIMS_cumu_HT_uniq,
             SIMS_adj_uniq = SIMS_adj_uniq,
             SIMS_scar_uniq = SIMS_scar_uniq,
             SIMS_type_uniq=SIMS_type_uniq,
             cumu_account_uniq=cumu_account_uniq,
             trj_id = trj_id,
             indices_S = indices_S,
             coord_id_S = coord_id_S,
             )


## Load saved preprocessed data

In [None]:
# # load data
fnpz_data = f"./data/jordan/pretraining/{rxn_name}.npz"
data_npz = np.load(fnpz_data)
# # asssign data to variables
for var in data_npz.files:
     locals()[var] = data_npz[var]
     print(var, locals()[var].shape)
     
# recover full data based on coord_id, indices, and unique data
SIMS_dp = SIMS_dp_uniq[coord_id_S]
SIMS_dp_og = SIMS_dp_og_uniq[coord_id_S]
SIMS_pair = SIMS_pair_uniq[coord_id_S]
SIMS_G = SIMS_G_uniq[coord_id_S]
SIMS_type = SIMS_type_uniq[coord_id_S]
SIMS_cumu_HT = SIMS_cumu_HT_uniq[coord_id_S]
cumu_account = cumu_account_uniq[coord_id_S]

SIMS_adj = SIMS_adj_uniq[coord_id_S]
SIMS_scar = SIMS_scar_uniq[coord_id_S]


## Construct Graph

### Construct weights (expected holding time)

In [None]:
print(SIMS_HT_uniq.max(), SIMS_HT_uniq[SIMS_HT_uniq!=0].min(), SIMS_HT_uniq.std(), SIMS_HT_uniq.shape)

### Construct edges

In [None]:
all_nodes = coord_id_S
# print("Initial node:", all_nodes[0], "  Final node:", all_nodes[-1])
print("Initial node: not unique", "  Final node:", all_nodes[-1])


In [None]:
# pairwise nodes
# note: this step connect final->initial node,
# which will be remove later
all_edges_temp = []
for previous, current in zip(all_nodes, all_nodes[1:]):
    all_edges_temp.append((previous, current))

In [None]:
# # remove all edges that connect final->initial node
# all_edges = list(filter((all_nodes[-1],all_nodes[0]).__ne__, all_edges_temp))

indices_to_delete = trj_id[:-1]
# Sort the indices in reverse order so that deleting elements won't affect subsequent indices
indices_to_delete = sorted(indices_to_delete, reverse=True)

all_edges = copy.deepcopy(all_edges_temp)
for index in indices_to_delete:
    del all_edges[index]

In [None]:
# sanity check
if (all_nodes[-1],all_nodes[0]) in all_edges:
    print("yes")
else:
    print(f"There are no edges from final node {all_nodes[-1]} to initial node {all_nodes[0]}")
    
print("before prune: ", len(all_edges))
print("after prune: ",len(all_edges_temp))
print("difference: ", len(all_edges_temp) - len(all_edges))

### Construct modified undirected weight graph

In [None]:
MUG = nx.Graph()

for  i in range(len(all_edges)):
    idx0 = all_edges[i][0]
    idx1 = all_edges[i][1]
    
    if SIMS_HT_uniq[idx0] < SIMS_HT_uniq[idx1]:
        weight = SIMS_HT_uniq[idx0]
    else:
        weight = SIMS_HT_uniq[idx1]
        
    if SIMS_HT_uniq[idx0] == 0 or SIMS_HT_uniq[idx1] == 0:
        weight = SIMS_HT_uniq[idx0] + SIMS_HT_uniq[idx1]
        
    MUG.add_edge(int(all_edges[i][0]), int(all_edges[i][1]), weight = float(weight))

In [None]:
len(MUG.edges), len(MUG.nodes), MUG.get_edge_data(0,1)

### Collect and save KNN=100 neighbor's shortest paths

In [None]:
%%script false --no-raise-error

# # create a directory to store the shortest path
directory = f'./data/jordan/graph/{rxn_name}'
if not os.path.exists(directory):
    os.makedirs(directory)
    print("Directory created successfully!")
else :
    print("Directory already exists")

# # collect the shortest path for each node
knn = 100

for i in range(len(SIMS_HT_uniq)):
    length = nx.single_source_dijkstra_path_length(MUG, i)
    length_arr = np.array(list(length.items()), dtype=object)
    with open(f'./data/jordan/graph/{rxn_name}/{i}.npz', 'wb') as f:
        np.savez(f,
                 X_j = length_arr[1:knn+1,0],
                 D_ij = length_arr[1:knn+1,1],
             )
    if i % 5000 == 0:
        print(i)

# # load saved distance matrix
X_j = []
D_ij = []
for i in range(len(SIMS_HT_uniq)):
    fpath = f'./data/jordan/graph/{rxn_name}/{i}.npz'
    X_j.append(np.load(fpath,allow_pickle=True)["X_j"])
    D_ij.append(np.load(fpath,allow_pickle=True)["D_ij"])

# # padd the X_j and D_ij to the same length
for i in range(len(X_j)):
    if len(X_j[i]) < knn:
        X_j[i] = np.pad(X_j[i], (0, knn-len(X_j[i])), 'constant', constant_values=i)
        D_ij[i] = np.pad(D_ij[i], (0, knn-len(D_ij[i])), 'constant', constant_values=0)

# # convert to numpy array
X_j = np.array(X_j, dtype=int)
D_ij = np.array(D_ij, dtype=float)

# # save npz file for shortest path
with open(f'./data/jordan/graph/{rxn_name}_shortestpath_knn={knn}.npz', 'wb') as f:
    np.savez(f,
             X_j = X_j,
             D_ij = D_ij,
         )

### Load saved graph info

In [None]:
X_j = np.array(np.load(f'./data/jordan/graph/{rxn_name}_shortestpath_knn=100.npz',allow_pickle=True)["X_j"], dtype=int)
D_ij = np.array(np.load(f'./data/jordan/graph/{rxn_name}_shortestpath_knn=100.npz',allow_pickle=True)["D_ij"], dtype=float)

X_j.shape, D_ij.shape

### Normalize distance 

In [None]:
from sklearn.preprocessing import MinMaxScaler

scaler = MinMaxScaler(feature_range=(0,3))  # range [1,3]
# scaler = MinMaxScaler(feature_range=(1,10))  # range [1,10]

norm_Dij = scaler.fit_transform(D_ij) 
print(norm_Dij.max(), norm_Dij.min(), norm_Dij.mean(), norm_Dij.std())

### Find the importance weight 

In [None]:
# calculate the probability of being visited during a simulated trajectory 
# from the initial state
split_id = trj_id + 1 # index for split to each trajectory
P_tot = np.zeros(len(SIMS_dp_uniq))

for i in range(len(split_id)):
    if i == 0:
        trj = set(coord_id_S[0:split_id[i]])
    else:
        trj = set(coord_id_S[split_id[i-1]:split_id[i]])

    P_tot[list(trj)] += 1

P_tot = P_tot / 100

P_tot.shape, P_tot.max(), P_tot.min()

## Edit Distance

In [None]:
%%script false --no-raise-error

def edit_distance_old(str1, str2):
    m = len(str1)
    n = len(str2)
    
    # Initialize a 2D array to store edit distances
    dp = [[0] * (n + 1) for _ in range(m + 1)]
    
    # Initialize the first row and column
    for i in range(m + 1):
        dp[i][0] = i
    for j in range(n + 1):
        dp[0][j] = j
    
    # Calculate edit distance using dynamic programming
    for i in range(1, m + 1):
        for j in range(1, n + 1):
            cost = 0 if str1[i - 1] == str2[j - 1] else 1
            dp[i][j] = min(
                dp[i - 1][j] + 1,      # Deletion
                dp[i][j - 1] + 1,      # Insertion
                dp[i - 1][j - 1] + cost  # Substitution
            )
    
    return dp[m][n]

In [None]:
def edit_distance(adj1, adj2):
    # Calculate edit distance based on its adjacency matrix
    edit_dist = np.sum(np.abs(adj1-adj2),dtype=int)
    
    return edit_dist

In [None]:
# Calculate the edit distance between X_i and X_j'
ED_ij = []
for i in range(X_j.shape[0]):
    ed_ij = []
    for j in range(X_j.shape[1]):
        ed_ij.append(edit_distance(SIMS_adj_uniq[i], SIMS_adj_uniq[X_j[i,j]]))
    ED_ij.append(ed_ij)
ED_ij = np.array(ED_ij, dtype=int)

with open(f'./data/jordan/graph/{rxn_name}_edit_distance.npz', 'wb') as f:
    np.savez(f,
             ED_ij = ED_ij,
         )

In [None]:
print(ED_ij.max(), ED_ij.min(), ED_ij.mean(), ED_ij.std(), ED_ij.shape)

### Load edit distance

In [None]:
ED_ij = np.array(np.load(f'./data/jordan/graph/{rxn_name}_edit_distance.npz',allow_pickle=True)["ED_ij"], dtype=float)
print(ED_ij.max(), ED_ij.min(), ED_ij.mean(), ED_ij.std(), ED_ij.shape)

## ViDa Model

### Make dataset

In [None]:
data_tup = (torch.Tensor(SIMS_scar_uniq),
            torch.Tensor(SIMS_G_uniq),
            torch.arange(len(SIMS_scar_uniq)))
data_dataset = torch.utils.data.TensorDataset(*data_tup)

In [None]:
# split the dataset into train and validation
data_size = len(data_dataset)
train_size = int(0.7 * data_size)
val_size = data_size - train_size

train_data, val_data = torch.utils.data.random_split(data_dataset, [train_size, val_size], 
                                                     generator=torch.Generator().manual_seed(42))

print(data_size, len(train_data), len(val_data))

### Set up configurations

In [None]:
# set up hyperparameters
input_dim = data_tup[0].shape[-1]

config = Namespace(
    reaction = rxn_name,
    select_trajs_id = select_trajs_id,
    
    type = 'vida',
    knn = 100,
    edit_distance = "100_neighbors",
    device = 'mps', # change to cuda if using Nvida GPU
    log_dir = f'{time.strftime("%m%d-%H%M")}', # log directory
    batch_size = 256,
    input_dim = input_dim,
    output_dim = input_dim,
    latent_dim = 25, # bottleneck dimension
    hidden_dim = 400,
    n_epochs = 60,
    # learning_rate = 1e-4, # learning rate
    learning_rate = 5e-5, # learning rate
    
    log_interval = 10, # how many batches to wait before logging training status    
    # patience = 5, # how many epochs to wait before early stopping #3
    patience = 10, # how many epochs to wait before early stopping #3
    
    
    # hyperparameters for loss function
    alpha = 1, # reconstruction loss
    
    beta = 1e-4, # kl divergence
    
    gamma = 0.3, # energy loss
    
    # delta = 0.04, # distance loss
    delta = 0.0004, # distance loss
    # delta = 1e-4, # distance loss
    
    # delta = 0, # distance loss
    
    epsilon = 1e-4, # edit distance loss
    # epsilon = 5e-5, # edit distance loss
    
)

### Make datalaoder

In [None]:
data_loader = torch.utils.data.DataLoader(data_dataset, batch_size=config.batch_size,
                                          shuffle=False, num_workers=0)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, 
                                           shuffle=True, num_workers=0)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=config.batch_size,
                                         shuffle=False, num_workers=0)

In [None]:
print('data_loader: ', len(data_loader.dataset), len(data_loader), data_loader.batch_size)
print('train_loader: ', len(train_loader.dataset), len(train_loader), train_loader.batch_size)
print('val_loader: ', len(val_loader.dataset), len(val_loader), val_loader.batch_size)

### Encoder

In [None]:
class Encoder(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, latent_dim):
        '''
        Args:
        ----
            - input_dim: the dimension of the input node feature
            - hiddent_dim: the dimension of the hidden layer
            - latent_dim: the dimension of the latent space (bottleneck layer)
        '''
        super(Encoder, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        
        self.fc1 = nn.Linear(self.input_dim, self.hidden_dim)
        self.bn1 = nn.BatchNorm1d(self.hidden_dim)
        self.fc2 = nn.Linear(400, 400)
        self.bn2 = nn.BatchNorm1d(400)
        
        # Split the result into mu and var components
        # of the latent Gaussian distribution, note how we only output
        # diagonal values of covariance matrix. Here we assume
        # they are conditionally independent
        self.hid2mu = nn.Linear(400, self.latent_dim)
        self.hid2logvar = nn.Linear(400, self.latent_dim)
        
    def forward(self, x):
        x = self.bn1(F.relu(self.fc1(x)))
        x = self.bn2(F.relu(self.fc2(x)))
        mu = self.hid2mu(x)
        logvar = self.hid2logvar(x)
        return mu, logvar


### Decoder

In [None]:
class Decoder(nn.Module):
    
    def __init__(self, latent_dim, hidden_dim, output_dim):
        '''
        Args:
        ----
            - latent_dim: the dimension of the latent space (bottleneck layer)
            - hiddent_dim: the dimension of the hidden layer
            - output_dim: the dimension of the output node feature
        '''
        super(Decoder, self).__init__()
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        
        self.fc1 = nn.Linear(self.latent_dim, 400)
        self.bn1 = nn.BatchNorm1d(self.hidden_dim)
        self.fc2 = nn.Linear(400, 400)
        self.bn2 = nn.BatchNorm1d(400)
        self.fc3 = nn.Linear(400, self.output_dim)
        
    def forward(self, z):
        x = self.bn1(F.relu(self.fc1(z)))
        x = self.bn2(F.relu(self.fc2(x)))
        x = self.fc3(x)
        return x

### Regressor

In [None]:
class Regressor(nn.Module):
    
    def __init__(self, latent_dim):
        '''
        The regressor is used to predict the energy of the node
        
        Args:
        ----
            - latent_dim: the dimension of the latent space (bottleneck layer)
        '''
        super(Regressor, self).__init__()
        self.latent_dim = latent_dim
        
        self.regfc1 = nn.Linear(self.latent_dim, 15)
        self.regfc2 = nn.Linear(15, 1)
        
    def forward(self, z):
        y = F.relu(self.regfc1(z))
        y = self.regfc2(y)
        return y

### VIDA model

In [None]:
class VIDA(nn.Module):
    
    def __init__(self, encoder, decoder, regressor):
        '''
        Args:
        ----
            - input_dim: the dimension of the input node feature
            - hiddent_dim: the dimension of the hidden layer
            - latent_dim: the dimension of the latent space (bottleneck layer)
            - output_dim: the dimension of the output node feature (same as input_dim)
        '''
        super(VIDA, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.regressor = regressor
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        z = mu + eps*std
        return z
        
    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        y_pred = self.regressor(z)
        return x_recon, y_pred, z, mu, logvar

### Loss functions

In [None]:
def vae_loss(x_recon, x, mu, logvar):
    '''
    Compute the VAE loss
    
    Args:
        - x_recon: the reconstructed node feature
        - x: the original node feature
        - mu: the mean of the latent space
        - logvar: the log variance of the latent space
    
    Returns:
    - loss: PyTorch Tensor containing (scalar) the loss for the VAE
    '''
    BCE = F.mse_loss(x_recon.flatten(), x.flatten()) # L2 loss
    # BCE = F.l1_loss(x_recon.flatten(), x.flatten()) # L1 loss
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE, KLD


def pred_loss(y_pred, y):
    '''
    Compute the energy prediction loss
    
    Args:
    ----
        - y_pred: the predicted energy of the node
        - y: the true energy of the node
    
    Returns:
        - loss: PyTorch Tensor containing (scalar) the loss for the prediction
    '''
    return F.mse_loss(y_pred.flatten(), y.flatten())


def distance_knn_loss_torch(config, zi, zj, Dij, P_tot, idx, batchXj_id):
    '''
    Compute the distance loss between embeddings 
    and the minimum expected holding time
    
    Args:
        - zi: the embedding of the node i
        - zj: the embedding of the nodes j's
        - Dij: the post-processing distance between nodes i and j's
        - P_tot: the total probability of the nodes i and j's
        - idx: the index of the node i
        - batchXj_id: the index of the nodes j's
    
    Returns:
    - loss: PyTorch Tensor containing (scalar) the loss for the embedding distance
    '''
    zi = zi.reshape(-1,1,zi.shape[-1])
    l2_zizj = torch.sqrt(torch.sum((zi-zj)**2, dim=-1))
    dist_diff = (l2_zizj - (Dij[idx]).to(config.device))**2

    wij = (P_tot[idx].reshape(-1,1) * P_tot[batchXj_id]).to(config.device) # importance weight of nodes i and j       
    dist_loss = torch.sum(dist_diff * wij)
    return dist_loss

def edit_distance_loss(config, zi, zj, ED_ij, idx):
    '''
    Compute the edit distance loss between embeddings 
    
    Args:
        - zi: the embedding of the node i
        - zj: the embedding of the nodes j's
        - ED_ij: the post-processing distance between nodes i and j's
        - idx: the index of the node i
        - batchXj_id: the index of the nodes j's
    
    Returns:
    - loss: PyTorch Tensor containing (scalar) the loss for the embedding distance
    '''
    zi = zi.reshape(-1,1,zi.shape[-1])
    l2_zizj = torch.sqrt(torch.sum((zi-zj)**2, dim=-1))
    dist_diff = (l2_zizj - (ED_ij[idx]).to(config.device))**2
    # wij = (P_tot[idx].reshape(-1,1) * P_tot[batchXj_id]).to(config.device) # importance weight of nodes i and j       
    # dist_loss = torch.sum(dist_diff * wij)
    editdist_loss = torch.sum(dist_diff)
    return editdist_loss
    

### Early_stopping function

In [None]:
def early_stop(val_loss, epoch, patience):
    """
    Check if validation loss has not improved for a certain number of epochs
    
    Args:
        val_loss (float): current validation loss
        epoch (int): current epoch number
        patience (int): number of epochs to wait before stopping if validation loss does not improve
        
    Returns:
        bool: True if validation loss has not improved for the last `patience` epochs, False otherwise
    """
    if epoch == 0:
        # First epoch, don't stop yet
        return False
    else:
        # Check if validation loss has not improved for `patience` epochs
        if val_loss >= early_stop.best_loss:
            early_stop.num_epochs_without_improvement += 1
            if early_stop.num_epochs_without_improvement >= patience:
                print("Stopping early")
                return True
            else:
                return False
        else:
            # Validation loss improved, reset counter
            early_stop.best_loss = val_loss
            early_stop.num_epochs_without_improvement = 0
            return False

### Train VIDA

In [None]:
# validation function
def validate(config, model, data_loader, val_loader, P_tot, Dij, ED_ij, X_j,
             vae_loss, pred_loss, distance_knn_loss, edit_distance_loss):
    model.to(config.device)
    model.eval()
    
    val_loss = 0; val_bce = 0; val_kld = 0; val_pred = 0; val_dist = 0; val_edit = 0
    
    # Disable gradient calculation to speed up inference
    with torch.no_grad():
        for x, y, idx in val_loader:
            
            # Configure input
            x = x.to(config.device)
            y = y.to(config.device)
            
            # forward X_j
            model.eval()
            with torch.no_grad():
                batchXj_id = X_j[idx]
                neighbor_input = data_loader.dataset.tensors[0][batchXj_id].reshape(-1, config.input_dim).to(config.device)
                _, _, neighbor_embed, _, _ = model(neighbor_input)
                neighbor_embed = neighbor_embed.reshape(-1, config.knn, neighbor_embed.shape[-1])
                
            # embedding
            x_recon, y_pred, z, mu, logvar = model(x)
            
            # # compute the total loss
            # vae loss
            recon_loss, kl_loss = vae_loss(x_recon, x, mu, logvar)
            recon_loss = recon_loss.item()
            kl_loss = kl_loss.item()
            
            # energy prediction loss
            p_loss = pred_loss(y_pred, y).item()
                
            # distance loss
            dist_loss = distance_knn_loss(config, z, neighbor_embed, Dij, P_tot, idx, batchXj_id).item()
            
            # edit distance loss
            edit_loss = edit_distance_loss(config, z, neighbor_embed, ED_ij, idx).item()
            
            # scaling the loss
            recon_loss = config.alpha * recon_loss
            kl_loss = config.beta * kl_loss
            p_loss = config.gamma * p_loss
            dist_loss = config.delta * dist_loss
            edit_loss = config.epsilon * edit_loss
            
            # total loss
            loss = recon_loss + kl_loss + p_loss + dist_loss + edit_loss
            
            val_loss += loss
            val_bce += recon_loss
            val_kld += kl_loss
            val_pred += p_loss
            val_dist += dist_loss
            val_edit += edit_loss
                        
    print('Validation Loss: {:.4f}'.format(val_loss/len(val_loader.dataset)))
    
    # Clear the cache
    torch.cuda.empty_cache()
    # torch.mps.empty_cache()
    
    return val_loss/len(val_loader.dataset), val_bce/len(val_loader.dataset), val_kld/len(val_loader.dataset), val_pred/len(val_loader.dataset), val_dist/len(val_loader.dataset), val_edit/len(val_loader.dataset)
            

In [None]:
def train(config, model, data_loader, train_loader, val_loader, P_tot, Dij, ED_ij, X_j,
          optimizer, vae_loss, pred_loss, distance_knn_loss, edit_distance_loss, early_stop,):
    '''
    Train VIDA!
    
    Args:
    ----
        - config: Experiment configurations
        - model: Pytorch VIDA model
        - data_loader: Pytorch DataLoader for all data
        - train_loader: Pytorch DataLoader for training set
        - val_loader: Pytorch DataLoader for validation set
        - P_tot: the total probability of each node
        - Dij: the shortest path distance between k nearest pairs of nodes
        - X_j: the index of k nearest neighbours of each node
        - optimizer: Pytorch optimizer
        - vae_loss: the VAE loss function
        - pred_loss: the energy prediction loss function
        - distant_knn_loss: the k nearest neighbour distance loss function
    '''
    
    model.to(config.device)
    
    log_dir = f'./model_config/{config.log_dir}'
    writer = SummaryWriter(log_dir=log_dir)
    
    # save the config file
    with open(f'{log_dir}/hparams.yaml','w') as f:
        yaml.dump(config,f)
    
    # Initialize early stop object
    early_stop.best_loss = np.Inf
    early_stop.num_epochs_without_improvement = 0

    # convert the numpy array to tensor
    P_tot = torch.from_numpy(P_tot.astype(np.float32))
    Dij = torch.from_numpy(Dij.astype(np.float32))
    ED_ij = torch.from_numpy(ED_ij.astype(int))
    X_j = torch.from_numpy(X_j)
    
    print('\n ------- Start Training -------')
    for epoch in range(config.n_epochs):
        start_time = time.time()
        training_loss = 0
        
        for batch_idx, (x, y, idx) in enumerate(train_loader):  # mini batch
            
            # Configure input
            x = x.to(config.device)
            y = y.to(config.device)
            
            # forward X_j
            model.eval()
            with torch.no_grad():
                batchXj_id = X_j[idx]
                neighbor_input = data_loader.dataset.tensors[0][batchXj_id].reshape(-1, config.input_dim).to(config.device)
                _, _, neighbor_embed, _, _ = model(neighbor_input)
                neighbor_embed = neighbor_embed.reshape(-1, config.knn, neighbor_embed.shape[-1])
                
            # ------------------------------------------
            #  Train VIDA
            # ------------------------------------------
            model.train()
            optimizer.zero_grad()
            
            # get the reconstructed nodes, predicted energy, and the embeddings
            x_recon, y_pred, z, mu, logvar = model(x)
        
            ## compute the total loss
            # vae loss
            recon_loss, kl_loss = vae_loss(x_recon, x, mu, logvar)
            
            # energy prediction loss
            p_loss = pred_loss(y_pred, y)

            # distance loss
            dist_loss = distance_knn_loss(config, z, neighbor_embed, Dij, P_tot, idx, batchXj_id)

            # edit distance loss
            edit_loss = edit_distance_loss(config, z, neighbor_embed, ED_ij, idx)
            
            # scaling the loss
            recon_loss = config.alpha * recon_loss
            kl_loss = config.beta * kl_loss
            p_loss = config.gamma * p_loss
            dist_loss = config.delta * dist_loss
            edit_loss = config.epsilon * edit_loss
            
            # total loss
            loss = recon_loss + kl_loss + p_loss + dist_loss + edit_loss
            
            # backpropagation and optimization
            loss.backward()
            optimizer.step()
            
            training_loss += loss.item()
            
            # ------------------------------------------
            # Log Progress
            # ------------------------------------------
            if batch_idx % config.log_interval == 0:
                print('Train Epoch: {}/{} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, config.n_epochs, batch_idx * len(x), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader),
                    loss.item()))
                
                writer.add_scalar('training loss',
                                  loss.item(),
                                  epoch * len(train_loader) + batch_idx)
                writer.add_scalar('recon loss',
                                  recon_loss.item(),
                                  epoch * len(train_loader) + batch_idx)
                writer.add_scalar('kl loss',
                                  kl_loss.item(),
                                  epoch * len(train_loader) + batch_idx)
                writer.add_scalar('pred loss',
                                  p_loss.item(),
                                  epoch * len(train_loader) + batch_idx)
                writer.add_scalar('dist loss',
                                  dist_loss.item(),
                                  epoch * len(train_loader) + batch_idx)
                writer.add_scalar('edit loss',
                                  edit_loss.item(),
                                  epoch * len(train_loader) + batch_idx)
    
        print ('====> Epoch: {} Average loss: {:.4f}'.format(epoch, training_loss/len(train_loader.dataset)))
        writer.add_scalar('epoch training loss', training_loss/len(train_loader.dataset), epoch)
        
        # validation
        val_loss, val_bce, val_kld, val_pred, val_dist, val_edit = validate(config, model, data_loader, val_loader, P_tot, Dij, ED_ij, X_j, 
                                                                  vae_loss, pred_loss, distance_knn_loss, edit_distance_loss)
        writer.add_scalar('validation loss', val_loss, epoch)
        writer.add_scalar('val_recon loss', val_bce, epoch)
        writer.add_scalar('val_kl loss', val_kld, epoch)
        writer.add_scalar('val_pred loss', val_pred, epoch)
        writer.add_scalar('val_dist loss', val_dist, epoch)
        writer.add_scalar('val_edit loss', val_edit, epoch)
        
        end_time = time.time()
        epoch_time = end_time - start_time
        print (f'Epoch {epoch} train+val time: {epoch_time:.2f} seconds \n')
        
        # Check if validation loss has not improved for `patience` epochs
        if early_stop(val_loss, epoch, config.patience):
            break
    
        # Clear the cache
        # torch.mps.empty_cache()
        torch.cuda.empty_cache()
        
        
    writer.close()
    print('\n ------- Finished Training -------')
    
    # save the model
    torch.save(model.state_dict(), f'{log_dir}/model.pt')
    

In [None]:
# define models
encoder = Encoder(input_dim=config.input_dim, hidden_dim=config.hidden_dim, latent_dim=config.latent_dim)
decoder = Decoder(latent_dim=config.latent_dim, hidden_dim=config.hidden_dim, output_dim=config.output_dim)
regressor = Regressor(latent_dim=config.latent_dim)

vida = VIDA(encoder, decoder, regressor)

# define optimizer
optimizer = torch.optim.Adam(vida.parameters(), lr=config.learning_rate)

In [None]:
# train VIDA
train(config, vida, data_loader, train_loader, val_loader, P_tot, norm_Dij, ED_ij, X_j,
      optimizer, vae_loss, pred_loss, distance_knn_loss_torch, edit_distance_loss, early_stop)

### Load saved trained model

In [None]:
# model = VIDA(encoder, decoder, regressor)
# model.load_state_dict(torch.load('./model_config/0805-0325/model.pt'))

### Get embeddings

In [None]:
model = VIDA(encoder, decoder, regressor)

In [None]:
# do inference
model.to(config.device).eval()

with torch.no_grad():
        _, _, z, _, _ = model(data_loader.dataset.tensors[0].to(config.device))        

In [None]:
data_embed = z.to('cpu').numpy()
data_embed.shape

In [None]:
print(data_embed.max(), data_embed.min(), data_embed.mean(), data_embed.std())

### PCA

In [None]:
# # do PCA for GSAE embeded data
pca_coords = PCA(n_components=2).fit_transform(data_embed)

pca_coords.shape, (np.unique(pca_coords,axis=0)).shape

### PHATE

In [None]:
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
data_embed = scaler.fit_transform(data_embed)

In [None]:
# # do PHATE for GSAE embeded data
phate_operator = phate.PHATE(n_jobs=-2)
phate_coords = phate_operator.fit_transform(data_embed)

phate_coords.shape, (np.unique(phate_coords,axis=0)).shape

### Save DRs

In [None]:
# %%script false --no-raise-error

""" Save all DRs
"""
# save for python
fnpz_data = f"data/jordan/plotdata/plot_{config.log_dir}_{rxn_name}.npz"
with open(fnpz_data, 'wb') as f:
    np.savez(f,
            data_embed=data_embed,
            # plotting data
            pca_coords=pca_coords,
            phate_coords=phate_coords,
            )

## Load saved embedding

In [None]:
# load plot data
# fnpz_data = f"data/jordan/plotdata/plot_{config.log_dir}_{rxn_name}.npz"
fnpz_data = f"data/jordan/plotdata/plot_0825-2043_{rxn_name}.npz"

data_npz = np.load(fnpz_data)

# asssign data to variables
for var in data_npz.files:
    locals()[var] = data_npz[var]
    print(var, locals()[var].shape)

pca_all_coords = pca_coords[coord_id_S]
phate_all_coords = phate_coords[coord_id_S]

## Sort Trajectories

In [None]:
print(np.min(SIMS_cumu_HT[SIMS_cumu_HT!=0]), SIMS_cumu_HT.max())
print(np.min(cumu_account), cumu_account.max())

pca_all_coords = pca_coords[coord_id_S]
phate_all_coords = phate_coords[coord_id_S]


In [None]:
# # assign the state with time=0 to the minimum value
SIMS_cumu_HT_temp = copy.deepcopy(SIMS_cumu_HT)
SIMS_cumu_HT_temp[np.where(SIMS_cumu_HT==0)[0]] = np.min(SIMS_cumu_HT[SIMS_cumu_HT!=0])
# # log transform the data
exponent = np.ceil(np.log10(1) - np.log10(np.min(SIMS_cumu_HT_temp))).astype(int)
SIMS_cumu_HT_log = np.log(SIMS_cumu_HT_temp * 10**exponent)

# # add 1 to the min(cumu_account_uniq) to make sure log transform doesn't have 0
cumu_account_temp = copy.deepcopy(cumu_account)
cumu_account_temp[np.where(cumu_account==1)[0]] = 2
# # log transform the data
cumu_account_log = np.log(cumu_account_temp)

print((SIMS_cumu_HT_log).min(), SIMS_cumu_HT_log.max())
print((cumu_account_log).min(), cumu_account_log.max())

SIMS_cumu_HT_uniq_log = SIMS_cumu_HT_log[indices_S]
cumu_account_uniq_log = cumu_account_log[indices_S]

SIMS_cumu_HT_uniq_log.shape, cumu_account_uniq_log.shape

In [None]:
# get each trajectory

# List of arrays to split
arrays_to_split = [SIMS_dp_og, SIMS_T, SIMS_HT, SIMS_G, SIMS_pair, SIMS_type, SIMS_cumu_HT, cumu_account, SIMS_cumu_HT_log, cumu_account_log, pca_all_coords, phate_all_coords]

# Get each trajectory using a single loop
subtrj_id = (trj_id+1)[:-1]
sub_arrays = [np.split(arr, subtrj_id) for arr in arrays_to_split]

# Use zip to unpack the sub-arrays into separate variables if needed
sub_SIMS_dp_og, sub_SIMS_T, sub_SIMS_HT, sub_SIMS_G, sub_SIMS_pair, sub_SIMS_type, sub_SIMS_cumu_HT, sub_cumu_account, sub_SIMS_cumu_HT_log, sub_cumu_account_log, sub_pca_all_coords, sub_phate_all_coords = sub_arrays

# Assert the lengths to ensure correctness
assert len(sub_SIMS_T) == len(sub_SIMS_G) == len(sub_SIMS_dp_og) == len(sub_SIMS_HT) == len(sub_SIMS_pair) == len(sub_SIMS_type) == len(sub_SIMS_cumu_HT) == len(sub_cumu_account) == len(sub_SIMS_cumu_HT_log) == len(sub_SIMS_cumu_HT_log) == len(sub_pca_all_coords) == len(sub_phate_all_coords)

# Print the length of sub_SIMS_T for verification
print(len(sub_SIMS_T))


In [None]:
# # separate the successful and failed trajectories

# List of arrays to extract successful trajectories from
arrays_to_extract = [sub_SIMS_dp_og, sub_SIMS_T, sub_SIMS_HT, sub_SIMS_G, sub_SIMS_pair, sub_SIMS_type, sub_SIMS_cumu_HT, sub_cumu_account, sub_SIMS_cumu_HT_log, sub_cumu_account_log, sub_pca_all_coords, sub_phate_all_coords]

succ_id = select_trajs_id
# Extract successful trajectories for each array using a single loop and list comprehension
succ_sub_arrays = [[arr[i] for i in succ_id] for arr in arrays_to_extract]
# Unpack the extracted successful trajectories into separate variables
succ_sub_SIMS_dp_og, succ_sub_SIMS_T, succ_sub_SIMS_HT, succ_sub_SIMS_G, succ_sub_SIMS_pair, succ_sub_SIMS_type, succ_sub_SIMS_cumu_HT, succ_sub_cumu_account, succ_sub_SIMS_cumu_HT_log, succ_sub_cumu_account_log, succ_sub_pca_all_coords, succ_sub_phate_all_coords = succ_sub_arrays

fail_id = np.setdiff1d(np.arange(len(sub_SIMS_T)), succ_id)
# Extract failed trajectories for each array using a single loop and list comprehension
fail_sub_arrays = [[arr[i] for i in fail_id] for arr in arrays_to_extract]
# Unpack the extracted failed trajectories into separate variables
fail_sub_SIMS_dp_og, fail_sub_SIMS_T, fail_sub_SIMS_HT, fail_sub_SIMS_G, fail_sub_SIMS_pair, fail_sub_SIMS_type, fail_sub_SIMS_cumu_HT, fail_sub_cumu_account, fail_sub_SIMS_cumu_HT_log, fail_sub_cumu_account_log, fail_sub_pca_all_coords, fail_sub_phate_all_coords = fail_sub_arrays

assert len(succ_sub_SIMS_dp_og) +  len(fail_sub_SIMS_dp_og) == len(sub_SIMS_dp_og)
print(len(succ_sub_SIMS_dp_og), len(fail_sub_SIMS_dp_og))

In [None]:
# Sort the successful and fail trajectories by reaction time
sorted_succ_indices = np.argsort([sub_array[-1] for sub_array in succ_sub_SIMS_T])[::-1]
sorted_fail_indices = np.argsort([sub_array[-1] for sub_array in fail_sub_SIMS_T])[::-1]

# List of arrays to sort in descending order based on sorted_succ_indices or sorted_fail_indices
succ_arrays_to_sort = [succ_sub_SIMS_dp_og, succ_sub_SIMS_T, succ_sub_SIMS_HT, succ_sub_SIMS_G, succ_sub_SIMS_pair, succ_sub_SIMS_type, succ_sub_SIMS_cumu_HT, succ_sub_cumu_account, succ_sub_SIMS_cumu_HT_log, succ_sub_cumu_account_log, succ_sub_pca_all_coords, succ_sub_phate_all_coords]
fail_arrays_to_sort = [fail_sub_SIMS_dp_og, fail_sub_SIMS_T, fail_sub_SIMS_HT, fail_sub_SIMS_G, fail_sub_SIMS_pair, fail_sub_SIMS_type, fail_sub_SIMS_cumu_HT, fail_sub_cumu_account, fail_sub_SIMS_cumu_HT_log, fail_sub_cumu_account_log, fail_sub_pca_all_coords, fail_sub_phate_all_coords]

# Use list comprehension and zip to sort all arrays simultaneously
sorted_succ_arrays = [np.array(arr,dtype=object)[sorted_succ_indices] for arr in succ_arrays_to_sort] 
sorted_fail_arrays = [np.array(arr,dtype=object)[sorted_fail_indices] for arr in fail_arrays_to_sort]

# Unpack the sorted arrays into separate variables
sorted_succ_sub_SIMS_dp_og, sorted_succ_sub_SIMS_T, sorted_succ_sub_SIMS_HT, sorted_succ_sub_SIMS_G, sorted_succ_sub_SIMS_pair, sorted_succ_sub_SIMS_type, sorted_succ_sub_SIMS_cumu_HT, sorted_succ_sub_cumu_account, sorted_succ_sub_SIMS_cumu_HT_log, sorted_succ_sub_cumu_account_log, sorted_succ_sub_pca_all_coords, sorted_succ_sub_phate_all_coords = sorted_succ_arrays
sorted_fail_sub_SIMS_dp_og, sorted_fail_sub_SIMS_T, sorted_fail_sub_SIMS_HT, sorted_fail_sub_SIMS_G, sorted_fail_sub_SIMS_pair, sorted_fail_sub_SIMS_type, sorted_fail_sub_SIMS_cumu_HT, sorted_fail_sub_cumu_account, sorted_fail_sub_SIMS_cumu_HT_log, sorted_fail_sub_cumu_account_log, sorted_fail_sub_pca_all_coords, sorted_fail_sub_phate_all_coords = sorted_fail_arrays

sorted_succ_sub_SIMS_dp_og.shape, sorted_fail_sub_SIMS_dp_og.shape

In [None]:
# # for traj 42
# print(sorted_succ_sub_SIMS_HT[0][sorted_succ_sub_SIMS_type[0]=="SMH"].sum())
# print(sorted_succ_sub_SIMS_HT[0][sorted_succ_sub_SIMS_type[0]=="0MH"].sum())
# print(sorted_succ_sub_SIMS_HT[0][sorted_succ_sub_SIMS_type[0]=="SM0"].sum())

## Interactive Plot

In [None]:
# make dataframe for plotting   
df = pd.DataFrame(data={
            "Energy": SIMS_G_uniq, "Pair": SIMS_pair_uniq, "DP": SIMS_dp_og_uniq, 
            "HT": SIMS_HT_uniq, "Type": SIMS_type_uniq,
            "HT_cumu": SIMS_cumu_HT_uniq, "Account_cumu": cumu_account_uniq,
            "HT_cumu_log": SIMS_cumu_HT_uniq_log, "Account_cumu_log": cumu_account_uniq_log,
            "PCA 1": pca_coords[:,0], "PCA 2": pca_coords[:,1],
            "PHATE 1": phate_coords[:,0], "PHATE 2": phate_coords[:,1]
            }
            )

dfsucc = pd.DataFrame(data={
        "Energy": sorted_succ_sub_SIMS_G, "Pair": sorted_succ_sub_SIMS_pair,"DP": sorted_succ_sub_SIMS_dp_og,
        "HT": sorted_succ_sub_SIMS_HT, "TotalT": sorted_succ_sub_SIMS_T,
        "HT_cumu": sorted_succ_sub_SIMS_cumu_HT, "Account_cumu": sorted_succ_sub_cumu_account,
        "HT_cumu_log": sorted_succ_sub_SIMS_cumu_HT_log, "Account_cumu_log": sorted_succ_sub_cumu_account_log,
        "PCA": sorted_succ_sub_pca_all_coords, "PHATE": sorted_succ_sub_phate_all_coords,
        "IDX": sorted_succ_indices, "Type": sorted_succ_sub_SIMS_type
        }
        )

dffail = pd.DataFrame(data={
        "Energy": sorted_fail_sub_SIMS_G, "Pair": sorted_fail_sub_SIMS_pair,"DP": sorted_fail_sub_SIMS_dp_og,
        "HT": sorted_fail_sub_SIMS_HT, "TotalT": sorted_fail_sub_SIMS_T,
        "HT_cumu": sorted_fail_sub_SIMS_cumu_HT, "Account_cumu": sorted_fail_sub_cumu_account,
        "HT_cumu_log": sorted_fail_sub_SIMS_cumu_HT_log, "Account_cumu_log": sorted_fail_sub_cumu_account_log,
        "PCA": sorted_fail_sub_pca_all_coords, "PHATE": sorted_fail_sub_phate_all_coords,
        "IDX": sorted_fail_indices, "Type": sorted_fail_sub_SIMS_type
        }
        )

In [None]:
###############################################################################
# plot 2D energy landscape
###############################################################################

def interactive_plotly_2D(df,dfsucc,dffail,rxn_name,vis):
    fig = go.Figure()
    
    # plot energy landscape background
    fig.add_trace(go.Scattergl(
            x=df["{} 1".format(vis)], 
            y=df["{} 2".format(vis)], 
            mode='markers',
            marker=dict(
                sizemode='area',
                # size=4,
                size=df["HT"],
                sizeref=1e-5,
                color=df["Energy"],
                colorscale="Plasma",
                showscale=True,
                # colorbar_x=-0.2,
                colorbar=dict(
                    title="Free energy (kcal/mol)",  
                    x=-0.2,
                    titleside="top",  
                    len=1.065,
                    y=0.5,
                ),
                line=dict(width=0),
            ),
            text=df['DP'],
            customdata=np.stack((df["Type"],
                        ),axis=-1),
            hovertemplate=
                "DP notation: <br> <b>%{text}</b><br>" +
                "X: %{x}   " + "   Y: %{y} <br>"+
                "Energy:  %{marker.color:.3f} kcal/mol<br>"+
                "Expected holding time:  %{marker.size:.3e} s<br>"+
                "Type: %{customdata[0]}",
            name="Energy landscape",
            # showlegend=False,
            visible='legendonly',
        )
    )
    
    
    # plot type background
    # Define a color mapping for each unique value in the "Types" column
    color_mapping = {
        "000": "green",
        "00H": "blue",
        "0M0": "pink",
        "0MH": "purple",
        "SMH": "red",
        "SM0": "orange",
        "S0H": "yellow",
        "S00": "grey",
    }
    fig.add_trace(go.Scattergl(
            x=df["{} 1".format(vis)], 
            y=df["{} 2".format(vis)], 
            mode='markers',
            marker=dict(
                sizemode='area',
                size=4,
                color=[color_mapping[type_val] for type_val in df["Type"]],
                showscale=False,
                line=dict(width=0),
                ),
            text=df['DP'],
            customdata=np.stack((df["HT_cumu"],
                                 df["Energy"],
                                 df["Type"],
                                 df["Account_cumu"],
                                     ),axis=-1),
            hovertemplate=
                "DP notation: <br> <b>%{text}</b><br>" +
                "X: %{x}   " + "   Y: %{y} <br>"+
                "Energy:  %{customdata[1]:.3f} kcal/mol<br>"+
                "Cumulative holding time:  %{customdata[0]:.3e} s<br>"+
                "Appearance frequency:  %{customdata[3]:d} <br>"+
                "Type: %{customdata[2]}",
            name="Type",
            # showlegend=False,
            visible='legendonly',
        )
    )


    # plot type with cumu_time size as background
    fig.add_trace(go.Scattergl(
            x=df["{} 1".format(vis)], 
            y=df["{} 2".format(vis)], 
            mode='markers',
            marker=dict(
                sizemode='area',
                # size=df["HT_cumu_log"],
                size=df["HT_cumu"],
                sizeref=5e-5, #1.3 for log
                color=[color_mapping[type_val] for type_val in df["Type"]],
                # colorscale=[[type_val, color_mapping[type_val]] for type_val in df["Type"].unique()],
                showscale=False,
                line=dict(width=0),
                ),
            text=df['DP'],
            customdata=np.stack((df["HT_cumu"],
                                 df["Energy"],
                                 df["Type"],
                                 df["Account_cumu"],
                                     ),axis=-1),
            hovertemplate=
                "DP notation: <br> <b>%{text}</b><br>" +
                "X: %{x}   " + "   Y: %{y} <br>"+
                "Energy:  %{customdata[1]:.3f} kcal/mol<br>"+
                "Cumulative holding time:  %{customdata[0]:.3e} s<br>"+
                "Appearance frequency:  %{customdata[3]:d} <br>"+
                "Type: %{customdata[2]}",
            name="Type-cumuTime",
            # showlegend=False,
            visible='legendonly',
        )
    )


    # plot type with cumu_count size as background
    fig.add_trace(go.Scattergl(
            x=df["{} 1".format(vis)], 
            y=df["{} 2".format(vis)], 
            mode='markers',
            marker=dict(
                sizemode='area',
                # size=df["Account_cumu_log"],
                size=df["Account_cumu"],
                sizeref=1,
                color=[color_mapping[type_val] for type_val in df["Type"]],
                # colorscale=[[type_val, color_mapping[type_val]] for type_val in df["Type"].unique()],
                showscale=False,
                line=dict(width=0),
                ),
            text=df['DP'],
            customdata=np.stack((df["HT_cumu"],
                                 df["Energy"],
                                 df["Type"],
                                 df["Account_cumu"],
                                     ),axis=-1),
            hovertemplate=
                "DP notation: <br> <b>%{text}</b><br>" +
                "X: %{x}   " + "   Y: %{y} <br>"+
                "Energy:  %{customdata[1]:.3f} kcal/mol<br>"+
                "Cumulative holding time:  %{customdata[0]:.3e} s<br>"+
                "Appearance frequency:  %{customdata[3]:d} <br>"+
                "Type: %{customdata[2]}",
            name="Type-freq",
            # showlegend=False,
            visible='legendonly',
        )
    )

    # plot bounded/unbounded background
    fig.add_trace(go.Scattergl(
            x=df["{} 1".format(vis)], 
            y=df["{} 2".format(vis)], 
            mode='markers',
            marker=dict(
                sizemode='area',
                size=4,
                color=df["Pair"],
                showscale=True,
                # colorbar_x=-0.2,
                colorbar=dict(
                    title="Free energy (kcal/mol)",  
                    x=-0.2,
                    titleside="top",  
                    len=1.065,
                    y=0.5,
                ),
                line=dict(width=0),
            ),
            text=df['DP'],
            customdata=np.stack((df["HT_cumu"],
                                 df["Energy"],
                                 df["Type"],
                                 df["Account_cumu"],
                                     ),axis=-1),
            hovertemplate=
                "DP notation: <br> <b>%{text}</b><br>" +
                "X: %{x}   " + "   Y: %{y} <br>"+
                "Energy:  %{customdata[1]:.3f} kcal/mol<br>"+
                "Cumulative holding time:  %{customdata[0]:.3e} s<br>"+
                "Appearance frequency:  %{customdata[3]:d} <br>"+
                "Type: %{customdata[2]}",          
            name="Bound/Unbound",
            # showlegend=False,
            visible='legendonly',
        )
    )

    # plot background
    fig.add_trace(go.Scattergl(
            x=df["{} 1".format(vis)], 
            y=df["{} 2".format(vis)], 
            mode='markers',
            marker=dict(
                sizemode='area',
                size=4,
                color='lightgrey',
                showscale=True,
                line=dict(width=0),
            ),
            text=df['DP'],
            hovertemplate= "DP notation: <br> <b>%{text}</b><br>",
            name="Background",
            visible='legendonly',
        )
    )

    # layout successful trajectory
    # for i in range(len(dfsucc)):
    idx=0
    colors=["darkorchid", "goldenrod", "black"]
    for i in [0, 47, 49]:
        Step = []
        if len(dfsucc["DP"][i]) < 2000:
            Step = np.arange(len(dfsucc["DP"][i]))
        else:
            Step = np.full(len(dfsucc["DP"][i]), None, dtype=object)  
        fig.add_trace(
            go.Scattergl(
                x=dfsucc[f"{vis}"][i][:,0],
                y=dfsucc[f"{vis}"][i][:,1],
                mode='lines+markers',
                # line=dict(
                #     color="rgba(0, 0, 0, 0.1)",
                #     width=0.3,
                # ),
                line=dict(
                    color=colors[idx],
                    width=2,
                ),
                marker=dict(
                    sizemode='area',
                    size=5,
                    # size=dfsucc["HT"][i],
                    # sizeref=5e-6,
                    # color=[color_mapping[type_val] for type_val in dfsucc["Type"][i]],
                    color=[color_mapping[type_val] for type_val in dfsucc["Type"][i]],
                    opacity=1,
                    # showscale=False,
                    colorbar=dict(
                        x=-0.2,
                        y=0.5,
                        tickvals=[],
                        len=1,
                    ),
                    line=dict(width=0),
                ),            
                text=Step,
                customdata=np.stack((dfsucc['Energy'][i],
                                     dfsucc['TotalT'][i],
                                     dfsucc['DP'][i],
                                     dfsucc['Type'][i],
                                     dfsucc['HT_cumu'][i],
                                     dfsucc['Account_cumu'][i],
                                     ),axis=-1),
                hovertemplate=
                    "Step:  <b>%{text}</b><br><br>"+
                    "DP notation: <br> <b>%{customdata[2]}</b><br>" +
                    "X: %{x}   " + "   Y: %{y} <br>"+
                    "Energy:  %{customdata[0]} kcal/mol<br>"+
                    "Holding time for last step:  %{marker.size:.3e} s<br>"+
                    "Total time until current state:  %{customdata[1]:.3e} s<br>"+
                    # "Cumulative holding time:  %{customdata[4]:.3e} s<br>"+
                    # "Appearance frequency:  %{customdata[5]:d} <br>"+
                    "Type: %{customdata[3]}", 
                visible='legendonly',
                name = "Succ {}".format(dfsucc["IDX"][i]+1),
            )
        )
        idx+=1
        
    # # layout failed trajectory on top of energy landscape
    # for i in range(len(dffail)):
    #     Step = []
    #     if len(dffail["DP"][i]) < 2000:
    #         Step = np.arange(len(dffail["DP"][i]))
    #     else:
    #         Step = np.full(len(dffail["DP"][i]), None, dtype=object)
    #     fig.add_trace(
    #         go.Scattergl(
    #             x=dffail[f"{vis}"][i][:,0],
    #             y=dffail[f"{vis}"][i][:,1],
    #             mode='lines+markers',
    #             line=dict(
    #                 color="rgba(0, 0, 0, 0.5)",
    #                 width=0.5,
    #             ),
    #             marker=dict(
    #                 sizemode='area',
    #                 # size=6,
    #                 size=dffail["TotalT"][i],
    #                 sizeref=5e-3,
    #                 color=[color_mapping[type_val] for type_val in dffail["Type"][i]],
    #                 opacity=1,
    #                 # showscale=False,
    #                 colorbar=dict(
    #                     x=-0.2,
    #                     y=0.5,
    #                     tickvals=[],
    #                     len=1,
    #                 ),
    #                 line=dict(width=0),
    #             ),    
    #             text=Step,
    #             customdata=np.stack((dffail['Energy'][i],
    #                                  dffail['TotalT'][i],
    #                                  dffail['DP'][i],
    #                                  dffail['Type'][i],
    #                                  dffail['HT_cumu'][i],
    #                                  dffail['Account_cumu'][i],
    #                                  ),axis=-1),
    #             hovertemplate=
    #                 "Step:  <b>%{text}</b><br><br>"+
    #                 "DP notation: <br> <b>%{customdata[2]}</b><br>" +
    #                 "X: %{x}   " + "   Y: %{y} <br>"+
    #                 "Energy:  %{customdata[0]} kcal/mol<br>"+
    #                 "Holding time for last step:  %{marker.size:.3e} s<br>"+
    #                 "Total time until current state:  %{customdata[1]:.3e} s<br>"+
    #                 # "Cumulative holding time:  %{customdata[4]:.3e} s<br>"+
    #                 # "Appearance frequency:  %{customdata[5]:d} <br>"+
    #                 "Type: %{customdata[3]}",
    #             visible='legendonly',
    #             name = "Fail {}".format(dffail["IDX"][i]+1),
    #         )
    #     )
    
           
    # label final states
    fig.add_trace(
        go.Scattergl(
            x=[dfsucc[f"{vis}"][0][-1,0]],
            y=[dfsucc[f"{vis}"][0][-1,1]],
            mode='markers+text',
            marker_color="lime", 
            # marker_size=15,
            marker_size=20,
            text=["F"],
            textposition="middle center",
            textfont=dict(
            family="sans serif",
            size=16,
            color="black"
        ),
            hoverinfo='skip',
            showlegend=False,
                        )
    )
    
    fig.update_xaxes(
        range=[min(df["{} 1".format(vis)])*1.1,max(df["{} 1".format(vis)])*1.1]
    )
    
    fig.update_yaxes(
        range=[min(df["{} 2".format(vis)])*1.1,max(df["{} 2".format(vis)])*1.1]
    )

    fig.update_layout(
        # autosize=True,
        # width=700,
        # height=700,
        # margin=dict(
        #     l=50,
        #     r=50,
        #     b=100,
        #     t=100,
        #     pad=4
        # ),
        title="{}: {} Vis".format(rxn_name,vis),
        title_x=0.5,
        xaxis=dict(
                title="{} 1".format(vis),
            ),
        yaxis=dict(
                title="{} 2".format(vis),
            ),
        legend=dict(
            # title="Single Trajectory",
            title_font=dict(size=10),
            font=dict(
                # family="Courier",
                size=10,
                color="black"
        )
        )
    )
    
    return fig

In [None]:
# VIS_METHOD = ["PCA", "PHATE"]
VIS_METHOD = ["PHATE"]

for vis in VIS_METHOD: 
    fig = interactive_plotly_2D(df,dfsucc,dffail,rxn_name,vis)
    # pio.write_html(fig, file=f"./data/jordan/plots/FirstStep_{vis}_{rxn_name}_{config.log_dir}.html", auto_open=True)
    # pio.write_html(fig, file=f"./data/jordan/plots/FirstStep_{vis}_{rxn_name}_0825-2043.html", auto_open=True)
    pio.write_html(fig, file=f"../output_files/saved_ViDa_plots/plot_dna29/FirstStep_{vis}_{rxn_name}_0825-2043_trj.html", auto_open=True)
    print("DONE: ", vis)

In [None]:
config.log_dir

## Feature Engineering to Find Kinetic Traps

In [None]:
## filter out data points with low frequency
filter_threshold = 0.1 # reaction 8

filter_idx = np.where(P_tot>=filter_threshold)[0]
filter_idx.shape

In [None]:
## load PCA for DBSCAN
filter_comb_pca_coords = pca_coords[filter_idx]
filter_comb_pca_coords.shape

### DBSCAN

In [None]:
## Elbow method to find eps for DBSCAN
from sklearn.neighbors import NearestNeighbors

n_neighbors = 4  # Number of neighbors to find
nbrs = NearestNeighbors(n_neighbors=n_neighbors).fit(filter_comb_pca_coords)
distances, indices = nbrs.kneighbors(filter_comb_pca_coords)
four_dist = np.sum(distances,axis=1)
sorted_four_dist = np.sort(four_dist)[::-1]

# Create a figure
fig = go.Figure()
# Add a line trace
fig.add_trace(go.Scatter(x=indices[:,0], y=sorted_four_dist, 
                         mode='lines', name='Line Plot'))
# Set labels and title
fig.update_layout(xaxis_title='points', yaxis_title='4-dist', title='Elbow')
# Show the plot
fig.show()


In [None]:
from sklearn.cluster import DBSCAN

# 0.71

X = filter_comb_pca_coords
clusters = DBSCAN(eps = 2.2, min_samples = 4).fit(X)
# get cluster labels
labels = clusters.labels_


# Number of clusters in labels, ignoring noise if present.
n_clusters_ = len(set(labels)) - (1 if -1 in labels else 0)
n_noise_ = list(labels).count(-1)

print("Estimated number of clusters: %d" % n_clusters_)
print("Estimated number of noise points: %d" % n_noise_)

# # check unique clusters
set(clusters.labels_)
# # -1 value represents noisy points could not assigned to any cluster

### Remove no trap clusters

In [None]:
real_labels = labels.copy()
for k_clust in np.unique(labels):
    min_index = np.argmin(SIMS_G_uniq[filter_idx][np.where(labels==k_clust)[0]])
    # print("For cluster {}:".format(k_clust))
    plausible_trap = SIMS_dp_og_uniq[filter_idx][np.where(labels==k_clust)[0]][min_index]
    if "("*10 in plausible_trap:
        real_labels = [-1 if x==k_clust else x for x in real_labels]
        print("Cluster {} is NOT a trap".format(k_clust))
    else:
        print("Cluster {} is a trap".format(k_clust))
        
real_labels = np.array(real_labels)

print("\nClusters with trap are: {}".format(np.unique(real_labels)))
    

In [None]:
import plotly.graph_objects as go
import numpy as np

# Sample data
X = filter_comb_pca_coords[:,0]
Y = filter_comb_pca_coords[:,1]
clusters = real_labels  # Cluster real labels

# Get unique cluster labels
unique_clusters = np.unique(clusters)

# Define colors for each cluster
noise  = 'grey'
colors = ['red', 'blue', 'green', 'orange', 'purple', 'yellow', 'cyan', 'magenta', 'lime', 'teal']  # Add more colors as needed
name_list = ['A', 'B', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'J']
trap_shape = ['star', 'x', 'triangle-up', 'cross', 'pentagon', 'diamond', 'square', 'triangle-down', 'triangle-left', 'triangle-right']

# Create a scatter trace for each cluster
traces = []
i = 0
for cluster_label in unique_clusters:
    mask = clusters == cluster_label
    if cluster_label == -1:
        # Assign a color for cluster label -1
        color = noise
        # name = 'Cluster -1'
        name = 'Noise'
        
    else:
        # Assign a color for other cluster labels
        color = colors[cluster_label]
        name = f'Cluster {name_list[i]}'
        i += 1
    
    trace = go.Scattergl(
        x=X[mask],
        y=Y[mask],
        mode='markers',
        marker=dict(
            color=color,
            # size = SIMS_HT_uniq[filter_idx][mask],
            # sizeref=3e-6,
            size=5,
            sizemode='diameter',
            ),
        name=name,
        showlegend=True,
        
        customdata = np.stack((SIMS_G_uniq[filter_idx][mask],
                           SIMS_HT_uniq[filter_idx][mask],
                           P_tot[filter_idx][mask],
                           ),axis=-1),
        text = SIMS_dp_og_uniq[filter_idx][mask],
        hovertemplate=
            "X: %{x}   " + "   Y: %{y} <br>"+
            "DP notation: <br> <b>%{text}</b><br>" +  
            "Energy:  %{customdata[0]:.3f} kcal/mol<br>"+
            "Average holding time:  %{customdata[1]:.5g} s<br>"+
            "Probability:  %{customdata[2]:.2g} <br>",
    )

    traces.append(trace)

# label final states
final_idx = np.where(SIMS_dp_og_uniq == final_states)[0][0]
trace = go.Scattergl(
        x = [pca_coords[final_idx][0]],
        y = [pca_coords[final_idx][1]],
        mode='markers+text',
        
        marker_color="lime", 
        marker_size=12,
        text=["F"],
        textposition="middle center",
        textfont=dict(
        family="sans serif",
        size=10,
        color="black"
    ),
        hoverinfo='skip',
        showlegend=False,
                    )

traces.append(trace)

# label kinetic traps
i = 0
for k_clust in np.unique(clusters):
    if k_clust == -1:
        continue
    
    min_index = np.argmin(SIMS_G_uniq[filter_idx][np.where(clusters==k_clust)[0]])
    trace = go.Scattergl(
        x = np.array(X[np.where(clusters==k_clust)[0]][min_index]),
        y = np.array(Y[np.where(clusters==k_clust)[0]][min_index]),
        mode='markers',
        marker=dict(
            color="black",
            symbol=trap_shape[i],
            size=10,
        ),
        name = f"Trap {name_list[i]}",
        showlegend=True,
    )
    i += 1
    traces.append(trace)

# legend setting
layout = go.Layout(
    legend=dict(
        # x=0.5,  # Adjust the x position of the legend
        # y=0.5,  # Adjust the y position of the legend
        # font=dict(
        #     size=10  # Adjust the font size of the legend
        # ),
        itemsizing='constant',
    ),
    xaxis=dict(
        range = [min(X)*1.1,max(X)*1.1],
    ),
    yaxis=dict(
        range = [min(Y)*1.1,max(Y)*1.1],
    ),
    title=f"DBSCAN finding Kinetic Traps for sample {rxn_name}",
)

# Create a figure
fig = go.Figure(data=traces, layout=layout)

# Show the plot
fig.show()


In [None]:
# print out the kinetic trap in each cluster
for k_clust in np.unique(real_labels):
    if k_clust == -1:
        continue
    min_index = np.argmin(SIMS_G_uniq[filter_idx][np.where(labels==k_clust)[0]])
    print("Kinetic trap in cluster {} is:".format(k_clust))
    print(SIMS_dp_og_uniq[filter_idx][np.where(labels==k_clust)[0]][min_index])

### Trajectory analysis

In [None]:
## exact each trajectory
split_id = trj_id + 1 # index for split to each trajectory
traj_in_clust = np.zeros(len(np.unique(labels)), dtype=int)
avg_time_in_clust = np.zeros(len(np.unique(labels)), dtype=float)

for i in range(len(split_id)):
    if i == 0:
        trj_dp = SIMS_dp_og[0:split_id[i]]
    else:
        trj_dp = SIMS_dp_og[split_id[i-1]:split_id[i]]

    for j, k_clust in enumerate(np.unique(labels)):
        mask = labels == k_clust
        if np.size(np.intersect1d(trj_dp, SIMS_dp_og_uniq[filter_idx][mask])) != 0:
            traj_in_clust[j] += 1
            avg_time_in_clust[j] += SIMS_T[trj_id[i]]

print(f"{rxn_name}:")
for i in range(len(traj_in_clust)):
    print("{} trajs in cluster {}. Average time: {:.3e}.".format(traj_in_clust[i], np.unique(labels)[i], avg_time_in_clust[i]/traj_in_clust[i]))


## Check

In [None]:
SIMS_scar_uniq[:3]

In [None]:
SIMS_dp_og_uniq[:3], SIMS_dp_og_uniq[-3:], SIMS_dp_og_uniq[1000], SIMS_dp_og_uniq[1050]

In [None]:
from sklearn.metrics import mean_squared_error

mse01 = mean_squared_error(SIMS_scar_uniq[0], SIMS_scar_uniq[1])
print("Mean Squared Error:", mse01)

mse02 = mean_squared_error(SIMS_scar_uniq[0], SIMS_scar_uniq[2])
print("Mean Squared Error:", mse02)

mse12 = mean_squared_error(SIMS_scar_uniq[1], SIMS_scar_uniq[2])
print("Mean Squared Error:", mse12)

mse01000 = mean_squared_error(SIMS_scar_uniq[0], SIMS_scar_uniq[1000])
print("Mean Squared Error:", mse01000)


In [None]:
mse0m1 = mean_squared_error(SIMS_scar_uniq[0], SIMS_scar_uniq[-1])
print("Mean Squared Error:", mse0m1)

msem1m2 = mean_squared_error(SIMS_scar_uniq[-1], SIMS_scar_uniq[-2])
print("Mean Squared Error:", msem1m2)

mse0m2 = mean_squared_error(SIMS_scar_uniq[0], SIMS_scar_uniq[-2])
print("Mean Squared Error:", mse0m2)

mse0m3 = mean_squared_error(SIMS_scar_uniq[0], SIMS_scar_uniq[-3])
print("Mean Squared Error:", mse0m3)

mse10001050 = mean_squared_error(SIMS_scar_uniq[1000], SIMS_scar_uniq[1050])
print("Mean Squared Error:", mse10001050)

