In [None]:
import sys
sys.path.append("/software/path/prefix/NvTK/")
import h5py, os, argparse, logging, time

import numpy as np
import pandas as pd

import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader

import NvTK
from NvTK import Trainer
from NvTK.Evaluator import calculate_correlation, calculate_pr, calculate_roc
from NvTK.Explainer import get_activate_W, meme_generate, save_activate_seqlets, calc_frequency_W

import matplotlib.pyplot as plt
from NvTK.Explainer import seq_logo, plot_seq_logo

#from NvTK import resnet18
from NvTK.Modules import BasicPredictor
# set_all_random_seed
NvTK.set_random_seed()
NvTK.set_torch_seed()
NvTK.set_torch_benchmark()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

sys.path.append("/file/path/prefix/")
n_tasks=50029 #cell num
from ResNeXt_conv1_128_btnk_2dense import *
model = resnext34(num_classes = n_tasks)

# define criterion
criterion = nn.BCELoss().to(device)

# define optimizer
optimizer = Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0,)

# define trainer
trainer = Trainer(model, criterion, optimizer, device, 
                    patience=10, tasktype='binary_classification', metric_sample=100,
                    use_tensorbord=False)
## reload best model
model = trainer.load_best_model('./Log/best_model.pth')
model.eval()

In [None]:
# unpack h5file
h5file = h5py.File('/file/path/prefix/Gecko_5wCells_13wPeaks_6wnegative.shuffled.noimpute.500bp.20230901.h5', 'r')
X = h5file["pmat"]["X"][:].swapaxes(-1,1).astype(np.float32)
peak_idx = h5file['pmat']['pmat_sc']['i'][:]
cell_idx = h5file['pmat']['pmat_sc']['j'][:]
x = h5file['pmat']['pmat_sc']['x'][:]
dim = h5file['pmat']['pmat_sc']['dim'][:]
y = np.zeros((dim[0], dim[1]), dtype = np.float32)
y[peak_idx, cell_idx] = x
features = h5file["pmat"]["peak"][:]
h5file.close()



# unpack anno
n_tasks = y.shape[-1]
mask = features[:,-1].astype(str)
train_idx, val_idx, test_idx = mask=='train', mask=='val', mask=='test'
x_train, x_val, x_test = X[train_idx], X[val_idx], X[test_idx]
y_train, y_val, y_test = y[train_idx], y[val_idx], y[test_idx]
# np.sum(train_idx), np.sum(val_idx), np.sum(test_idx)


# define data loader
batch_size =100
train_loader = DataLoader(list(zip(x_train, y_train)), batch_size=batch_size,
                            shuffle=True, num_workers=0, drop_last=False, pin_memory=True)
validate_loader = DataLoader(list(zip(x_val, y_val)), batch_size=batch_size, 
                            shuffle=False, num_workers=0, drop_last=False, pin_memory=True)
test_loader = DataLoader(list(zip(x_test, y_test)), batch_size=batch_size, 
                            shuffle=False, num_workers=0, drop_last=False, pin_memory=True)

In [None]:
test_loader

In [None]:
_, _, test_predictions, test_targets = trainer.predict(test_loader)

In [None]:
test_predictions.shape

In [None]:
test_targets.shape

In [None]:
import pandas as pd
anno = pd.read_table('./Gecko_anno_5w.txt')
anno

In [None]:
pd.value_counts(anno.id)

In [None]:
pred_data=pd.DataFrame(test_predictions).T
pred_data['subcluster'] = anno['id'].values
pred_data_mean = pred_data.groupby(['subcluster']).mean()
pred_data_mean

In [None]:
target_data=pd.DataFrame(test_targets).T
target_data['subcluster'] = anno['id'].values
target_data_mean = target_data.groupby(['subcluster']).mean()
target_data_mean

In [None]:
corr = np.corrcoef(pred_data_mean.T, target_data_mean.T, rowvar=False)
corr.shape

In [None]:
corr_pt = pd.DataFrame(corr[:int(corr.shape[0]/2),int(corr.shape[1]/2):], #corr[:12,12:],
                       index=pred_data_mean.T.columns,#.map(lambda x:x+"_pred"),
                       columns=target_data_mean.T.columns)#.map(lambda x:x+"_target"))

In [None]:
corr_pt

In [None]:
np.mean(np.diagonal(corr_pt.to_numpy()))

In [None]:
import seaborn as sns
g2 = sns.clustermap(corr_pt, cmap='viridis',
               #col_colors=anno_color[["colors_lineage"]],
               #row_colors=anno_color[["colors_lineage"]],
               row_cluster=False, 
               col_cluster=False, 
#                standard_scale=1,
               z_score=1,
               vmin=0, vmax=1
              )
g2.savefig("./PTheatmap_pred-targ.pdf")

In [None]:
color = ("#CB99CC",  "#ED6245", "#6D6CF5", "#CCCD67",  "#FCE28D", "#E9F297"  , "#BEE6A0",   "#2D9687",   "#3288BD","#83D4D8")
regions = ("Endothelial", "Epithelial", "Erythroid" , "Hepatocyte", "Immune","Muscle","Neural","Secretory","Stromal",'Reproductive')
color_regions = {x:y for x,y in zip(regions, color)}
color_regions

In [None]:
anno_color = anno
anno_color["colors_lineage"] = anno_color[['lineage']].applymap(lambda x: color_regions[x])
anno_color

In [None]:
celltype_anno = anno_color[["id", "lineage", "colors_lineage"]].drop_duplicates().set_index(["id"]).loc[np.unique(anno.id.values)].colors_lineage
celltype_anno

In [None]:
lut = {cluster:color_regions.get(cluster) for cluster in anno_color.lineage.unique()}
lut

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
from matplotlib.patches import Patch

In [None]:
output_fname='./PTheatmap_pred-targ_legend_gecko.pdf'
plt.figure(figsize=(25,25))
g= sns.clustermap(corr_pt, cmap='viridis',
               #col_colors=anno_color[["colors_lineage"]],
               #row_colors=anno_color[["colors_lineage"]],
               row_cluster=False, 
               col_cluster=False, 
#                standard_scale=1,
               z_score=1,
               vmin=0, vmax=1,col_colors=celltype_anno,row_colors=celltype_anno,colors_ratio=0.02
              )

handles = [Patch(facecolor=lut[name]) for name in lut]
plt.legend(handles, lut, title='CellLieange',
               bbox_to_anchor=(0.15, 0.75), bbox_transform=plt.gcf().transFigure, loc='upper right')

plt.savefig(output_fname)
plt.show()
plt.close()