In [None]:
import numpy as np
import matplotlib.pyplot as plt
import pickle
from misc import *

from sklearn.decomposition import PCA
import phate


## 1. Import/generate Data

### 1. Load simulated data from Mulistrand

In [None]:
# load text file
# f = open('./data/helix_assos/assos_PT3_1sim_20C_21.txt', 'r') # PT3 
# STRAND_NAME = "assos_PT3_1sim_20C_21"

f = open('./data/helix_assos/assos_PT0_1sim_20C_51.txt', 'r') # PT0
STRAND_NAME = "assos_PT0_1sim_20C_51"

""" Dimenstions of SIM list 
SIM: [[sim1], [sim2], ...]
sim: [[state1], [state2], ...]
state: [structure, time, energy]
"""
# define absorbing (final) state structure
FINAL_STRUCTURE = "(((((((((((((((((((((((((+)))))))))))))))))))))))))"

SIM = loadtrj(f,FINAL_STRUCTURE,type="Multiple")
SIM_retrieve = np.array(SIM)
SIM_concat = concat_helix_structures(SIM) 

print("SIM: ", len(SIM))
print("SIM_retrieve: ", SIM_retrieve.shape)
print("SIM_concat: ", len(SIM_concat))

### 2. Convert dot-paren to adjacency matrix

In [None]:
""" Dimenstions of SIM_adj list 
SIM_adj: N*m*m
    N: number of states in the trajectory
    m: number of nucleotides in the state (strand)
"""
# get adjacency matrix, energy, and holding time for each state
SIM_adj,SIM_G,SIM_T,SIM_HT = sim_adj(SIM_concat)

In [None]:
SIM_adj.shape,SIM_G.shape,SIM_T.shape,SIM_HT.shape

In [None]:
# get unique states adjacency matrix with their occupancy density
# get unique energy, and time;
# and their corresponding indices
indices,occ_density,SIM_adj_uniq,SIM_G_uniq,SIM_T_uniq,SIM_HT_uniq \
     = get_unique(SIM_concat,SIM_adj,SIM_G,SIM_T,SIM_HT)

SIM_adj_uniq.shape,SIM_G_uniq.shape,SIM_T_uniq.shape,SIM_HT_uniq.shape

### 3. Get labeled trajectory data

In [None]:
# get trajectory data with its corresponding labels 
SIM_dict = label_structures(SIM_concat,indices) 
coord_id = SIM_dict[:,3].astype(int)
SIM_dict.shape, coord_id.shape

In [None]:
# find the structure having the largest occupancy density
SIM_retrieve[indices[occ_density.argmax()]]

### 4. Convert adjacency matrix scattering coefficients

In [None]:
# convert all states
scat_coeff_array = transform_dataset(SIM_adj)
norm_scat_coeffs = get_normalized_moments(scat_coeff_array).squeeze()
SIM_scar = norm_scat_coeffs
SIM_scar.shape

In [None]:
# convert only unique states to get unique scattering
scat_coeff_array = transform_dataset(SIM_adj_uniq)
norm_scat_coeffs = get_normalized_moments(scat_coeff_array).squeeze()
SIM_scar_uniq = norm_scat_coeffs
SIM_scar_uniq.shape

### 5. Split data into tranning and test sets

In [None]:
"""Shape of split data
    train_data: [tr_adjs, tr_coeffs, tr_energies]
    test_data: [te_adjs, te_coeffs, te_energies]
"""
train_data,test_data = split_data(SIM_adj_uniq,SIM_scar_uniq,SIM_G_uniq)

In [None]:
(np.unique(train_data[1],axis=0)).shape, train_data[1].shape,test_data[1].shape

### 6. Train and test dataloader

In [None]:
"""Structure of train_tup when gnn=False
    train_tup: [train_coeffs,train_energy] 
"""
train_loader, train_tup, test_tup, valid_loader,early_stop_callback = load_trte(train_data,test_data,
                                              batch_size=8)
train_tup[0].shape, test_tup[0].shape, train_loader.batch_size

## 2.1 Load Model

In [None]:
# set up hyperparameters

input_dim = train_tup[0].shape[-1]
len_epoch = len(train_loader)

hparams = {
    'input_dim':  input_dim,
    'bottle_dim': 25,
    'hidden_dim': 400, #not used in model
    
    'len_epoch': len_epoch,
    'learning_rate': 0.0001,
    'max_epochs': 60,  # PT0 --> 1985 epoch  # PT3 --> 60， 100, 150, 756(overtfit) epoch
    'n_gpus': 0,
    'batch_size': 8, #not used in model
    
    'alpha':1.0,
    'beta':0.0001,

}

