In [1]:
%load_ext autoreload
%autoreload 2

In [46]:
import os
import sys
import pickle
import numpy as np
import pandas as pd
sys.path.append("/home/wangrr/Analysis/epiRNA/src")
import epiRNA as erna

workdir = "/data/wangrr/Analysis/epiRNA/workdir2/RBP_token"
os.makedirs(workdir, exist_ok=True); os.chdir(workdir)

## <span style="font-family: Arial, sans-serif;">Step 1: Load trained models

In [3]:
checkpoints = "/data/wangrr/Analysis/epiRNA/models/basic_fold/version_0/checkpoints/epoch=18-step=215327-val_ePCC_mixed=0.7823.ckpt"

config_file = "/home/wangrr/Analysis/epiRNA/src/experiments/model_v2/configs/basic_model.yaml"

model = erna.evals.load_model(ckpt=checkpoints, config_yaml=config_file)

## <span style="font-family: Arial, sans-serif;">Step 2: Load validation/test datasets

In [4]:
datainfo = pd.read_excel("/home/wangrr/Analysis/epiRNA/data/mRNA_multiomic_sample_info.xlsx", sheet_name="m6A")

m6A_train_samples = ["Adipose-1", "Aorta-1", "Appendix-1", "Esophagus-1", "GOS-1", "Hela-1", "HT29-1", "Hypothalamus-1", "Jejunum-1", "Liver-1", "Prostate-1", "Skin-1", "Testis-1", "Tongue-1", "Trachea-1", "WPMY-1", "MT4-1", "GSC11-1", "iSLK-1", "TIME-1", "mHeart-1", "mSpleen-1", "mLung-1", "mLiver-1", "mCerebrum-1", "mCerebellum-1", "mBrainstem-1", "mHypothalamus-1", "mBMDC-1", "mEF-1", "mNPC-1", "mESC-1", "mStriatum-1", "mPeritoneal_macrophages-1", "mB16-OVA-1", "miPSC-1", "mNSC-1", "mForebrain-1", "mKidney-1", "mHypothalamus-2"]

datainfo = datainfo[datainfo["SID"].isin(m6A_train_samples)].reset_index(drop=True)

ref_dir = "/data/wangrr/Analysis/epiRNA/Dataset/Reference"
genome = {
    "human":{
        "fasta": f"{ref_dir}/Homo_sapiens.GRCh38.dna.primary_assembly.84.fa",
        "gtf": f"{ref_dir}/Homo_sapiens.GRCh38.84.gtf"
    },
    "mouse":{
        "fasta": f"{ref_dir}//Mus_musculus.GRCm38.dna.primary_assembly.84.fa",
        "gtf": f"{ref_dir}/Mus_musculus.GRCm38.84.gtf"
    }
}

human_gtf_df = erna.pp.GTFLoader(gtf_path=genome["human"]["gtf"], zero_based=True)
mouse_gtf_df = erna.pp.GTFLoader(gtf_path=genome["mouse"]["gtf"], zero_based=True)

chrom_kwargs = {
    'human': {
        'train_chroms': ["1", "2", "3", "4", "5", "6", "7", "9", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "21", "22"],
        'val_chroms': ["8"],
        'test_chroms': ["10"]
    },
    'mouse': {
        'train_chroms': ["1", "2", "3", "4", "5", "6", "7", "9", "11", "12", "13", "14", "15", "16", "17", "18", "19"],
        'val_chroms': ["8"],
        'test_chroms': ["10"]
    }
}

INFO:root:Extracted GTF attributes: ['gene_id', 'gene_version', 'gene_name', 'gene_source', 'gene_biotype', 'havana_gene', 'havana_gene_version', 'transcript_id', 'transcript_version', 'transcript_name', 'transcript_source', 'transcript_biotype', 'havana_transcript', 'havana_transcript_version', 'tag', 'transcript_support_level', 'exon_number', 'exon_id', 'exon_version', 'ccds_id', 'protein_id', 'protein_version']
INFO:root:Extracted GTF attributes: ['gene_id', 'gene_version', 'gene_name', 'gene_source', 'gene_biotype', 'havana_gene', 'havana_gene_version', 'transcript_id', 'transcript_version', 'transcript_name', 'transcript_source', 'transcript_biotype', 'havana_transcript', 'havana_transcript_version', 'tag', 'transcript_support_level', 'exon_number', 'exon_id', 'exon_version', 'ccds_id', 'protein_id', 'protein_version']


In [71]:
rbp_dict_file = "/home/wangrr/Analysis/epiRNA/data/rbp_expression_dict.pkl"
with open(rbp_dict_file, "rb") as f:
    rbp_dict = pickle.load(f)

In [6]:
rbp_dict.keys()

dict_keys(['rbp_dict', 'human_rbp', 'mouse_rbp'])

