In [1]:
%matplotlib inline

import numpy as np
import pickle
import json
from argparse import Namespace
import plotly.express as px
import matplotlib.pyplot as plt
import time
import yaml
from sklearn.preprocessing import MinMaxScaler
from sklearn.neighbors import NearestNeighbors
from sklearn.preprocessing import StandardScaler

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

import networkx as nx
from pyvis.network import Network

from sklearn.decomposition import PCA
import phate
from umap import UMAP
from sklearn.manifold import TSNE

# Load simulated data

In [2]:
"""load saved trajectories data for npz file
"""
SEQ = "PT4"
# SEQ = "PT4_hairpin"

# laod pre-training data
fnpz_data = "./data/pretraining/pretraining_{}.npz".format(SEQ)
data_npz = np.load(fnpz_data)

# asssign data to variables
for var in data_npz.files:
     locals()[var] = data_npz[var]

# recover full data based on coord_id, indices, and unique data
SIMS_adj = SIMS_adj_uniq[coord_id_S]
SIMS_scar = SIMS_scar_uniq[coord_id_S]
SIMS_G = SIMS_G_uniq[coord_id_S]
SIMS_pair = SIMS_pair_uniq[coord_id_S]

print(SIMS_T.shape,SIMS_HT.shape,SIMS_HT_uniq.shape)
print(SIMS_adj.shape,SIMS_scar.shape,SIMS_G.shape,SIMS_HT.shape,SIMS_pair.shape)
print(SIMS_adj_uniq.shape,SIMS_scar_uniq.shape,SIMS_G_uniq.shape,SIMS_pair_uniq.shape) 
print(SIMS_dict.shape,SIMS_dict_uniq.shape)
print(coord_id_S.shape,indices_S.shape,trj_id.shape,occ_density_S.shape)

(621984,) (621984,) (46606,)
(621984, 50, 50) (621984, 4000) (621984,) (621984,) (621984,)
(46606, 50, 50) (46606, 4000) (46606,) (46606,)
(621984, 5) (46606, 4)
(621984,) (46606,) (100,) (46606,)


# Construct graph

### Construct weights (expected holding time)

In [3]:
print(SIMS_HT_uniq.max(), SIMS_HT_uniq[SIMS_HT_uniq!=0].min(), SIMS_HT_uniq.std(), SIMS_HT_uniq.shape)

3.602668547999989e-08 5.3379999818823536e-14 1.648221154642124e-09 (46606,)


In [None]:
print("Possible kinetic trap: \n", 
      SIMS_dict_uniq[SIMS_HT_uniq.argmax()])

### Construct edges

In [5]:
all_nodes = np.array(SIMS_dict[:,-1], dtype=int)
print(all_nodes.shape, all_nodes, all_nodes.max())
print("Initial node:", all_nodes[0], "  Final node:", all_nodes[-1])

(621984,) [  0   1   2 ... 288 289 291] 46605
Initial node: 0   Final node: 291


In [6]:
np.all(coord_id_S == np.array(SIMS_dict[:,-1], dtype=int))

True

In [118]:
# pairwise nodes
# note: this step connect final->initial node,
# which will be remove later
all_edges_temp = []
for previous, current in zip(all_nodes, all_nodes[1:]):
    all_edges_temp.append((previous, current))

In [136]:
# # remove all edges that connect final->initial node
# all_edges = list(filter((all_nodes[-1],all_nodes[0]).__ne__, all_edges_temp))
import copy

indices_to_delete = trj_id[:-1]
# # Sort the indices in reverse order so that deleting elements won't affect subsequent indices
indices_to_delete = sorted(indices_to_delete, reverse=True)

all_edges = copy.deepcopy(all_edges_temp)
for index in indices_to_delete:
    del all_edges[index]



In [139]:
# sanity check
if (all_nodes[-1],all_nodes[0]) in all_edges:
    print("yes")
else:
    print(f"There are no edges from final node {all_nodes[-1]} to initial node {all_nodes[0]}")
    
print("before prune: ", len(all_edges))
print("after prune: ",len(all_edges_temp))
print("difference: ", len(all_edges_temp) - len(all_edges))

There are no edges from final node 291 to initial node 0
before prune:  621884
after prune:  621983
difference:  99


In [180]:
len(all_edges)

621884

In [191]:
all_edges[200]

(92, 93)

In [207]:
SIMS_HT_uniq[all_edges[0][0]]

6.146118170353792e-10

In [208]:
SIMS_HT_uniq

array([6.14611817e-10, 9.86611864e-10, 3.16825110e-09, ...,
       1.41202200e-10, 2.80204730e-09, 9.62795700e-10])

### Construct directed weight graph

In [823]:
DG = nx.DiGraph()
for i in range(len(all_edges)):
    weight = SIMS_HT_uniq[all_edges[i][0]]  
    DG.add_edge(int(all_edges[i][0]), int(all_edges[i][1]), weight = float(weight))


### Construct modified undirected weight graph

In [141]:
MUG = nx.Graph()

for  i in range(len(all_edges)):
    idx0 = all_edges[i][0]
    idx1 = all_edges[i][1]
    
    if SIMS_HT_uniq[idx0] < SIMS_HT_uniq[idx1]:
        weight = SIMS_HT_uniq[idx0]
    else:
        weight = SIMS_HT_uniq[idx1]
        
    if SIMS_HT_uniq[idx0] == 0 or SIMS_HT_uniq[idx1] == 0:
        weight = SIMS_HT_uniq[idx0] + SIMS_HT_uniq[idx1]
        
    MUG.add_edge(int(all_edges[i][0]), int(all_edges[i][1]), weight = float(weight))
    
    
MUG.get_edge_data(98,99), MUG.get_edge_data(97,98), MUG.get_edge_data(98,97)

(None, None, None)

### Collect and save KNN=n neighbor's shortest paths

In [None]:
# %%script false --no-raise-error
start_time = time.time()
# collect the shortest path for each node
knn = 100
X_j = []
D_ij = []

# for i in range(len(SIMS_HT_uniq)):
for i in range(200):

    # length = nx.single_source_dijkstra_path_length(DG, i)
    length = nx.single_source_shortest_path_length(DG, i, cutoff=100)
    
    length_arr = np.array(list(length.items()), dtype=object)

    
    # # find the shortest path that is less than cutoff
    # # and get the index of the node
    # pos = np.where(length_arr[:,1] < cutoff)
    # Xj = length_arr[pos][:,0]
    # dij = length_arr[pos][:,1]
    # shortestpath_dict[i] = Xj, dij
    
    # get the nearest three nodes
    # X_j.append(length_arr[1:knn+1,0])
    # D_ij.append(length_arr[1:knn+1,1])
    
    # X_j.append(length_arr[:100,0])
    # D_ij.append(length_arr[:100,1])
    
    
    X_j.append(length_arr[:,0])
    D_ij.append(length_arr[:,1])
    
   
    
    if i % 5000 == 0:
        print(i)
        