In [None]:
hparams = argparse.Namespace(**hparams)

In [None]:
model = GSAE(hparams)

In [None]:
print(model)

## 2.2 Train Model

In [None]:
trainer = pl.Trainer.from_argparse_args(hparams,
                                        max_epochs=hparams.max_epochs,
                                        gpus=hparams.n_gpus,
                                        # callbacks=[early_stop_callback],
                                        )
trainer.fit(model=model,
            train_dataloader=train_loader,
            val_dataloaders=valid_loader,)

In [None]:
model

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/ --host localhost --port 8088
#  http://localhost:8088

In [None]:
# save the trained model
# from data I1_10000sim.txt SIM[1]
fname_model = "models/{}_model_{}epoch.pickle".format(STRAND_NAME,hparams.max_epochs)
pickle.dump(model, open(fname_model, 'wb'))
print('Trained model saved.')

In [None]:
fname_model

## 3. Load Pretrained Models

In [None]:
fname_model = "models/assos_PT3_1sim_20C_21_model_60epoch.pickle"
# fname_model = "models/assos_PT0_1sim_20C_51_model_1985epoch.pickle"

model = pickle.load(open(fname_model, 'rb'))
model

## 4. Get Embeddings

In [None]:
# without duplicates
with torch.no_grad():
        data_embed = model.embed(torch.Tensor(SIM_scar_uniq))[0]

In [None]:
# do PCA for GSAE embeded data
pca_coords = PCA(n_components=3).fit_transform(data_embed)

# get all pca embedded states coordinates
pca_all_coords = pca_coords[coord_id]

pca_coords.shape, pca_all_coords.shape

In [None]:
# do PHATE for GSAE embeded data
phate_operator = phate.PHATE(n_jobs=-2)
phate_coords = phate_operator.fit_transform(data_embed)

# get all phate embedded states coordinates
phate_all_coords = phate_coords[coord_id]

phate_coords.shape, phate_all_coords.shape

In [None]:
""" Save all obtained data to npz file
"""
fname_data = "data/helix_assos/{}_{}epoch.npz".format(STRAND_NAME,hparams.max_epochs)
with open(fname_data, 'wb') as f:
    np.savez(f,
            SIM_adj=SIM_adj,SIM_scar=SIM_scar,SIM_G=SIM_G,SIM_HT=SIM_HT,
            SIM_adj_uniq=SIM_adj_uniq, SIM_scar_uniq=SIM_scar_uniq,
            SIM_G_uniq=SIM_G_uniq, SIM_HT_uniq=SIM_HT_uniq,
            # SIM_dict=SIM_dict, 
            occp=occ_density,
            data_embed=data_embed, coord_id=coord_id,
            pca_coords=pca_coords, pca_all_coords=pca_all_coords,
            phate_coords=phate_coords, phate_all_coords=phate_all_coords,
            )

In [None]:
# """ Save all obtained data to hf5 file
# """
# fname_data_h5 = "data/helix_assos/{}_{}epoch.h5".format(STRAND_NAME,hparams.max_epochs)
# save_h5(fname_data_h5,
#             SIM_adj, SIM_scar, SIM_G, SIM_HT,
#             SIM_adj_uniq, SIM_scar_uniq, SIM_G_uniq, SIM_HT_uniq,
#             # SIM_dict, 
#             occ_density, data_embed, coord_id,
#             pca_coords, pca_all_coords,
#             phate_coords, phate_all_coords)


## 5. Visualize

In [None]:
# fname_data = "/Users/chenwei/Desktop/Github/RPE/code/data/helix_assos/assos_PT3_1sim_20C_21_60epoch.npz"
fname_data = "/Users/chenwei/Desktop/Github/RPE/code/data/helix_assos/assos_PT0_1sim_20C_51_1985epoch.npz"

npyfile = np.load(fname_data)
npyfile.files

### 1. PCA Vis

In [None]:
X = npyfile["pca_all_coords"][:,0]
Y = npyfile["pca_all_coords"][:,1]
Z = npyfile["pca_all_coords"][:,2]

# PCA: 2 components
fig,ax = plt.subplots(figsize=(8,6))
im = ax.scatter(X, Y, 
          c=npyfile["SIM_G"], 
          cmap='plasma',
        )

plt.colorbar(im)

annotations=["I","F"]
x = [X[0],X[-1]]
y = [Y[0],Y[-1]]
plt.scatter(x,y,s=150, c="green", alpha=1)
for i, label in enumerate(annotations):
    plt.annotate(label, (x[i]-0.3,y[i]-0.3),fontsize=15,c="yellow")