In [7]:
# Split by species
human_datainfo = datainfo[datainfo["Species"]=="Human"].reset_index(drop=True)
mouse_datainfo = datainfo[datainfo["Species"]=="Mouse"].reset_index(drop=True)

In [67]:
human_dataset_kwargs = {
    'human_fasta_path': genome["human"]["fasta"],
    'mouse_fasta_path': genome["mouse"]["fasta"],
    "modality_to_index": {"m6A_human": 0, "m6A_mouse": 1},
    'bigwig_dir': "/data/wangrr/Analysis/epiRNA/Dataset/processed/4-coverage",
    'datainfo': human_datainfo,
    'rbp_dict': rbp_dict['rbp_dict'],
    'human_gtf_df': human_gtf_df,
    'mouse_gtf_df': mouse_gtf_df,
    'return_bin_input': True,
    'seq_len': 65536,
    'mask_type': None
}

mouse_dataset_kwargs = human_dataset_kwargs.copy()
mouse_dataset_kwargs['datainfo'] = mouse_datainfo

human_datamodule = erna.ds.GeneDataModule(
    dataset_cls=erna.ds.GeneDataset,
    dataset_kwargs=human_dataset_kwargs,
    chrom_kwargs=chrom_kwargs,
    batch_size=32,
    num_workers=8
)

mouse_datamodule = erna.ds.GeneDataModule(
    dataset_cls=erna.ds.GeneDataset,
    dataset_kwargs=mouse_dataset_kwargs,
    chrom_kwargs=chrom_kwargs,
    batch_size=32,
    num_workers=8
)   

human_datamodule.setup(stage="test")
mouse_datamodule.setup(stage="test")


Total samples: 350180
Omics types: ['m6A']
Total samples: 13380
Omics types: ['m6A']
Total samples: 14580
Omics types: ['m6A']
Total samples: 375140
Omics types: ['m6A']
Total samples: 21040
Omics types: ['m6A']
Total samples: 20480
Omics types: ['m6A']


## <span style="font-family: Arial, sans-serif;">Step 3: RBP importance analysis
### <span style="font-family: Arial, sans-serif;">3.1: Rank RBP based on importance

In [9]:
import torch
from tqdm import tqdm

In [10]:
sample_list = []
gene_list = []
rbp_grad_input = torch.Tensor([])

device = "cuda:0"
model = model.to(device)
# use 2 gpu
#model = torch.nn.DataParallel(model, device_ids=[0, 1])
for batch in tqdm(human_datamodule.test_dataloader()):
    x = batch["input"].to(device)
    rbp = batch["rbp_expr"].to(device)
    head_idx = batch["modality_idx"].to(device, dtype=torch.long)
    rbp.requires_grad = True

    model.zero_grad()
    
    pred, _ = model(x, rbp, head_idx) # (B, len)
    
    # get observed binned input (same length as pred)
    bin_input = batch["bin_input"].to(device)  # (B, len)
    
    mask = pred > bin_input # (B, len)
    scaler = (pred - bin_input) * mask  # (B, len)
    scaler = scaler.sum() 
    scaler.backward()
    
    grads = rbp.grad.detach().cpu() # (B, N_rbp)
    rbp = rbp.detach().cpu()
    gi = grads * rbp  # (B, N_rbp)
    rbp_grad_input = torch.cat([rbp_grad_input, gi], dim=0)  # (num_samples, N_rbp)
    
    sample = batch["sample"]
    gene = batch["gene"]
    sample_list.extend(sample)
    gene_list.extend(gene)
    
    #torch.cuda.empty_cache()

100%|██████████| 456/456 [09:40<00:00,  1.27s/it]


In [11]:
attribution = rbp_grad_input.mean(dim=0).numpy()  # (N_rbp,)
attribution_df = pd.DataFrame({
    "RBP": list(rbp_dict['human_rbp']),
    "Mean_Attribution": attribution
})

In [12]:
attribution_df 

Unnamed: 0,RBP,Mean_Attribution
0,LAS1L,-0.025367
1,LASP1,-0.053205
2,RBM5,-0.028496
3,ARF5,-0.013096
4,RBM6,-0.025666
...,...,...
2135,CT45A6,-0.000037
2136,C11orf98,-0.001516
2137,MRM1,-0.021904
2138,MRPL45,0.008301


In [17]:
# Rank top positive and negative RBPs
# 1. Top positive
attribution_df_pos = attribution_df.sort_values(by="Mean_Attribution", ascending=False).copy()
attribution_df_pos['Rank'] = np.arange(1, len(attribution_df_pos)+1)
attribution_df_pos.head(100)

Unnamed: 0,RBP,Mean_Attribution,Rank
1933,FLNA,0.104455,1
876,RNASE1,0.102880,2
602,NNT,0.060817,3
953,RBM38,0.060795,4
586,MRPL18,0.059545,5
...,...,...,...
2030,ARL2,0.030790,96
729,TMPO,0.030785,97
641,MOGS,0.030724,98
887,MRPL34,0.030721,99


