In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
import sys
import pandas as pd
import numpy as np
import warnings
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
sys.path.append("/home/wangrr/Analysis/epiRNA/src")
import epiRNA as erna

workdir = "/data/wangrr/Analysis/epiRNA/workdir2/benchmark"
os.makedirs(workdir, exist_ok=True); os.chdir(workdir)
warnings.filterwarnings("ignore", category=UserWarning)



## 1. Load models

In [5]:
# model config:
config = "/home/wangrr/Analysis/epiRNA/src/experiments/model_v2/configs/basic_model.yaml"

# model checkpoint:
ckpt_dir = "/data/wangrr/Analysis/epiRNA/models"
checkpoints = {
    "complete": f"{ckpt_dir}/basic_fold/version_0/checkpoints/epoch=18-step=215327-val_ePCC_mixed=0.7823.ckpt",
    "maskrna": f"{ckpt_dir}/v2maskrna/version_0/checkpoints/epoch=11-step=135996-val_ePCC_mixed=0.4895.ckpt",
    "maskseq": f"{ckpt_dir}/v2maskseq/version_0/checkpoints/epoch=02-step=33999-val_ePCC_mixed=0.4829.ckpt",
    "maskrbp": f"{ckpt_dir}/v2maskrbp/version_0/checkpoints/epoch=10-step=124663-val_ePCC_mixed=0.6686.ckpt"
}

# load models
model_zoo = {}
for key, ckpt in checkpoints.items():
    model = erna.evals.load_model(ckpt, config)
    model.eval()
    model_zoo[key] = model

## 2. Load samples

In [22]:
import pickle

# load mock samples
datainfo = pd.read_excel("/home/wangrr/Analysis/epiRNA/data/mRNA_multiomic_sample_info.xlsx", sheet_name="m6A")

# pick human and mouse m6A samples
# m6A_train_samples = ["Adipose-1", "Aorta-1", "Appendix-1", "Esophagus-1", "GOS-1", "Hela-1", "HT29-1", "Hypothalamus-1", "Jejunum-1", "Liver-1", "Prostate-1", "Skin-1", "Testis-1", "Tongue-1", "Trachea-1", "WPMY-1", "MT4-1", "GSC11-1", "iSLK-1", "TIME-1", "mHeart-1", "mSpleen-1", "mLung-1", "mLiver-1", "mCerebrum-1", "mCerebellum-1", "mBrainstem-1", "mHypothalamus-1", "mBMDC-1", "mEF-1", "mNPC-1", "mESC-1", "mStriatum-1", "mPeritoneal_macrophages-1", "mB16-OVA-1", "miPSC-1", "mNSC-1", "mForebrain-1", "mKidney-1", "mHypothalamus-2"]
m6A_train_samples = ["Adipose-1"]

datainfo = datainfo[datainfo["SID"].isin(m6A_train_samples)].reset_index(drop=True)

ref_dir = "/data/wangrr/Analysis/epiRNA/Dataset/Reference"
genome = {
    "human":{
        "fasta": f"{ref_dir}/Homo_sapiens.GRCh38.dna.primary_assembly.84.fa",
        "gtf": f"{ref_dir}/Homo_sapiens.GRCh38.84.gtf"
    },
    "mouse":{
        "fasta": f"{ref_dir}//Mus_musculus.GRCm38.dna.primary_assembly.84.fa",
        "gtf": f"{ref_dir}/Mus_musculus.GRCm38.84.gtf"
    }
}

human_gtf_df = erna.pp.GTFLoader(gtf_path=genome["human"]["gtf"], zero_based=True)
mouse_gtf_df = erna.pp.GTFLoader(gtf_path=genome["mouse"]["gtf"], zero_based=True)