In [None]:
X = npyfile["pca_coords"][:,0]
Y = npyfile["pca_coords"][:,1]
Z = npyfile["pca_coords"][:,2]

# PCA: 2 components
fig,ax = plt.subplots(figsize=(8,6))
im = ax.scatter(X, Y, 
          c=npyfile["SIM_G_uniq"], 
          cmap='plasma',
        )

plt.colorbar(im)

annotations=["I","F"]
x = [X[0],X[-1]]
y = [Y[0],Y[-1]]
plt.scatter(x,y,s=150, c="green", alpha=1)
for i, label in enumerate(annotations):
    plt.annotate(label, (x[i]-0.3,y[i]-0.3),fontsize=15,c="yellow")

In [None]:
X = npyfile["pca_coords"][:,0]
Y = npyfile["pca_coords"][:,1]
Z = npyfile["pca_coords"][:,2]

# PCA: 3 components
fig,ax = plt.subplots(figsize=(8,6))
ax = plt.axes(projection ="3d")

im = ax.scatter3D(X,Y,Z,
          c=npyfile["SIM_G_uniq"],
          cmap='plasma')
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.set_zlabel("Z")
plt.colorbar(im)

annotations=["I","F"]
x = [X[0],X[-1]]
y = [Y[0],Y[-1]]
z = [Z[0], Z[-1]]
ax.scatter(x,y,z,s=100,c="green",alpha=1)

#### Try use PCA directly without AE

In [None]:
pca_coords1 = PCA(n_components=3).fit_transform(npyfile["SIM_scar_uniq"])
pca_coords1.shape

X = pca_coords1[:,0]
Y = pca_coords1[:,1]
Z = pca_coords1[:,2]

# PCA: 2 components
fig,ax = plt.subplots(figsize=(8,6))
im = ax.scatter(X, Y, 
          c=npyfile["SIM_G_uniq"], 
          cmap='plasma',
        )

plt.colorbar(im)

annotations=["I","F"]
x = [X[0],X[-1]]
y = [Y[0],Y[-1]]
plt.scatter(x,y,s=150, c="green", alpha=1)
for i, label in enumerate(annotations):
    plt.annotate(label, (x[i]-0.3,y[i]-0.3),fontsize=15,c="black")

### 2. PHATE Vis

In [None]:
X_phate = npyfile["phate_all_coords"][:,0]
Y_phate = npyfile["phate_all_coords"][:,1]

fig,ax = plt.subplots(figsize=(8,6))
im = ax.scatter(X_phate,Y_phate,
                c=npyfile["SIM_G"], 
                cmap='plasma',
               )

plt.colorbar(im)

annotations=["I","F"]
x = [X_phate[0],X_phate[-1]]
y = [Y_phate[0],Y_phate[-1]]
plt.scatter(x,y,s=50, c="green", alpha=1)
for i, label in enumerate(annotations):
    plt.annotate(label, (x[i],y[i]),fontsize=30,c="black")

In [None]:
X_phate = npyfile["phate_coords"][:,0]
Y_phate = npyfile["phate_coords"][:,1]

fig,ax = plt.subplots(figsize=(8,6))
im = ax.scatter(X_phate,Y_phate,
                c=npyfile["SIM_G_uniq"], 
                cmap='plasma',
               )

plt.colorbar(im)

annotations=["I","F"]
x = [X_phate[0],X_phate[-1]]
y = [Y_phate[0],Y_phate[-1]]
plt.scatter(x,y,s=50, c="green", alpha=1)
for i, label in enumerate(annotations):
    plt.annotate(label, (x[i],y[i]),fontsize=30,c="black")

#### PHATE without AE

In [None]:
phate_operator = phate.PHATE(n_jobs=-2)
phate1 = phate_operator.fit_transform(npyfile["SIM_scar_uniq"])

fig,ax = plt.subplots(figsize=(8,6))
im = ax.scatter(phate1[:,0],
          phate1[:,1],
          c=npyfile["SIM_G_uniq"], 
          cmap='plasma',
        )

plt.colorbar(im)

annotations=["I","F"]
x = [phate1[:,0][0],phate1[:,0][-1]]
y = [phate1[:,1][0],phate1[:,1][-1]]
plt.scatter(x,y,s=50, c="green", alpha=1)
for i, label in enumerate(annotations):
    plt.annotate(label, (x[i],y[i]),fontsize=20,c="black")