In [1]:
%load_ext autoreload
%autoreload 2

In [17]:
import os
import sys
import pandas as pd
import numpy as np
import warnings
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

sys.path.append("/home/wangrr/Analysis/epiRNA/src")
import epiRNA as erna
workdir = "/data/wangrr/Analysis/epiRNA/workdir2/benchmark"
os.makedirs(workdir, exist_ok=True); os.chdir(workdir)
warnings.filterwarnings("ignore", category=UserWarning)

In [6]:
# model config:
config = "/home/wangrr/Analysis/epiRNA/src/experiments/model_v2/configs/basic_model.yaml"

# model checkpoint:
ckpt_dir = "/data/wangrr/Analysis/epiRNA/models"
checkpoints = {
    "human": f"{ckpt_dir}/v2human_only/version_0/checkpoints/epoch=13-step=76594-val_ePCC_mixed=0.7087.ckpt",
    "mouse": f"{ckpt_dir}/v2mouse_only/version_0/checkpoints/epoch=08-step=52749-val_ePCC_mixed=0.7269.ckpt"
}

# load models
model_zoo = {}
for key, ckpt in checkpoints.items():
    model = erna.evals.load_model(ckpt, config)
    model.eval()
    model_zoo[key] = model

In [7]:
import pickle

datainfo = pd.read_excel("/home/wangrr/Analysis/epiRNA/data/mRNA_multiomic_sample_info.xlsx", sheet_name="m6A")

ref_dir = "/data/wangrr/Analysis/epiRNA/Dataset/Reference"
genome = {
    "human":{
        "fasta": f"{ref_dir}/Homo_sapiens.GRCh38.dna.primary_assembly.84.fa",
        "gtf": f"{ref_dir}/Homo_sapiens.GRCh38.84.gtf"
    },
    "mouse":{
        "fasta": f"{ref_dir}//Mus_musculus.GRCm38.dna.primary_assembly.84.fa",
        "gtf": f"{ref_dir}/Mus_musculus.GRCm38.84.gtf"
    }
}

human_gtf_df = erna.pp.GTFLoader(gtf_path=genome["human"]["gtf"], zero_based=True)
mouse_gtf_df = erna.pp.GTFLoader(gtf_path=genome["mouse"]["gtf"], zero_based=True)

chrom_kwargs = {
    'human': {
        'train_chroms': ["1", "2", "3", "4", "5", "6", "7", "9", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "21", "22"],
        'val_chroms': ["8"],
        'test_chroms': ["10"]
    },
    'mouse': {
        'train_chroms': ["1", "2", "3", "4", "5", "6", "7", "9", "11", "12", "13", "14", "15", "16", "17", "18", "19"],
        'val_chroms': ["8"],
        'test_chroms': ["10"]
    }
}

INFO:root:Extracted GTF attributes: ['gene_id', 'gene_version', 'gene_name', 'gene_source', 'gene_biotype', 'havana_gene', 'havana_gene_version', 'transcript_id', 'transcript_version', 'transcript_name', 'transcript_source', 'transcript_biotype', 'havana_transcript', 'havana_transcript_version', 'tag', 'transcript_support_level', 'exon_number', 'exon_id', 'exon_version', 'ccds_id', 'protein_id', 'protein_version']
INFO:root:Extracted GTF attributes: ['gene_id', 'gene_version', 'gene_name', 'gene_source', 'gene_biotype', 'havana_gene', 'havana_gene_version', 'transcript_id', 'transcript_version', 'transcript_name', 'transcript_source', 'transcript_biotype', 'havana_transcript', 'havana_transcript_version', 'tag', 'transcript_support_level', 'exon_number', 'exon_id', 'exon_version', 'ccds_id', 'protein_id', 'protein_version']


In [8]:
human_samples = ["Adipose-1", "Aorta-1", "Appendix-1", "Esophagus-1", "GOS-1", "Hela-1", "HT29-1", "Hypothalamus-1", "Jejunum-1", "Liver-1", "Prostate-1", "Skin-1", "Testis-1", "Tongue-1", "Trachea-1", "WPMY-1", "MT4-1", "GSC11-1", "iSLK-1", "TIME-1"]
mouse_samples = ["mHeart-1", "mSpleen-1", "mLung-1", "mLiver-1", "mCerebrum-1", "mCerebellum-1", "mBrainstem-1", "mHypothalamus-1", "mBMDC-1", "mEF-1", "mNPC-1", "mESC-1", "mStriatum-1", "mPeritoneal_macrophages-1", "mB16-OVA-1", "miPSC-1", "mNSC-1", "mForebrain-1", "mKidney-1", "mHypothalamus-2"]