chrom_kwargs = {
    'human': {
        'train_chroms': ["1", "2", "3", "4", "5", "6", "7", "9", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20", "21", "22"],
        'val_chroms': ["8"],
        'test_chroms': ["10"]
    },
    'mouse': {
        'train_chroms': ["1", "2", "3", "4", "5", "6", "7", "9", "11", "12", "13", "14", "15", "16", "17", "18", "19"],
        'val_chroms': ["8"],
        'test_chroms': ["10"]
    }
}

INFO:root:Extracted GTF attributes: ['gene_id', 'gene_version', 'gene_name', 'gene_source', 'gene_biotype', 'havana_gene', 'havana_gene_version', 'transcript_id', 'transcript_version', 'transcript_name', 'transcript_source', 'transcript_biotype', 'havana_transcript', 'havana_transcript_version', 'tag', 'transcript_support_level', 'exon_number', 'exon_id', 'exon_version', 'ccds_id', 'protein_id', 'protein_version']
INFO:root:Extracted GTF attributes: ['gene_id', 'gene_version', 'gene_name', 'gene_source', 'gene_biotype', 'havana_gene', 'havana_gene_version', 'transcript_id', 'transcript_version', 'transcript_name', 'transcript_source', 'transcript_biotype', 'havana_transcript', 'havana_transcript_version', 'tag', 'transcript_support_level', 'exon_number', 'exon_id', 'exon_version', 'ccds_id', 'protein_id', 'protein_version']


In [23]:
with open("/home/wangrr/Analysis/epiRNA/data/rbp_expression_dict.pkl", "rb") as f:
    rbp_dict = pickle.load(f)['rbp_dict']
    
with open("/home/wangrr/Analysis/epiRNA/data/rbp_expression_dict_zero.pkl", "rb") as f:
    rbp_dict_zero = pickle.load(f)['rbp_dict']

In [24]:
fulldataset_kwargs = {
    'human_fasta_path': genome["human"]["fasta"],
    'mouse_fasta_path': genome["mouse"]["fasta"],
    "modality_to_index": {"m6A_human": 0, "m6A_mouse": 1},
    'bigwig_dir': "/data/wangrr/Analysis/epiRNA/Dataset/processed/4-coverage",
    'datainfo': datainfo,
    'rbp_dict': rbp_dict,
    'human_gtf_df': human_gtf_df,
    'mouse_gtf_df': mouse_gtf_df,
    'return_bin_input': True,
    'seq_len': 65536,
    'mask_type': None
}

maskrnadataset_kwargs = fulldataset_kwargs.copy()
maskrnadataset_kwargs['mask_type'] = 'expression'

maskseqdataset_kwargs = fulldataset_kwargs.copy()
maskseqdataset_kwargs['mask_type'] = 'sequence'

maskrbpdataset_kwargs = fulldataset_kwargs.copy()
maskrbpdataset_kwargs['rbp_dict'] = rbp_dict_zero

fulldatamodule = erna.ds.GeneDataModule(
    dataset_cls=erna.ds.GeneDataset,
    dataset_kwargs=fulldataset_kwargs,
    chrom_kwargs=chrom_kwargs,
    batch_size=32,
    num_workers=8
)

maskrnadatamodule = erna.ds.GeneDataModule(
    dataset_cls=erna.ds.GeneDataset,
    dataset_kwargs=maskrnadataset_kwargs,
    chrom_kwargs=chrom_kwargs,
    batch_size=32,
    num_workers=8
)

maskseqdatamodule = erna.ds.GeneDataModule(
    dataset_cls=erna.ds.GeneDataset,
    dataset_kwargs=maskseqdataset_kwargs,
    chrom_kwargs=chrom_kwargs,
    batch_size=32,
    num_workers=8
)

maskrbpdatamodule = erna.ds.GeneDataModule(
    dataset_cls=erna.ds.GeneDataset,
    dataset_kwargs=maskrbpdataset_kwargs,
    chrom_kwargs=chrom_kwargs,
    batch_size=32,
    num_workers=8
)

In [25]:
fulldatamodule.setup(stage='test')

Total samples: 17509
Omics types: ['m6A']
Total samples: 669
Omics types: ['m6A']
Total samples: 729
Omics types: ['m6A']


## Start benchmarking

In [None]:
summaries = {}
for model_type, model in model_zoo.items():
    # for data_type, datamodule in zip(
    #     ["complete", "maskrna", "maskseq", "maskrbp"],
    #     [fulldatamodule, maskrnadatamodule, maskseqdatamodule, maskrbpdatamodule]
    # ):
    if model_type == "complete":
       datamodule = fulldatamodule
    elif model_type == "maskrna":
       datamodule = maskrnadatamodule
    elif model_type == "maskseq":
       datamodule = maskseqdatamodule
    elif model_type == "maskrbp":
       datamodule = maskrbpdatamodule
    else:
       raise ValueError("Unknown model type.")
        
    dataloaders = [datamodule.train_dataloader(), datamodule.val_dataloader(), datamodule.test_dataloader()]
    for phase, dataloader in zip(["train", "val", "test"], dataloaders):
        print(f"Evaluating Model: {model_type} on data during {phase} phase")
        summary = erna.evals.benchmark_summary(
            model=model,
            dataloader=dataloader,
            n_channels=2,
            data_device="cuda:2",
            device_ids="2,3"  
        )
        summaries[(model_type, phase)] = summary

Total samples: 17509
Omics types: ['m6A']
Total samples: 669
Omics types: ['m6A']
Total samples: 729
Omics types: ['m6A']
Evaluating Model: complete on Data: complete during train phase


Benchmarking PCC:   0%|          | 0/547 [00:00<?, ?it/s]

Benchmarking PCC: 100%|██████████| 547/547 [02:06<00:00,  4.34it/s]


Evaluating Model: complete on Data: complete during val phase


Benchmarking PCC: 100%|██████████| 21/21 [00:07<00:00,  2.86it/s]


Evaluating Model: complete on Data: complete during test phase


Benchmarking PCC:  13%|█▎        | 3/23 [00:02<00:13,  1.47it/s]Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f9f5c8b8a40>
Traceback (most recent call last):
  File "/nvme/biosoft/miniforge3/envs/wrr_m6a_env/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1604, in __del__
    self._shutdown_workers()
  File "/nvme/biosoft/miniforge3/envs/wrr_m6a_env/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 1587, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/nvme/biosoft/miniforge3/envs/wrr_m6a_env/lib/python3.11/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
AssertionError: can only test a child process
Exception ignored in: Exception ignored in: Exception ignored in: Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7f9f5c8b8a40><function _MultiProcessingDataLoaderIter.__del__ at 0x7f9f5c8b8a40><fu

Total samples: 17509
Omics types: ['m6A']
Total samples: 669
Omics types: ['m6A']
Total samples: 729
Omics types: ['m6A']
Evaluating Model: complete on Data: maskrna during train phase


Benchmarking PCC: 100%|██████████| 547/547 [02:08<00:00,  4.26it/s]


Evaluating Model: complete on Data: maskrna during val phase


Benchmarking PCC: 100%|██████████| 21/21 [00:07<00:00,  2.88it/s]


Evaluating Model: complete on Data: maskrna during test phase


Benchmarking PCC: 100%|██████████| 23/23 [00:07<00:00,  2.94it/s]


Total samples: 17509
Omics types: ['m6A']
Total samples: 669
Omics types: ['m6A']
Total samples: 729
Omics types: ['m6A']
Evaluating Model: complete on Data: maskseq during train phase


Benchmarking PCC:  69%|██████▉   | 378/547 [01:29<00:39,  4.24it/s]


KeyboardInterrupt: 

In [27]:
summaries

{('complete', 'complete', 'train'): {'avg_loss': 0.014126870758198134,
  'global_pcc': array([0.92665637, 0.        , 0.92665637], dtype=float32)},
 ('complete', 'complete', 'val'): {'avg_loss': 0.031025115417258088,
  'global_pcc': array([0.765631, 0.      , 0.765631], dtype=float32)},
 ('complete', 'complete', 'test'): {'avg_loss': 0.03055985040441462,
  'global_pcc': array([0.78595364, 0.        , 0.78595364], dtype=float32)},
 ('complete', 'maskrna', 'train'): {'avg_loss': 0.0714938507723525,
  'global_pcc': array([0.67057407, 0.        , 0.67057407], dtype=float32)},
 ('complete', 'maskrna', 'val'): {'avg_loss': 0.06871554634139855,
  'global_pcc': array([0.3864124, 0.       , 0.3864124], dtype=float32)},
 ('complete', 'maskrna', 'test'): {'avg_loss': 0.07447547057262828,
  'global_pcc': array([0.39170244, 0.        , 0.39170244], dtype=float32)}}