In [16]:

# 2. Top negative
attribution_df_neg = attribution_df.sort_values(by="Mean_Attribution", ascending=True).copy()
attribution_df_neg['Rank'] = np.arange(1, len(attribution_df_neg)+1)
attribution_df_neg.head(300)

Unnamed: 0,RBP,Mean_Attribution,Rank
1429,ATP1A1,-0.095269,1
1871,MKL2,-0.090907,2
1720,RAB6A,-0.082059,3
178,RAB7A,-0.080789,4
1443,TIPARP,-0.079042,5
...,...,...,...
170,IGF2BP2,-0.028264,296
1532,PATL1,-0.028238,297
1951,MYO5A,-0.028194,298
984,PUM1,-0.028099,299


### <span style="font-family: Arial, sans-serif;">3.2: In silico knockout

In [76]:
# list all m6A writers
ko_gene_list = ["METTL3"]

rbp_list = np.array(rbp_dict['human_rbp'])
gene_pos = [np.where(rbp_list == gene)[0][0] for gene in ko_gene_list if gene in rbp_list]
ko_rbp_dict = {k: v.clone() for k, v in rbp_dict['rbp_dict'].items()}
for k, v in ko_rbp_dict.items():
    v[gene_pos] = 0.0
    
# construct datamodule
human_dataset_kwargs_ko = human_dataset_kwargs.copy()
human_dataset_kwargs_ko['rbp_dict'] = ko_rbp_dict

human_datamodule_ko = erna.ds.GeneDataModule(
    dataset_cls=erna.ds.GeneDataset,
    dataset_kwargs=human_dataset_kwargs_ko,
    chrom_kwargs=chrom_kwargs,
    batch_size=32,
    num_workers=8
)
human_datamodule_ko.setup(stage="test")

Total samples: 350180
Omics types: ['m6A']
Total samples: 13380
Omics types: ['m6A']
Total samples: 14580
Omics types: ['m6A']


In [84]:
ko_gene_list = ["FTO", "ALKBH5"]

rbp_list = np.array(rbp_dict['human_rbp'])
gene_pos = [np.where(rbp_list == gene)[0][0] for gene in ko_gene_list if gene in rbp_list]
ko_rbp_dict = {k: v.clone() for k, v in rbp_dict['rbp_dict'].items()}
for k, v in ko_rbp_dict.items():
    v[gene_pos] = 0.0
    
# construct datamodule
human_dataset_kwargs_ko = human_dataset_kwargs.copy()
human_dataset_kwargs_ko['rbp_dict'] = ko_rbp_dict

human_datamodule_ko = erna.ds.GeneDataModule(
    dataset_cls=erna.ds.GeneDataset,
    dataset_kwargs=human_dataset_kwargs_ko,
    chrom_kwargs=chrom_kwargs,
    batch_size=32,
    num_workers=8
)
human_datamodule_ko.setup(stage="test")

Total samples: 350180
Omics types: ['m6A']
Total samples: 13380
Omics types: ['m6A']
Total samples: 14580
Omics types: ['m6A']


In [77]:
ctrl_results = erna.evals.benchmark_summary(
    model=model,
    dataloader=human_datamodule.val_dataloader(),
    n_channels=2,
    loss_fn=torch.nn.MSELoss(),
    count_pcc=False,
    count_gene=False,
    count_sum=True,
    device_ids="0,1"
)

Benchmarking PCC:   0%|          | 0/419 [00:00<?, ?it/s]

Benchmarking PCC: 100%|██████████| 419/419 [01:36<00:00,  4.34it/s]


In [78]:
ko_results = erna.evals.benchmark_summary(
    model=model,
    dataloader=human_datamodule_ko.val_dataloader(),
    n_channels=2,
    loss_fn=torch.nn.MSELoss(),
    count_pcc=False,
    count_gene=False,
    count_sum=True,
    device_ids="0,1"
)

Benchmarking PCC:   3%|▎         | 11/419 [00:04<01:37,  4.18it/s]

Benchmarking PCC: 100%|██████████| 419/419 [01:36<00:00,  4.33it/s]


In [85]:
ko_results2 = erna.evals.benchmark_summary(
    model=model,
    dataloader=human_datamodule_ko.val_dataloader(),
    n_channels=2,
    loss_fn=torch.nn.MSELoss(),
    count_pcc=False,
    count_gene=False,
    count_sum=True,
    device_ids="0,1"
)

Benchmarking PCC:   0%|          | 0/419 [00:00<?, ?it/s]

Benchmarking PCC: 100%|██████████| 419/419 [01:35<00:00,  4.37it/s]


In [79]:
np.array(ko_results['pred_sum']).mean()

48.160819951942685

In [80]:
np.array(ctrl_results['pred_sum']).mean()


57.74833131584886

In [86]:
np.array(ko_results2['pred_sum']).mean()

48.12989700635274