In [1]:
%load_ext autoreload
%autoreload 2

In [None]:
import anndata
import h5py
import matplotlib.pyplot as plt
import os
import seaborn as sns
import scanpy as sc
import pandas as pd
import torch

from scbasset.scbasset_utils import motif_score
from scbasset.utils import *
from scbasset.model_class import ModelClass
from scbasset.config import Config

### download data

### path to input data

In [None]:
start_directory = '../../'
# start_directory = ''
motif_fasta_folder = start_directory + 'data/download/Homo_sapiens_motif_fasta'

### Path to input data 

In [None]:
seq_len = 768
# type_data, file_name = 'multiome_example', 'pbmc_multiome'
type_data, file_name = 'TF_to_region', 'TF_to_region_hvg'
# type_data, file_name = 'TF_to_region', 'TF_to_region_marker_genes'

data_path = start_directory + 'data/' + type_data + '/processed/'

ad_file = data_path + file_name + '-' + str(seq_len) + '-ad.h5ad'
h5_file = data_path + file_name + '-' + str(seq_len) + '-train_val_test.h5'

### load the data and trained model

In [None]:
f = h5py.File(h5_file, 'r')
X = f['X'][:].astype('float32')
Y = f['Y'][:].astype('float32')

n_TFs = Y.shape[1]
ic(n_TFs, Y.shape[0])

In [None]:
# read h5ad file
ad = anndata.read_h5ad(ad_file)
ad

In [None]:
# os.makedirs("../..", exist_ok=True)
print(torch.cuda.is_available())
device = "cuda"
if "cuda" in device and not torch.cuda.is_available():
    device = "cpu"
else:
    torch.cuda.set_device(0)

print(device)

In [None]:
config = Config()
config.h5_file = h5_file
config.bottleneck_size = 32
config.activation_fct = 'relu'
config.model_name = 'tfbanformer'
config.num_heads = 8
config.num_transforms = 7
config.repeat = 5

print(config)

In [None]:
# load model
dashboard_model = ModelClass(config, n_TFs=n_TFs)
dashboard_model.activate_analysis()
dashboard_model.load_weights(device=device)
# dashboard_model.get_model_summary()
model = dashboard_model.model

### score a TF of interest

In [None]:
latent_representation, weights = get_latent_representation_and_weights(model, X, Y)
ic(latent_representation.shape, weights.shape)

In [None]:
# TF_act = 'CTCF'
# scores = motif_score(TF_act, model, motif_fasta_folder=motif_fasta_folder, n_TFs=n_TFs)
# ad.obs[TF_act + '_activity'] = scores
# print(ad)

In [None]:
proj = get_TF_embedding(model) 
print(len(proj))
pd.DataFrame(proj).to_csv('results/projection_atac.csv')
ad.obsm['projection'] = pd.read_csv('results/projection_atac.csv', index_col=0).values

In [None]:
ad.var

In [None]:
ad.obs

### Latent representation

In [None]:
ad_regions = ad.T
ad_regions = prepare_leiden_representation(adata=ad_regions)
# sc.pp.neighbors(ad.T, use_rep='projection')
# sc.tl.umap(ad)

In [None]:
ad_latent = sc.AnnData(latent_representation)
ad_latent = ad_latent

ad_latent.obs.index = ad_regions.obs.index
ad_latent = prepare_leiden_representation(adata=ad_latent)

ad_latent.obs['leiden_original'] = ad_regions.obs['leiden'].values
ad_regions.obs['leiden_learned'] = ad_latent.obs['leiden'].values

In [None]:
sc.pp.filter_cells(ad_latent, min_genes=0)
sc.pp.filter_genes(ad_latent, min_cells=0)

#### Plot latent representation

In [None]:
f, axs = plt.subplots(ncols=1, nrows=4, figsize=(10, 20))
sc.pl.umap(ad_regions, color='leiden', ax=axs[0], show=False)
sc.pl.umap(ad_regions, color='EXP030880.CD4_T-cells.CTCF.MA0139.1', ax=axs[1], show=False)
sc.pl.umap(ad_latent, color='leiden', ax=axs[2], show=False)
sc.pl.umap(ad_latent, color='leiden_original', ax=axs[3])
# sc.pl.umap(ad, color='cell_type', ax=axs[1])
# sc.pl.umap(ad, color=TF_act + '_activity', ax=axs[0][1], cmap='coolwarm', vmin=-2, vmax=2, show=False)
# sc.pl.umap(ad.T, color='TF', ax=axs[1][1], show=False)

# f.tight_layout()

#### Jaccard index

In [None]:
df_jaccard_matrix_latent = compute_jaccard_matrix(ad_latent)

In [None]:
sns.heatmap(df_jaccard_matrix_latent)

### TF Representation

In [None]:
ad_TF = ad

# sc.pp.neighbors(ad_TF, use_rep='projection')
ad_TF = prepare_leiden_representation(ad_TF)
ad_TF

In [None]:
ad_weights = sc.AnnData(weights)
ad_weights.obs.index = ad_TF.obs.index
ad_weights = prepare_leiden_representation(ad_weights)
ad_weights.obs['leiden_original'] = ad_TF.obs['leiden'].values
ad_TF.obs['leiden_learned'] = ad_weights.obs['leiden'].values
ad_weights

In [None]:
data = pd.DataFrame(ad_weights.X)
data

#### Plot TF represenation

In [None]:
f, axs = plt.subplots(ncols=1, nrows=3, figsize=(10, 15))
sc.pl.umap(ad_TF, color='leiden', ax=axs[0], show=False)
sc.pl.umap(ad_TF, color='leiden_learned', ax=axs[1], show=False)
# sc.pl.umap(ad_TF, color=TF_act + '_activity', ax=axs[1], cmap='coolwarm', vmin=-2, vmax=2, show=False)
sc.pl.umap(ad_weights, color='leiden', ax=axs[2], show=False)
# sc.pl.umap(ad_weights, color='leiden_original', ax=axs[3])
sc.pl.draw_graph(ad_weights, ax=axs[3])
# sc.pl.umap(ad, color='cell_type', ax=axs[1])
# sc.pl.umap(ad.T, color='TF', ax=axs[1][1], show=False)
# sc.pl.umap(ad, color='EXP030880.CD4_T-cells.CTCF.MA0139.1', ax=axs[1])
# f.tight_layout()

#### Jaccard Index TF representation

In [None]:
df_jaccard_matrix_TF = compute_jaccard_matrix(ad_weights)

In [None]:
sns.heatmap(df_jaccard_matrix_TF)