X_j = np.array(X_j)
D_ij = np.array(D_ij)      

end_time = time.time()  
print(f"--- {end_time - start_time} seconds ---")
  
# # save npz file for shortest path
# with open(f'./data/graph/{SEQ}/shortestpath_knn={knn}.npz', 'wb') as f:
#     np.savez(f,
#              X_j = X_j,
#              D_ij = D_ij,
#          )

In [14]:
xj = np.array(np.load(f'./data/graph/{SEQ}/shortestpath_knn=100.npz',allow_pickle=True)["X_j"], dtype=int)
dij = np.array(np.load(f'./data/graph/{SEQ}/shortestpath_knn=100.npz',allow_pickle=True)["D_ij"], dtype=float)

print(xj.shape, dij.shape)
print(dij.max(), dij.min(), dij.mean(), dij.std())

(46606, 100) (46606, 100)
4.666732692179992e-08 5.3379999818823536e-14 4.69795412390628e-09 3.3059767747225673e-09


In [283]:
X_j[9].shape

(37755,)

In [284]:
len(set(X_j[9]))

37755

In [727]:
import networkx as nx
import heapq

def dijkstra_n_shortest_paths(graph, source, n_neigh=100):
    # Initialize data structures
    visited = set()
    distances = {node: float('infinity') for node in graph}
    distances[source] = 0
    priority_queue = [(0, source)]
    paths = []

    # Main loop
    while priority_queue and len(paths) < n_neigh:
                
        _, current_node = heapq.heappop(priority_queue)
        
        if current_node in visited:
            continue

        visited.add(current_node)

        for neighbor, weight in graph[current_node].items():
            if neighbor not in visited:
                tentative_distance = distances[current_node] + weight['weight']

                if tentative_distance < distances[neighbor]:
                    distances[neighbor] = tentative_distance
                    heapq.heappush(priority_queue, (tentative_distance, neighbor))
                            
        paths.append((source, current_node, distances[current_node]))
        
    return np.array(paths)

In [832]:
x_j, d_ij = [], []
n_neigh = 100

for i in range(len(DG.nodes)):
    shortest_path_i = dijkstra_n_shortest_paths(DG, i)
    x_j.append(shortest_path_i[:,1].astype(int))
    d_ij.append(shortest_path_i[:,2].astype(float))
    
    if len(x_j[i]) < n_neigh:
        x_j[i] = np.pad(x_j[i], (0, n_neigh-len(x_j[i])), 'constant', constant_values=i)
        d_ij[i] = np.pad(d_ij[i], (0, n_neigh-len(d_ij[i])), 'constant', constant_values=0)
            
            

In [827]:
np.array(x_j, dtype=int)

array([[    0,     1,   292, ..., 37257,  3035, 36833],
       [    1,     2,  8474, ..., 34214, 13588, 42054],
       [    2,     3,  6084, ...,    28,    29, 17792],
       ...,
       [46603, 46602,  4917, ...,  4972, 22639,  8378],
       [46604, 11522,   282, ..., 25239, 30666,   291],
       [46605,  7753,  3031, ..., 46605, 46605, 46605]])

In [829]:
np.array(d_ij, dtype=float)

array([[0.00000000e+00, 6.14611817e-10, 6.14611817e-10, ...,
        1.14994725e-09, 1.17423519e-09, 1.18058478e-09],
       [0.00000000e+00, 9.86611864e-10, 9.86611864e-10, ...,
        5.35741198e-09, 5.36544467e-09, 5.41582160e-09],
       [0.00000000e+00, 3.16825110e-09, 3.16825110e-09, ...,
        5.95488306e-08, 6.04428456e-08, 6.05665658e-08],
       ...,
       [0.00000000e+00, 1.41202200e-10, 4.03455995e-09, ...,
        9.50789295e-09, 9.50789295e-09, 9.52487436e-09],
       [0.00000000e+00, 2.80204730e-09, 3.97413663e-09, ...,
        1.48185491e-08, 1.49415307e-08, 1.50258665e-08],
       [0.00000000e+00, 9.62795700e-10, 3.14380751e-09, ...,
        0.00000000e+00, 0.00000000e+00, 0.00000000e+00]])

In [842]:
x_j