human_datainfo = datainfo[datainfo["SID"].isin(human_samples)].reset_index(drop=True).copy()
mouse_datainfo = datainfo[datainfo["SID"].isin(mouse_samples)].reset_index(drop=True).copy()

In [9]:
with open("/home/wangrr/Analysis/epiRNA/data/rbp_expression_dict.pkl", "rb") as f:
    rbp_dict = pickle.load(f)['rbp_dict']

In [10]:
human_datasetkwargs = {
    'human_fasta_path': genome["human"]["fasta"],
    'mouse_fasta_path': genome["mouse"]["fasta"],
    "modality_to_index": {"m6A_human": 0, "m6A_mouse": 1},
    'bigwig_dir': "/data/wangrr/Analysis/epiRNA/Dataset/processed/4-coverage",
    'datainfo': human_datainfo,
    'rbp_dict': rbp_dict,
    'human_gtf_df': human_gtf_df,
    'mouse_gtf_df': mouse_gtf_df,
    'return_bin_input': True,
    'seq_len': 65536,
    'mask_type': None
}

mouse_datasetkwargs = human_datasetkwargs.copy()
mouse_datasetkwargs.update({
    'datainfo': mouse_datainfo
})

human_datamodule = erna.ds.GeneDataModule(
    dataset_cls=erna.ds.GeneDataset,
    dataset_kwargs=human_datasetkwargs,
    chrom_kwargs=chrom_kwargs,
    batch_size=32,
    num_workers=8
)

mouse_datamodule = erna.ds.GeneDataModule(
    dataset_cls=erna.ds.GeneDataset,
    dataset_kwargs=mouse_datasetkwargs,
    chrom_kwargs=chrom_kwargs,
    batch_size=32,
    num_workers=8
)

human_datamodule.setup(stage='test')
mouse_datamodule.setup(stage='test')

Total samples: 350180
Omics types: ['m6A']
Total samples: 13380
Omics types: ['m6A']
Total samples: 14580
Omics types: ['m6A']
Total samples: 375140
Omics types: ['m6A']
Total samples: 21040
Omics types: ['m6A']
Total samples: 20480
Omics types: ['m6A']


In [None]:
summaries = {}
for model_type, model in model_zoo.items():
    if model_type == "human":
        datamodule = human_datamodule
    elif model_type == "mouse":
        datamodule = mouse_datamodule
    else:
        continue
    dataloaders = [datamodule.val_dataloader(), datamodule.train_dataloader(), datamodule.test_dataloader()]
    for phase, dataloader in zip(["val", "train", "test"], dataloaders):
        print(f"Evaluating Model: {model_type} on data during {phase} phase")
        summary = erna.evals.benchmark_summary(
            model=model,
            dataloader=dataloader,
            n_channels=2,
            data_device="cuda:0",
            device_ids="0,1"  
        )
        summaries[(model_type, phase)] = summary
        print(summary)
            
with open("benchmark_species_summaries.pkl", "wb") as f:
    pickle.dump(summaries, f)

Evaluating Model: mouse on data during val phase


Benchmarking PCC:   0%|          | 0/658 [00:00<?, ?it/s]

Benchmarking PCC: 100%|██████████| 658/658 [02:30<00:00,  4.38it/s]


{'avg_loss': 0.050114666274483154, 'global_pcc': array([0.       , 0.7264283, 0.7264283], dtype=float32)}
Evaluating Model: mouse on data during train phase


Benchmarking PCC: 100%|██████████| 11723/11723 [44:17<00:00,  4.41it/s]


{'avg_loss': 0.02222799146243643, 'global_pcc': array([0.       , 0.8568593, 0.8568593], dtype=float32)}
Evaluating Model: mouse on data during test phase


Benchmarking PCC: 100%|██████████| 640/640 [02:25<00:00,  4.39it/s]

{'avg_loss': 0.043096711458019854, 'global_pcc': array([0.       , 0.7361933, 0.7361933], dtype=float32)}



