In [None]:
import sys
sys.path.append("/software/path/prefix/NvTK/")
import h5py, os, argparse, logging, time

import numpy as np
import pandas as pd

import torch
from torch import nn
from torch.optim import Adam
from torch.utils.data import DataLoader

import NvTK
from NvTK import Trainer
from NvTK.Evaluator import calculate_correlation, calculate_pr, calculate_roc
from NvTK.Explainer import get_activate_W, meme_generate, save_activate_seqlets, calc_frequency_W

import matplotlib.pyplot as plt
from NvTK.Explainer import seq_logo, plot_seq_logo

#from NvTK import resnet18
from NvTK.Modules import BasicPredictor
# set_all_random_seed
NvTK.set_random_seed()
NvTK.set_torch_seed()
NvTK.set_torch_benchmark()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

sys.path.append("/file/path/prefix/")
n_tasks=50040 #cell num
from ResNeXt_conv1_128_btnk_2dense import *
model = resnext34(num_classes = n_tasks)

# define criterion
criterion = nn.BCELoss().to(device)

# define optimizer
optimizer = Adam(model.parameters(), lr=1e-3, betas=(0.9, 0.999), eps=1e-08, weight_decay=0,)

# define trainer
trainer = Trainer(model, criterion, optimizer, device, 
                    patience=10, tasktype='binary_classification', metric_sample=100,
                    use_tensorbord=False)
## reload best model
model = trainer.load_best_model('./Log/best_model.pth')
model.to(device)
model.eval()

In [None]:
#load testing data

In [None]:
# unpack h5file
h5file = h5py.File('/file/path/prefix/Femalemus_5wCells_17wPeaks_8wnegative.shuffled.noimpute.500bp.20230822.h5', 'r')
X = h5file["pmat"]["X"][:].swapaxes(-1,1).astype(np.float32)
peak_idx = h5file['pmat']['pmat_sc']['i'][:]
cell_idx = h5file['pmat']['pmat_sc']['j'][:]
x = h5file['pmat']['pmat_sc']['x'][:]
dim = h5file['pmat']['pmat_sc']['dim'][:]
y = np.zeros((dim[0], dim[1]), dtype = np.float32)
y[peak_idx, cell_idx] = x
features = h5file["pmat"]["peak"][:]
h5file.close()


# unpack anno
n_tasks = y.shape[-1]
mask = features[:,-1].astype(str)
train_idx, val_idx, test_idx = mask=='train', mask=='val', mask=='test'
x_train, x_val, x_test = X[train_idx], X[val_idx], X[test_idx]
y_train, y_val, y_test = y[train_idx], y[val_idx], y[test_idx]
# np.sum(train_idx), np.sum(val_idx), np.sum(test_idx)


# define data loader
batch_size =100
test_loader = DataLoader(list(zip(x_test, y_test)), batch_size=batch_size, 
                            shuffle=False, num_workers=0, drop_last=False, pin_memory=True)

In [None]:
test_loader 

In [None]:
_, _, test_predictions, test_targets = trainer.predict(test_loader)

In [None]:
test_predictions.shape

In [None]:
#filter

In [None]:
from NvTK.Explainer import get_fmap, meme_generate, calc_frequency_W,get_activate_W_from_fmap
from NvTK.Explainer.Featuremap import ActivateFeaturesHook

In [None]:
fmap, X = get_fmap(model, model.conv1, test_loader)
fmap.shape, X.shape

In [None]:
fmap=fmap.squeeze()

In [None]:
fmap.shape

In [None]:
os.makedirs('./Motif',exist_ok=True)

In [None]:
from NvTK.Explainer import get_activate_W_from_fmap

In [None]:
W= get_activate_W_from_fmap(fmap, X, pool=1, threshold=0.9, motif_width=7,pad=3)
W.shape

In [None]:
from NvTK.Explainer.MotifVisualize import plot_filter_heatmap,filter_heatmap
import  matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
plot_filter_heatmap(W, factor=None, fig_size=(10,4), save=False)

In [None]:
meme_generate(W, output_file="./Motif/meme_conv1.txt")

In [None]:
#calculate frequency

In [None]:
from NvTK.Explainer.Motif import calc_frequency_W,calc_motif_IC,calc_motif_frequency

In [None]:
W_freq,W_IC = calc_frequency_W(W, background=0.25)

In [None]:
len(W_freq)

In [None]:
W_freq[:5]

In [None]:
pd.DataFrame({"freq":W_freq, "IC":W_IC}).to_csv("./Motif/W_IC_freq.csv")