[array([    0,     1,   292,  3033,  3034,  3254,  5110,  5648,  5650,
         5652,  5660,  6080,  7306,  7756,  7757, 11527, 13235, 13439,
        13585, 14003, 14112, 14114, 14266, 14267, 14554, 15135, 15438,
        15439, 16014, 16145, 16974, 17240, 17360, 17948, 21797, 22688,
        22784, 22996, 23365, 24914, 26533, 27360, 32595, 33469, 33740,
        36067, 36830, 37255, 37894, 38564, 39354, 39708, 40233, 42311,
        44559, 44746, 45850,  5111, 36831, 39355, 26534, 26535, 42050,
        14115, 14137, 40641, 36832, 37256, 17949, 21461, 22785, 13440,
        13441, 14004, 45851, 40234, 36068, 27361, 15440, 20799, 45146,
        13442, 13586, 14454, 16850, 19251, 28546, 34213, 34918, 42499,
        11528, 11529, 22786, 17361, 40507, 43009,  3255, 37257,  3035,
        36833]),
 array([    1,     2,  8474, 10282, 10283, 10286, 13586, 15440, 15442,
        28148, 36347, 39716, 40321, 41975, 42051, 45783, 40322, 45784,
        41976, 45785, 40323, 36348, 10284, 45786, 39717, 154

In [834]:
# normalize the distance
min_val = np.min(d_ij)
max_val = np.max(d_ij)
normalized_minmax = (d_ij - min_val) / (max_val - min_val)
normalized_minmax

array([[0.        , 0.00326542, 0.00326542, ..., 0.00610964, 0.00623868,
        0.00627242],
       [0.        , 0.00524184, 0.00524184, ..., 0.02846379, 0.02850647,
        0.02877412],
       [0.        , 0.01683284, 0.01683284, ..., 0.3163814 , 0.32113128,
        0.32178861],
       ...,
       [0.        , 0.0007502 , 0.02143551, ..., 0.05051519, 0.05051519,
        0.05060541],
       [0.        , 0.01488721, 0.02111449, ..., 0.07873057, 0.07938397,
        0.07983204],
       [0.        , 0.00511531, 0.01670297, ..., 0.        , 0.        ,
        0.        ]])

In [843]:
np.where(normalized_minmax == normalized_minmax.max())

(array([12, 12]), array([98, 99]))

In [None]:
from sklearn.preprocessing import MinMaxScaler

scaler = MinMaxScaler(feature_range=(1,3))  # range [1,3]
# scaler = MinMaxScaler()  # range [1,3]

norm_Dij = scaler.fit_transform(D_ij.reshape(-1,1)) 
# norm_Dij = scaler.fit_transform(D_ij) 

print(norm_Dij.max(), norm_Dij.min(), norm_Dij.mean(), norm_Dij.std())

In [None]:
norm_Dij.reshape(10,-1)

In [None]:
norm_Dij.reshape(10,-1)[0][:30]

In [None]:
norm_Dij[0][:30]

In [None]:
D_ij[0][:30]

In [651]:
X_j = np.array(np.load(f'./data/graph/{SEQ}/shortestpath_knn=100.npz',allow_pickle=True)["X_j"], dtype=int)
D_ij = np.array(np.load(f'./data/graph/{SEQ}/shortestpath_knn=100.npz',allow_pickle=True)["D_ij"], dtype=float)

print(X_j.shape, D_ij.shape)
print(D_ij.max(), D_ij.min(), D_ij.mean(), D_ij.std())

(46606, 100) (46606, 100)
4.666732692179992e-08 5.3379999818823536e-14 4.69795412390628e-09 3.3059767747225673e-09


#### Normalize distance 

In [None]:
from sklearn.preprocessing import MinMaxScaler

# # standardize the distance matrix
# std_Dij = (D_ij - D_ij.mean()) / D_ij.std()
# print(std_Dij.max(), std_Dij.min(), std_Dij.mean(), std_Dij.std())
# print()

## normalize the distance matrix
# norm_Dij = (D_ij - D_ij.min()) / (D_ij.max() - D_ij.min()) # range [0,1]
scaler = MinMaxScaler(feature_range=(1,3))  # range [1,3]
# scaler = MinMaxScaler(feature_range=(1,10))  # range [1,10]

norm_Dij = scaler.fit_transform(D_ij) 
print(norm_Dij.max(), norm_Dij.min(), norm_Dij.mean(), norm_Dij.std())
# print(np.where(norm_Dij == 3))

# log_Dij = np.log(D_ij)+np.abs(np.log(D_ij).min())+1
# print(log_Dij.max(), log_Dij.min(), log_Dij.mean(), log_Dij.std())

### Find the importance weight 

In [None]:
# calculate the probability of being visited during a simulated trajectory 
# from the initial state
split_id = trj_id + 1 # index for split to each trajectory
P_tot = np.zeros(len(SIMS_dict_uniq))

for i in range(len(split_id)):
    if i == 0:
        trj = set(SIMS_dict[0:split_id[i],4].astype(int))
    else:
        trj = set(SIMS_dict[split_id[i-1]:split_id[i],4].astype(int))

    P_tot[list(trj)] += 1

P_tot = P_tot / 100

P_tot.shape, P_tot.max(), P_tot.min()

# Edit Distance

In [None]:
%%script false --no-raise-error

def edit_distance(adj1, adj2):
    # Calculate edit distance based on its adjacency matrix
    edit_dist = np.sum(np.abs(adj1-adj2),dtype=int)
    
    return edit_dist


# Calculate the edit distance between X_i and X_j'
ED_ij = []
for i in range(X_j.shape[0]):
    ed_ij = []
    for j in range(X_j.shape[1]):
        ed_ij.append(edit_distance(SIMS_adj_uniq[i], SIMS_adj_uniq[X_j[i,j]]))
    ED_ij.append(ed_ij)
ED_ij = np.array(ED_ij, dtype=int)
    
# save npz file for shortest path
with open(f'./data/graph/{SEQ}_edit_distance.npz', 'wb') as f:
    np.savez(f,
             ED_ij = ED_ij,
         )

### Load edit distance

In [None]:
ED_ij = np.array(np.load(f'./data/graph/{SEQ}/{SEQ}_edit_distance.npz',allow_pickle=True)["ED_ij"], dtype=float)
X_j = np.array(np.load(f'./data/graph/{SEQ}/shortestpath_knn=100.npz',allow_pickle=True)["X_j"], dtype=int)
print(ED_ij.max(), ED_ij.min(), ED_ij.mean(), ED_ij.std(), ED_ij.shape)

# ViDa Model

### Make dataset

In [None]:
data_tup = (torch.Tensor(SIMS_scar_uniq),
            torch.Tensor(SIMS_G_uniq),
            torch.arange(len(SIMS_scar_uniq)))
data_dataset = torch.utils.data.TensorDataset(*data_tup)

In [None]:
# split the dataset into train and validation
data_size = len(data_dataset)
train_size = int(0.7 * data_size)

val_size = data_size - train_size

train_data, val_data = torch.utils.data.random_split(data_dataset, [train_size, val_size], 
                                                     generator=torch.Generator().manual_seed(42))

print(data_size, len(train_data), len(val_data))

### Set up configurations

In [None]:
# set up hyperparameters
input_dim = data_tup[0].shape[-1]

config = Namespace(
    seq = SEQ,
    type = 'vida',
    knn = 100,
    edit_distance = "100_neighbors",
    device = 'mps', # change to cuda if using Nvida GPU
    log_dir = f'{time.strftime("%m%d-%H%M")}', # log directory
    batch_size = 256,
    input_dim = input_dim,
    output_dim = input_dim,
    latent_dim = 25, # bottleneck dimension
    hidden_dim = 400,
    n_epochs = 150, #(try 60 for pt4-hairpin)
    
    learning_rate = 1e-4, # learning rate #try 6e-5 for pt4-hairpin
    
    log_interval = 10, # how many batches to wait before logging training status    
    patience = 20, # how many epochs to wait before early stopping    
      
    # hyperparameters for loss function
    alpha = 1.0, # reconstruction loss
    
    beta = 1e-4, # kl divergence
    # beta = 1e-5, # kl divergence
    
    gamma = 0.3, # energy loss   # 0.3 for PT4; 
    # gamma = 1, # energy loss

    delta = 1e-4, # distance loss # 0.04 for PT4;
    # delta = 4e-4, # distance loss # 0.04 for PT4;
  
    epsilon = 1e-4, # edit distance loss
    # epsilon = 5e-5, # edit distance loss
    
)

### Make datalaoder

In [None]:
data_loader = torch.utils.data.DataLoader(data_dataset, batch_size=config.batch_size,
                                          shuffle=False, num_workers=0)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=config.batch_size, 
                                           shuffle=True, num_workers=0)
val_loader = torch.utils.data.DataLoader(val_data, batch_size=config.batch_size,
                                         shuffle=False, num_workers=0)

In [None]:
print('data_loader: ', len(data_loader.dataset), len(data_loader), data_loader.batch_size)
print('train_loader: ', len(train_loader.dataset), len(train_loader), train_loader.batch_size)
print('val_loader: ', len(val_loader.dataset), len(val_loader), val_loader.batch_size)

In [None]:
data_loader.dataset.tensors[0]

In [None]:
data_loader.dataset.tensors[1]

### Encoder

In [None]:
class Encoder(nn.Module):
    
    def __init__(self, input_dim, hidden_dim, latent_dim):
        '''
        Args:
        ----
            - input_dim: the dimension of the input node feature
            - hiddent_dim: the dimension of the hidden layer
            - latent_dim: the dimension of the latent space (bottleneck layer)
        '''
        super(Encoder, self).__init__()
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.latent_dim = latent_dim
        
        self.fc1 = nn.Linear(self.input_dim, self.hidden_dim)
        self.bn1 = nn.BatchNorm1d(self.hidden_dim)
        self.fc2 = nn.Linear(self.hidden_dim, 400)
        self.bn2 = nn.BatchNorm1d(400)
        
        # Split the result into mu and var components
        # of the latent Gaussian distribution, note how we only output
        # diagonal values of covariance matrix. Here we assume
        # they are conditionally independent
        self.hid2mu = nn.Linear(400, self.latent_dim)
        self.hid2logvar = nn.Linear(400, self.latent_dim)
        
    def forward(self, x):
        x = self.bn1(F.relu(self.fc1(x)))
        x = self.bn2(F.relu(self.fc2(x)))
        mu = self.hid2mu(x)
        logvar = self.hid2logvar(x)
        return mu, logvar


### Decoder

In [None]:
class Decoder(nn.Module):
    
    def __init__(self, latent_dim, hidden_dim, output_dim):
        '''
        Args:
        ----
            - latent_dim: the dimension of the latent space (bottleneck layer)
            - hiddent_dim: the dimension of the hidden layer
            - output_dim: the dimension of the output node feature
        '''
        super(Decoder, self).__init__()
        self.latent_dim = latent_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        
        self.fc1 = nn.Linear(self.latent_dim, self.hidden_dim)
        self.bn1 = nn.BatchNorm1d(self.hidden_dim)
        self.fc2 = nn.Linear(400, 400)
        self.bn2 = nn.BatchNorm1d(400)
        self.fc3 = nn.Linear(400, self.output_dim)
        
    def forward(self, z):
        x = self.bn1(F.relu(self.fc1(z)))
        x = self.bn2(F.relu(self.fc2(x)))
        x = self.fc3(x)
        return x

### Regressor

In [None]:
class Regressor(nn.Module):
    
    def __init__(self, latent_dim):
        '''
        The regressor is used to predict the energy of the node
        
        Args:
        ----
            - latent_dim: the dimension of the latent space (bottleneck layer)
        '''
        super(Regressor, self).__init__()
        self.latent_dim = latent_dim
        
        self.regfc1 = nn.Linear(self.latent_dim, 15)
        self.regfc2 = nn.Linear(15, 1)
        
    def forward(self, z):
        y = F.relu(self.regfc1(z))
        y = self.regfc2(y)
        return y

### VIDA model

In [None]:
class VIDA(nn.Module):
    
    def __init__(self, encoder, decoder, regressor):
        '''
        Args:
        ----
            - input_dim: the dimension of the input node feature
            - hiddent_dim: the dimension of the hidden layer
            - latent_dim: the dimension of the latent space (bottleneck layer)
            - output_dim: the dimension of the output node feature (same as input_dim)
        '''
        super(VIDA, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.regressor = regressor
        
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        z = mu + eps*std
        return z
        
    def forward(self, x):
        mu, logvar = self.encoder(x)
        z = self.reparameterize(mu, logvar)
        x_recon = self.decoder(z)
        y_pred = self.regressor(z)
        return x_recon, y_pred, z, mu, logvar

### Loss functions

In [None]:
def vae_loss(x_recon, x, mu, logvar):
    '''
    Compute the VAE loss
    
    Args:
        - x_recon: the reconstructed node feature
        - x: the original node feature
        - mu: the mean of the latent space
        - logvar: the log variance of the latent space
    
    Returns:
    - loss: PyTorch Tensor containing (scalar) the loss for the VAE
    '''
    BCE = F.mse_loss(x_recon.flatten(), x.flatten()) # L2 loss
    # BCE = F.l1_loss(x_recon.flatten(), x.flatten()) # L1 loss
    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    return BCE, KLD


def pred_loss(y_pred, y):
    '''
    Compute the energy prediction loss
    
    Args:
    ----
        - y_pred: the predicted energy of the node
        - y: the true energy of the node
    
    Returns:
        - loss: PyTorch Tensor containing (scalar) the loss for the prediction
    '''
    return F.mse_loss(y_pred.flatten(), y.flatten())


def distance_knn_loss_torch(config, zi, zj, Dij, P_tot, idx, batchXj_id):
    '''
    Compute the distance loss between embeddings 
    and the minimum expected holding time
    
    Args:
        - zi: the embedding of the node i
        - zj: the embedding of the nodes j's
        - Dij: the post-processing distance between nodes i and j's
        - P_tot: the total probability of the nodes i and j's
        - idx: the index of the node i
        - batchXj_id: the index of the nodes j's
    
    Returns:
    - loss: PyTorch Tensor containing (scalar) the loss for the embedding distance
    '''
    zi = zi.reshape(-1,1,zi.shape[-1])
    l2_zizj = torch.sqrt(torch.sum((zi-zj)**2, dim=-1))
    dist_diff = (l2_zizj - (Dij[idx]).to(config.device))**2

    wij = (P_tot[idx].reshape(-1,1) * P_tot[batchXj_id]).to(config.device) # importance weight of nodes i and j       
    dist_loss = torch.sum(dist_diff * wij)
    return dist_loss

def edit_distance_loss(config, zi, zj, ED_ij, idx):
    '''
    Compute the edit distance loss between embeddings 
    
    Args:
        - zi: the embedding of the node i
        - zj: the embedding of the nodes j's
        - ED_ij: the post-processing distance between nodes i and j's
        - idx: the index of the node i
        - batchXj_id: the index of the nodes j's
    
    Returns:
    - loss: PyTorch Tensor containing (scalar) the loss for the embedding distance
    '''
    zi = zi.reshape(-1,1,zi.shape[-1])
    l2_zizj = torch.sqrt(torch.sum((zi-zj)**2, dim=-1))
    dist_diff = (l2_zizj - (ED_ij[idx]).to(config.device))**2
    # wij = (P_tot[idx].reshape(-1,1) * P_tot[batchXj_id]).to(config.device) # importance weight of nodes i and j       
    # dist_loss = torch.sum(dist_diff * wij)
    editdist_loss = torch.sum(dist_diff)
    return editdist_loss
    

### Early_stopping function

In [None]:
def early_stop(val_loss, epoch, patience):
    """
    Check if validation loss has not improved for a certain number of epochs
    
    Args:
        val_loss (float): current validation loss
        epoch (int): current epoch number
        patience (int): number of epochs to wait before stopping if validation loss does not improve
        
    Returns:
        bool: True if validation loss has not improved for the last `patience` epochs, False otherwise
    """
    if epoch == 0:
        # First epoch, don't stop yet
        return False
    else:
        # Check if validation loss has not improved for `patience` epochs
        if val_loss >= early_stop.best_loss:
            early_stop.num_epochs_without_improvement += 1
            if early_stop.num_epochs_without_improvement >= patience:
                print("Stopping early")
                return True
            else:
                return False
        else:
            # Validation loss improved, reset counter
            early_stop.best_loss = val_loss
            early_stop.num_epochs_without_improvement = 0
            return False

### Train VIDA

In [None]:
# validation function
def validate(config, model, data_loader, val_loader, P_tot, Dij, ED_ij, X_j,
             vae_loss, pred_loss, distance_knn_loss, edit_distance_loss):
    model.to(config.device)
    model.eval()
    
    val_loss = 0; val_bce = 0; val_kld = 0; val_pred = 0; val_dist = 0; val_edit = 0
    
    # Disable gradient calculation to speed up inference
    with torch.no_grad():
        for x, y, idx in val_loader:
            
            # Configure input
            x = x.to(config.device)
            y = y.to(config.device)
            
            # forward X_j
            model.eval()
            with torch.no_grad():
                batchXj_id = X_j[idx]
                neighbor_input = data_loader.dataset.tensors[0][batchXj_id].reshape(-1, config.input_dim).to(config.device)
                _, _, neighbor_embed, _, _ = model(neighbor_input)
                neighbor_embed = neighbor_embed.reshape(-1, config.knn, neighbor_embed.shape[-1])
                
            # embedding
            x_recon, y_pred, z, mu, logvar = model(x)
            
            # # compute the total loss
            # vae loss
            recon_loss, kl_loss = vae_loss(x_recon, x, mu, logvar)
            recon_loss = recon_loss.item()
            kl_loss = kl_loss.item()
            
            # energy prediction loss
            p_loss = pred_loss(y_pred, y).item()
                
            # distance loss
            dist_loss = distance_knn_loss(config, z, neighbor_embed, Dij, P_tot, idx, batchXj_id).item()
            
            # edit distance loss
            edit_loss = edit_distance_loss(config, z, neighbor_embed, ED_ij, idx).item()
            
            # scaling the loss
            recon_loss = config.alpha * recon_loss
            kl_loss = config.beta * kl_loss
            p_loss = config.gamma * p_loss
            dist_loss = config.delta * dist_loss
            edit_loss = config.epsilon * edit_loss
            
            # total loss
            loss = recon_loss + kl_loss + p_loss + dist_loss + edit_loss
            
            val_loss += loss
            val_bce += recon_loss
            val_kld += kl_loss
            val_pred += p_loss
            val_dist += dist_loss
            val_edit += edit_loss
                        
    print('Validation Loss: {:.4f}'.format(val_loss/len(val_loader.dataset)))
    
    # Clear the cache
    torch.cuda.empty_cache()
    # torch.mps.empty_cache()
    
    return val_loss/len(val_loader.dataset), val_bce/len(val_loader.dataset), val_kld/len(val_loader.dataset), val_pred/len(val_loader.dataset), val_dist/len(val_loader.dataset), val_edit/len(val_loader.dataset)
            

In [None]:
def train(config, model, data_loader, train_loader, val_loader, P_tot, Dij, ED_ij, X_j,
          optimizer, scheduler, vae_loss, pred_loss, distance_knn_loss, edit_distance_loss, early_stop,):
    '''
    Train VIDA!
    
    Args:
    ----
        - config: Experiment configurations
        - model: Pytorch VIDA model
        - data_loader: Pytorch DataLoader for all data
        - train_loader: Pytorch DataLoader for training set
        - val_loader: Pytorch DataLoader for validation set
        - P_tot: the total probability of each node
        - Dij: the shortest path distance between k nearest pairs of nodes
        - X_j: the index of k nearest neighbours of each node
        - optimizer: Pytorch optimizer
        - scheduler: Pytorch learning rate scheduler
        - vae_loss: the VAE loss function
        - pred_loss: the energy prediction loss function
        - distant_knn_loss: the k nearest neighbour distance loss function
    '''
    
    model.to(config.device)
    
    log_dir = f'./model_config/{config.log_dir}'
    writer = SummaryWriter(log_dir=log_dir)
    
    # save the config file
    with open(f'{log_dir}/hparams.yaml','w') as f:
        yaml.dump(config,f)
    
    # Initialize early stop object
    early_stop.best_loss = np.Inf
    early_stop.num_epochs_without_improvement = 0

    # convert the numpy array to tensor
    P_tot = torch.from_numpy(P_tot.astype(np.float32))
    Dij = torch.from_numpy(Dij.astype(np.float32))
    ED_ij = torch.from_numpy(ED_ij.astype(int))
    X_j = torch.from_numpy(X_j)
    
    print('\n ------- Start Training -------')
    for epoch in range(config.n_epochs):
        start_time = time.time()
        training_loss = 0
        
        for batch_idx, (x, y, idx) in enumerate(train_loader):  # mini batch
            
            # Configure input
            x = x.to(config.device)
            y = y.to(config.device)
            
            # forward X_j
            model.eval()
            with torch.no_grad():
                batchXj_id = X_j[idx]
                neighbor_input = data_loader.dataset.tensors[0][batchXj_id].reshape(-1, config.input_dim).to(config.device)
                _, _, neighbor_embed, _, _ = model(neighbor_input)
                neighbor_embed = neighbor_embed.reshape(-1, config.knn, neighbor_embed.shape[-1])
                
            # ------------------------------------------
            #  Train VIDA
            # ------------------------------------------
            model.train()
            optimizer.zero_grad()
            
            # get the reconstructed nodes, predicted energy, and the embeddings
            x_recon, y_pred, z, mu, logvar = model(x)
        
            ## compute the total loss
            # vae loss
            recon_loss, kl_loss = vae_loss(x_recon, x, mu, logvar)
            
            # energy prediction loss
            p_loss = pred_loss(y_pred, y)

            # distance loss
            dist_loss = distance_knn_loss(config, z, neighbor_embed, Dij, P_tot, idx, batchXj_id)

            # edit distance loss
            edit_loss = edit_distance_loss(config, z, neighbor_embed, ED_ij, idx)
            
            # scaling the loss
            recon_loss = config.alpha * recon_loss
            kl_loss = config.beta * kl_loss
            p_loss = config.gamma * p_loss
            dist_loss = config.delta * dist_loss
            edit_loss = config.epsilon * edit_loss
            
            # total loss
            loss = recon_loss + kl_loss + p_loss + dist_loss + edit_loss
            
            # backpropagation and optimization
            loss.backward()
            optimizer.step()
            
            training_loss += loss.item()
            
            # ------------------------------------------
            # Log Progress
            # ------------------------------------------
            if batch_idx % config.log_interval == 0:
                print('Train Epoch: {}/{} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, config.n_epochs, batch_idx * len(x), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader),
                    loss.item()))
                
                writer.add_scalar('training loss',
                                  loss.item(),
                                  epoch * len(train_loader) + batch_idx)
                writer.add_scalar('recon loss',
                                  recon_loss.item(),
                                  epoch * len(train_loader) + batch_idx)
                writer.add_scalar('kl loss',
                                  kl_loss.item(),
                                  epoch * len(train_loader) + batch_idx)
                writer.add_scalar('pred loss',
                                  p_loss.item(),
                                  epoch * len(train_loader) + batch_idx)
                writer.add_scalar('dist loss',
                                  dist_loss.item(),
                                  epoch * len(train_loader) + batch_idx)
                writer.add_scalar('edit loss',
                                  edit_loss.item(),
                                  epoch * len(train_loader) + batch_idx)
    
        print ('====> Epoch: {} Average loss: {:.4f}'.format(epoch, training_loss/len(train_loader.dataset)))
        writer.add_scalar('epoch training loss', training_loss/len(train_loader.dataset), epoch)
        
        # validation
        val_loss, val_bce, val_kld, val_pred, val_dist, val_edit = validate(config, model, data_loader, val_loader, P_tot, Dij, ED_ij, X_j, 
                                                                  vae_loss, pred_loss, distance_knn_loss, edit_distance_loss)
        writer.add_scalar('validation loss', val_loss, epoch)
        writer.add_scalar('val_recon loss', val_bce, epoch)
        writer.add_scalar('val_kl loss', val_kld, epoch)
        writer.add_scalar('val_pred loss', val_pred, epoch)
        writer.add_scalar('val_dist loss', val_dist, epoch)
        writer.add_scalar('val_edit loss', val_edit, epoch)

        # log the learning rate
        current_lr = optimizer.param_groups[0]['lr']
        writer.add_scalar('learning rate', current_lr, epoch)
        
        # update the learning rate
        scheduler.step(val_loss)
        
        # timing
        end_time = time.time()
        epoch_time = end_time - start_time
        print (f'Epoch {epoch} train+val time: {epoch_time:.2f} seconds \n')
        
        # Check if validation loss has not improved for `patience` epochs
        if early_stop(val_loss, epoch, config.patience):
            break
    
        # Clear the cache
        # torch.mps.empty_cache()
        torch.cuda.empty_cache()
        
        
    writer.close()
    print('\n ------- Finished Training -------')
    
    # save the model
    torch.save(model.state_dict(), f'{log_dir}/model.pt')
    

In [None]:
# define models
encoder = Encoder(input_dim=config.input_dim, hidden_dim=config.hidden_dim, latent_dim=config.latent_dim)
decoder = Decoder(latent_dim=config.latent_dim, hidden_dim=config.hidden_dim, output_dim=config.output_dim)
regressor = Regressor(latent_dim=config.latent_dim)

# Initialize ViDa 
vida = VIDA(encoder, decoder, regressor)

# define optimizer
optimizer = torch.optim.Adam(vida.parameters(), lr=config.learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)

In [None]:
# train VIDA
train(config, vida, data_loader, train_loader, val_loader, P_tot, norm_Dij, ED_ij, X_j,
      optimizer, scheduler, vae_loss, pred_loss, distance_knn_loss_torch, edit_distance_loss)

### Load trained model

In [None]:
%%script false --no-raise-error

model = VIDA(encoder, decoder, regressor)
model.load_state_dict(torch.load('./model_config/0827-1520/model.pt',map_location=torch.device('cpu')))
# model.load_state_dict(torch.load('./model_config/0303-0320/model.pt'))

### Get embeddings

In [None]:
model = VIDA(encoder, decoder, regressor)

In [None]:
# do inference
model.to(config.device).eval()

with torch.no_grad():
        _, _, z, _, _ = model(data_loader.dataset.tensors[0].to(config.device))        

In [None]:
data_embed = z.to('cpu').numpy()
data_embed.shape

In [None]:
print(data_embed.max(), data_embed.min(), data_embed.mean(), data_embed.std())

### PCA

In [None]:
# # do PCA for GSAE embeded data
pca_coords = PCA(n_components=3).fit_transform(data_embed)

# # get all pca embedded states coordinates
pca_all_coords = pca_coords[coord_id_S]  # multiple trj

pca_coords.shape, pca_all_coords.shape

In [None]:
(np.unique(pca_coords,axis=0)).shape, (np.unique(pca_all_coords,axis=0)).shape

### PHATE

In [None]:
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
data_embed = scaler.fit_transform(data_embed)
data_embed

In [None]:
# # do PHATE for GSAE embeded data
phate_operator = phate.PHATE(n_jobs=-2)
phate_coords = phate_operator.fit_transform(data_embed)

# # get all phate embedded states coordinates
phate_all_coords = phate_coords[coord_id_S]

phate_coords.shape, phate_all_coords.shape

In [None]:
(np.unique(phate_coords,axis=0)).shape, (np.unique(phate_all_coords,axis=0)).shape

In [None]:
config.log_dir

### Direct PCA/PHATE

In [None]:
SIMS_scar_uniq.shape

In [None]:
# # do PCA for GSAE embeded data
pca_coords = PCA(n_components=2).fit_transform(SIMS_scar_uniq)

# # get all pca embedded states coordinates
pca_all_coords = pca_coords[coord_id_S]  # multiple trj

pca_coords.shape, pca_all_coords.shape

In [None]:
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
data_embed = scaler.fit_transform(SIMS_scar_uniq)
# # do PHATE for GSAE embeded data
phate_operator = phate.PHATE(n_jobs=-2)
phate_coords = phate_operator.fit_transform(data_embed)

# # get all phate embedded states coordinates
phate_all_coords = phate_coords[coord_id_S]

phate_coords.shape, phate_all_coords.shape

### Save DRs

In [None]:
# %%script false --no-raise-error

""" Save all DRs
"""
# save for python 
# fnpz_data = f"data/vida_data/{SEQ}_{config.log_dir}.npz"
fnpz_data = f"../output_files/saved_ViDa_plots/plot_dna29/dir_PCA_PHATE.npz"


with open(fnpz_data, 'wb') as f:
    np.savez(f,
            # embed data
            data_embed=data_embed,
            # plotting data
            pca_coords=pca_coords, pca_all_coords=pca_all_coords,
            phate_coords=phate_coords, phate_all_coords=phate_all_coords,
            # umap_coord_2d=umap_coord_2d, umap_all_coord_2d=umap_all_coord_2d,
            )

### MDS

#### MDS for embeded data

In [None]:
%%script false --no-raise-error

"""Both ways cause kernel crashed
This MDS is designed for after-VAE embedding or direct for scattering transformed data
"""
from sklearn.metrics import pairwise_distances
from sklearn.manifold import MDS

# method 1
X_eucl = pairwise_distances(data_embed, metric='euclidean')
mds = MDS(n_components=2, dissimilarity='precomputed')
mds_coords = mds.fit_transform(X_eucl)

# method 2
mds = MDS(n_components=2)
mds_coords = mds.fit_transform(data_embed)

#### MDS for distance matrix

In [None]:
# make precomputed distance matrix for MDS
MDS_dist = np.ones((D_ij.shape[0],D_ij.shape[0]))
for i in range(len(D_ij)):
    MDS_dist[i,X_j[i]] = D_ij[i]

In [None]:
def makeSymmetric(mat):
    # Loop to traverse lower triangular
    # elements of the given matrix
    for i in range(0, len(mat)):
        for j in range(0, len(mat)):
            if (j < i):
                mat[i][j] = mat[j][i] = min(mat[i][j], mat[j][i])
    return mat

In [None]:
MDS_dist_symm = makeSymmetric(MDS_dist)

In [None]:
from sklearn.manifold import MDS
mds = MDS(n_components=2, dissimilarity='precomputed')
mds_coords = mds.fit_transform(MDS_dist_symm)

In [None]:
# fnpz_data_embed = f"./data/vida_data/{SEQ}_0305-2258.npz"
# # fnpz_data_embed = f"./data/vida_data/{SEQ}_usePT4_03040216.npz"
# data_npz_embed = np.load(fnpz_data_embed,allow_pickle=True)
# # asssign data to variables
# for var in data_npz_embed.files:
#     globals()[var] = data_npz_embed[var]

### Evaluate embedding

#### Metric distance

In [None]:
def metric_dist(X_j, D_ij, P_tot, z):
    """
    Metric to calculate the distance 
    """
    z_re = z.reshape(-1,1,z.shape[-1])
    zj = z[X_j]
    l2_zizj = np.sqrt(np.sum((z_re-zj)**2, axis=-1))
    
    # normalize the distance
    scaler = MinMaxScaler(feature_range=(1,3)) 
    l2_zizj = scaler.fit_transform(l2_zizj)
    D_ij = scaler.fit_transform(D_ij)
    
    dist_diff = (l2_zizj - D_ij)**2
    wij = (P_tot.reshape(-1,1) * P_tot[X_j])
    dist_loss = np.sum(wij * dist_diff)/len(dist_diff)
    return dist_loss

In [None]:
pca_dist = metric_dist(X_j, D_ij, P_tot, pca_coords[:,:2])
phate_dist = metric_dist(X_j, D_ij, P_tot, phate_coords)
# umap_dist = metric_dist(X_j, D_ij, P_tot, umap_coord_2d)
print (f'PCA distance loss: {pca_dist:.4f}')
print (f'PHATE distance loss: {phate_dist:.4f}')
# print (f'UMAP distance loss: {umap_dist:.4f}')

#### Neighboring preservation rate

In [None]:
from sklearn.neighbors import NearestNeighbors

def neighboring_preservation_rate(X, X_j, k):
    """
    Metric to calculate the neighboring preservation rate 
    """
    # Compute the k-nearest neighbors for both X and Y.
    nn_X = NearestNeighbors(n_neighbors=k+1).fit(X) # k+1 because we don't want to include the point itself
    indices_X = nn_X.kneighbors(X,return_distance=False)[:,1:] # exclude the point itself
    
    # compute the rate of each point
    rate_list = []
    for i in range(len(indices_X)):
        count = len(np.intersect1d(indices_X[i], X_j[i,:k]))
        rate_i = count/k
        rate_list.append(rate_i)
        
    # Compute the overall neighboring preservation rate
    return np.mean(rate_list)

In [None]:
knnn = 100
print("PCA rate: ", neighboring_preservation_rate(pca_coords[:,:2], X_j, k=knnn))
print("PHATE rate: ", neighboring_preservation_rate(phate_coords, X_j, k=knnn))

#### PCA explained variance

In [None]:
cm = PCA(n_components=25)
cm.fit(data_embed)

PC_values = np.arange(cm.n_components_) + 1
plt.plot(PC_values, np.cumsum(cm.explained_variance_ratio_), 'ro-', linewidth=2)
plt.title('Scree Plot: PCA')
plt.xlabel('Number of principal components')
plt.ylabel('Cumulative explained variance');
# plt.xticks(np.arange(0, data_embed.shape[-1]+1, 1))

plt.show()

print(np.cumsum(cm.explained_variance_ratio_))

# Visualize

In [None]:
SEQ

### PCA Vis

In [None]:
%%script false --no-raise-error

%matplotlib inline
X = pca_all_coords[:,0]
Y = pca_all_coords[:,1]
Z = pca_all_coords[:,2]

# PCA: 2 components
fig,ax = plt.subplots(figsize=(8,6))
im = ax.scatter(X, Y, 
          c=SIMS_G,
          cmap='plasma',
          s=20
        )

plt.colorbar(im)

annotations=["I","F"]
x = [X[0],X[-1]]
y = [Y[0],Y[-1]]
plt.scatter(x,y,s=150, c="green", alpha=1)
for i, label in enumerate(annotations):
    plt.annotate(label, (x[i],y[i]*0.95),fontsize=15,c="yellow", horizontalalignment='center')

In [None]:
%matplotlib inline
X = pca_coords[:,0]
Y = pca_coords[:,1]
Z = pca_coords[:,2]

# PCA: 2 components
fig,ax = plt.subplots(figsize=(8,6))
im = ax.scatter(X, Y, 
          c=SIMS_G_uniq, 
          cmap='plasma',
        )

plt.colorbar(im)

annotations=["I","F"]
x = [X[0],X[all_nodes[-1]]]

y = [Y[0],Y[all_nodes[-1]]]
plt.scatter(x,y,s=150, c="green", alpha=1)
for i, label in enumerate(annotations):
    plt.annotate(label, (x[i]-0.3,y[i]-0.3),fontsize=15,c="yellow")

In [None]:
%%script false --no-raise-error

X = pca_coords[:,0]
Y = pca_coords[:,1]
Z = pca_coords[:,2]

# PCA: 3 components
fig,ax = plt.subplots(figsize=(8,6))
ax = plt.axes(projection ="3d")

im = ax.scatter3D(X,Y,Z,
          c=SIMS_G_uniq,      
          cmap='plasma')
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")
plt.colorbar(im)

annotations=["I","F"]
x = [X[0],X[-1]]
y = [Y[0],Y[-1]]
z = [Z[0], Z[-1]]
ax.scatter(x,y,z,s=100,c="green",alpha=1)

In [None]:
%%script false --no-raise-error

X = pca_coords[:,0]
Y = pca_coords[:,1]
Z = pca_coords[:,2]


# PCA: 2 components
fig,ax = plt.subplots(figsize=(8,6))
im = ax.scatter(X, Y,
          c=SIMS_pair_uniq,
          cmap='plasma',
          s=15
        )

plt.colorbar(im)

annotations=["I","F"]
x = [X[0],X[-1]]
y = [Y[0],Y[-1]]
plt.scatter(x,y,s=150, c="green", alpha=1)
for i, label in enumerate(annotations):
    plt.annotate(label, (x[i]-0.3,y[i]-0.3),fontsize=15,c="yellow")

#### Try use PCA directly without AE

In [None]:
# %%script false --no-raise-error

pca_coords_direct = PCA(n_components=3).fit_transform(SIMS_scar_uniq)   # multiple trj

X = pca_coords_direct[:,0]
Y = pca_coords_direct[:,1]
Z = pca_coords_direct[:,2]

# PCA: 2 components
fig,ax = plt.subplots(figsize=(8,6))
im = ax.scatter(X, Y, 
          c=SIMS_G_uniq, 
          cmap='plasma',
        )

plt.colorbar(im)

annotations=["I","F"]
x = [X[0],X[-1]]
y = [Y[0],Y[-1]]
plt.scatter(x,y,s=150, c="green", alpha=1)
for i, label in enumerate(annotations):
    plt.annotate(label, (x[i]-0.3,y[i]-0.3),fontsize=15,c="black")

In [None]:
cm = PCA(n_components=25)
cm.fit(SIMS_scar_uniq)

PC_values = np.arange(cm.n_components_) + 1
plt.plot(PC_values, np.cumsum(cm.explained_variance_ratio_), 'ro-', linewidth=2)
plt.title('Scree Plot: PCA')
plt.xlabel('Number of principal components')
plt.ylabel('Cumulative explained variance');
# plt.xticks(np.arange(0, data_embed.shape[-1]+1, 1))

plt.show()

print(np.cumsum(cm.explained_variance_ratio_))

### PHATE Vis

In [None]:
%%script false --no-raise-error

X_phate = phate_all_coords[:,0]
Y_phate = phate_all_coords[:,1]

fig,ax = plt.subplots(figsize=(8,6))
im = ax.scatter(X_phate,Y_phate,
                c=SIMS_G,   # multiple trj               
                cmap='plasma',
               )

plt.colorbar(im)

annotations=["I","F"]
x = [X_phate[0],X_phate[-1]]
y = [Y_phate[0],Y_phate[-1]]
plt.scatter(x,y,s=50, c="green", alpha=1)
for i, label in enumerate(annotations):
    plt.annotate(label, (x[i],y[i]),fontsize=30,c="black")

In [None]:
X_phate = phate_coords[:,0]
Y_phate = phate_coords[:,1]

fig,ax = plt.subplots(figsize=(8,6))
im = ax.scatter(X_phate,Y_phate,
                c=SIMS_G_uniq,            
                cmap='plasma',
               )

plt.colorbar(im)

annotations=["I","F"]
x = [X_phate[0],X_phate[all_nodes[-1]]]
y = [Y_phate[0],Y_phate[all_nodes[-1]]]
plt.scatter(x,y,s=50, c="green", alpha=1)
for i, label in enumerate(annotations):
    plt.annotate(label, (x[i],y[i]),fontsize=30,c="black")

#### PHATE without AE

In [None]:
%%script false --no-raise-error

phate_operator = phate.PHATE(n_jobs=-2)
phate1 = phate_operator.fit_transform(SIMS_scar_uniq)   # multiple trj

fig,ax = plt.subplots(figsize=(8,6))
im = ax.scatter(phate1[:,0],
          phate1[:,1],
          c=SIMS_G_uniq, 
          cmap='plasma',
        )

plt.colorbar(im)

annotations=["I","F"]
x = [phate1[:,0][0],phate1[:,0][-1]]
y = [phate1[:,1][0],phate1[:,1][-1]]
plt.scatter(x,y,s=50, c="green", alpha=1)
for i, label in enumerate(annotations):
    plt.annotate(label, (x[i],y[i]),fontsize=20,c="black")

### UMAP Vis

In [None]:
X = umap_coord_2d[:,0]
Y = umap_coord_2d[:,1]
cmap = plt.cm.plasma
cmap_r = plt.cm.get_cmap('plasma_r')

fig,ax = plt.subplots(figsize=(8,6))
im = ax.scatter(X, Y, 
          c = SIMS_G_uniq,
          cmap=cmap,
          s=10
        )
 
plt.colorbar(im)

annotations=["I","F"]
x = [X[0],X[all_nodes[-1]]]
y = [Y[0],Y[all_nodes[-1]]]
plt.scatter(x,y,s=150, c="green", alpha=1)
for i, label in enumerate(annotations):
    plt.annotate(label, (x[i]-0.3,y[i]-0.3),fontsize=15,c="yellow")

In [None]:
%%script false --no-raise-error

# directly UMAP 2D
umap_coord_2dscar = umap_2d.fit_transform(SIMS_scar_uniq)

fig_2d = px.scatter(
    umap_coord_2dscar, x=0, y=1,color=SIMS_G_uniq
)
fig_2d.update_traces(marker_size=3)
fig_2d.show()



In [None]:
%%script false --no-raise-error

fig_2d = px.scatter(
    umap_coord_2d, x=0, y=1,color=SIMS_G_uniq
)
fig_2d.update_traces(marker_size=3)


fig_3d = px.scatter_3d(
    umap_coord_3d, x=0, y=1, z=2,color=SIMS_G_uniq
)

fig_3d.update_traces(marker_size=2)

fig_2d.show()
fig_3d.show()



### MDS Vis

In [None]:
X = mds_coords[:,0]
Y = mds_coords[:,1]
cmap = plt.cm.plasma
cmap_r = plt.cm.get_cmap('plasma_r')

fig,ax = plt.subplots(figsize=(8,6))
im = ax.scatter(X, Y, 
          c = SIMS_G_uniq,
          cmap=cmap,
          s=10
        )
 
plt.colorbar(im)

annotations=["I","F"]
x = [X[0],X[all_nodes[-1]]]
y = [Y[0],Y[all_nodes[-1]]]
plt.scatter(x,y,s=150, c="green", alpha=1)
for i, label in enumerate(annotations):
    plt.annotate(label, (x[i]-0.3,y[i]-0.3),fontsize=15,c="yellow")