In [1]:
%load_ext autoreload
%autoreload 2

In [20]:
import os
import sys
import pickle
import numpy as np
import torch
import pandas as pd
sys.path.append("/home/wangrr/Analysis/epiRNA/src")
import epiRNA as erna

workdir = "/data/wangrr/Analysis/epiRNA/workdir2/attribution"
os.makedirs(workdir, exist_ok=True); os.chdir(workdir)

In [3]:
checkpoints = "/data/wangrr/Analysis/epiRNA/models/basic_fold/version_0/checkpoints/epoch=18-step=215327-val_ePCC_mixed=0.7823.ckpt"
config_file = "/home/wangrr/Analysis/epiRNA/src/experiments/model_v2/configs/basic_model.yaml"
model = erna.evals.load_model(ckpt=checkpoints, config_yaml=config_file)

In [5]:
datainfo = pd.read_excel("/home/wangrr/Analysis/epiRNA/data/mRNA_multiomic_sample_info.xlsx", sheet_name="m6A")

samples = ["Cerebrum-1", "Cerebellum-1"]
datainfo = datainfo[datainfo["SID"].isin(samples)].reset_index(drop=True)

ref_dir = "/data/wangrr/Analysis/epiRNA/Dataset/Reference"
genome = {
    "human":{
        "fasta": f"{ref_dir}/Homo_sapiens.GRCh38.dna.primary_assembly.84.fa",
        "gtf": f"{ref_dir}/Homo_sapiens.GRCh38.84.gtf"
    },
    "mouse":{
        "fasta": f"{ref_dir}//Mus_musculus.GRCm38.dna.primary_assembly.84.fa",
        "gtf": f"{ref_dir}/Mus_musculus.GRCm38.84.gtf"
    }
}

human_gtf_df = erna.pp.GTFLoader(gtf_path=genome["human"]["gtf"], zero_based=True)
mouse_gtf_df = erna.pp.GTFLoader(gtf_path=genome["mouse"]["gtf"], zero_based=True)

chrom_kwargs = {
    'human': {
        'train_chroms': [],
        'val_chroms': [str(i) for i in range(1, 23)],
        'test_chroms': []
    },
    'mouse': {
        'train_chroms': [],
        'val_chroms': [],
        'test_chroms': []
    }
}

INFO:root:Extracted GTF attributes: ['gene_id', 'gene_version', 'gene_name', 'gene_source', 'gene_biotype', 'havana_gene', 'havana_gene_version', 'transcript_id', 'transcript_version', 'transcript_name', 'transcript_source', 'transcript_biotype', 'havana_transcript', 'havana_transcript_version', 'tag', 'transcript_support_level', 'exon_number', 'exon_id', 'exon_version', 'ccds_id', 'protein_id', 'protein_version']
INFO:root:Extracted GTF attributes: ['gene_id', 'gene_version', 'gene_name', 'gene_source', 'gene_biotype', 'havana_gene', 'havana_gene_version', 'transcript_id', 'transcript_version', 'transcript_name', 'transcript_source', 'transcript_biotype', 'havana_transcript', 'havana_transcript_version', 'tag', 'transcript_support_level', 'exon_number', 'exon_id', 'exon_version', 'ccds_id', 'protein_id', 'protein_version']


In [7]:
rbp_dict_file = "/home/wangrr/Analysis/epiRNA/data/rbp_expression_dict.pkl"
with open(rbp_dict_file, "rb") as f:
    rbp_dict = pickle.load(f)
    
dataset_kwargs = {
    'human_fasta_path': genome["human"]["fasta"],
    'mouse_fasta_path': genome["mouse"]["fasta"],
    "modality_to_index": {"m6A_human": 0, "m6A_mouse": 1},
    'bigwig_dir': "/data/wangrr/Analysis/epiRNA/Dataset/processed/4-coverage",
    'datainfo': datainfo,
    'rbp_dict': rbp_dict['rbp_dict'],
    'human_gtf_df': human_gtf_df,
    'mouse_gtf_df': mouse_gtf_df,
    'return_bin_input': True,
    'seq_len': 65536,
    'mask_type': None
}

datamodule = erna.ds.GeneDataModule(
    dataset_cls=erna.ds.GeneDataset,
    dataset_kwargs=dataset_kwargs,
    chrom_kwargs=chrom_kwargs,
    batch_size=32,
    num_workers=8
)
datamodule.setup(stage="test")

Total samples: 0
Omics types: ['m6A']
Total samples: 37814
Omics types: ['m6A']
Total samples: 0
Omics types: ['m6A']


In [14]:
datasets = datamodule.val_dataset
dataloader = datamodule.val_dataloader()

In [22]:
from tqdm import tqdm
ig_score = []
igattributor = erna.ex.attribution.InputGradientAttrbution(model=model.to(device="cuda:0"))
target_manipulator = erna.ex.attribution.Manipulation(method="sum")
for batch in tqdm(dataloader):
    x = batch["input"].to(dtype=torch.float32, device="cuda:0")
    y = batch["target"].to(dtype=torch.float32, device="cuda:0")
    rbp_expr = batch['rbp_expr'].to(dtype=torch.float32, device="cuda:0")
    modality_idx = batch["modality_idx"].to(dtype=torch.int64, device="cuda:0")
    ig_scores_batch = igattributor.attribute(
        inputs=x, 
        target_manipulation=target_manipulator,
        additional_forward_args=(rbp_expr, modality_idx),
    )
    ig_score.append(ig_scores_batch)
    

  0%|          | 0/1182 [00:00<?, ?it/s]

100%|██████████| 1182/1182 [30:43<00:00,  1.56s/it]


In [23]:
# concatenate all batches
ig_score = torch.cat(ig_score, dim=0)

# get sequence attribution scores only
seq_ig_score = ig_score[:, :4, :].cpu().numpy()  #

# get expression attribution scores only
expr_ig_score = ig_score[:, 4:, :].cpu().numpy()  #

# get sum attribution scores over the ACGU channel
sum_seq_ig_score = np.sum(seq_ig_score, axis=1)  # shape: (num_samples, seq_len)

# save all
np.savez_compressed(
    "cns_m6A_ig_attribution.npz",
    seq_ig_score=seq_ig_score,
    expr_ig_score=expr_ig_score,
    sum_seq_ig_score=sum_seq_ig_score
)