In [None]:
import os
import sys
import pandas as pd
import warnings


sys.path.append("/disk1/gaochch_group/wangrr/Analysis/epiRNA/src")
workdir = "/disk1/gaochch_group/wangrr/Analysis/epiRNA/workdir/benchmark"
os.makedirs(workdir, exist_ok=True); os.chdir(workdir)
warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
import epiRNA as erna
ckpt = "../train_test/basic_fold/version_0/checkpoints/epoch=18-step=215327-val_ePCC_mixed=0.7823.ckpt"
yaml = "/home/wangrr/Analysis/epiRNA/src/experiments/model_v2/configs/basic_model.yaml"
model = erna.evals.load_model(ckpt, yaml)

train_human_samples = ["Adipose-1", "Aorta-1", "Appendix-1", "Esophagus-1", "GOS-1", "Hela-1", "HT29-1", "Hypothalamus-1", "Jejunum-1", "Liver-1", "Prostate-1", "Skin-1", "Testis-1", "Tongue-1", "Trachea-1", "WPMY-1", "MT4-1", "GSC11-1", "iSLK-1", "TIME-1", "mHeart-1", "mSpleen-1", "mLung-1", "mLiver-1", "mCerebrum-1", "mCerebellum-1", "mBrainstem-1", "mHypothalamus-1", "mBMDC-1", "mEF-1", "mNPC-1", "mESC-1", "mStriatum-1", "mPeritoneal_macrophages-1", "mB16-OVA-1", "miPSC-1", "mNSC-1", "mForebrain-1", "mKidney-1", "mHypothalamus-2"]

test_human_samples = ['293T-1', 'Cerebellum-1', 'Cerebrum-1', 'Heart-1', 'Jurkat-1', 'K562-1', 'Lung-1', 'Rectum-1', 'Spleen-1', 'Thyroid_gland-1', 'MONO-MAC-6-1', 'MSC-1', 'HEK293A-TOA-1']

# check intersect
train_set = set(train_human_samples)
test_set = set(test_human_samples)
intersect = train_set.intersection(test_set)
print(f"Intersected samples between train and test: {intersect}")

datainfo = pd.read_excel("/home/wangrr/Analysis/epiRNA/data/mRNA_multiomic_sample_info.xlsx", sheet_name="m6A")

ref_dir = "/data/wangrr/Analysis/epiRNA/Dataset/Reference"
genome = {
    "human":{
        "fasta": f"{ref_dir}/Homo_sapiens.GRCh38.dna.primary_assembly.84.fa",
        "gtf": f"{ref_dir}/Homo_sapiens.GRCh38.84.gtf"
    },
    "mouse":{
        "fasta": f"{ref_dir}//Mus_musculus.GRCm38.dna.primary_assembly.84.fa",
        "gtf": f"{ref_dir}/Mus_musculus.GRCm38.84.gtf"
    }
}

human_gtf_df = erna.pp.GTFLoader(gtf_path=genome["human"]["gtf"], zero_based=True)
mouse_gtf_df = erna.pp.GTFLoader(gtf_path=genome["mouse"]["gtf"], zero_based=True)

chrom_kwargs = {
    'human': {
        'train_chroms': [],
        'val_chroms': [str(i) for i in range(1, 23)], # No shuffle
        'test_chroms': []
    },
    'mouse': {
        'train_chroms': [],
        'val_chroms': [str(i) for i in range(1, 20)], # No shuffle
        'test_chroms': []
    }
}
import pickle
with open("/home/wangrr/Analysis/epiRNA/data/rbp_expression_dict.pkl", "rb") as f:
    rbp_dict = pickle.load(f)['rbp_dict']
    
datainfo = datainfo[datainfo["SID"].isin(test_human_samples)].reset_index(drop=True).copy()

datasetkwargs = {
    'human_fasta_path': genome["human"]["fasta"],
    'mouse_fasta_path': genome["mouse"]["fasta"],
    "modality_to_index": {"m6A_human": 0, "m6A_mouse": 1},
    'bigwig_dir': "/data/wangrr/Analysis/epiRNA/Dataset/processed/4-coverage",
    'datainfo': datainfo,
    'rbp_dict': rbp_dict,
    'human_gtf_df': human_gtf_df,
    'mouse_gtf_df': mouse_gtf_df,
    'seq_len': 65536,
    'mask_type': None
}

from epiRNA.evals import pcc_across_samples

batchsize = 128
num_workers = 8

results_list = []
for obs in test_human_samples:
    for pred in test_human_samples:
        print(f"Evaluating sample pair: obs={obs}, pred={pred}")

        # create dataloaders
        info_obs = datainfo[datainfo["SID"]==obs].reset_index(drop=True).copy()
        info_pred = datainfo[datainfo["SID"]==pred].reset_index(drop=True).copy()
        
        dskwargs_obs = datasetkwargs.copy()
        dskwargs_obs['datainfo'] = info_obs
        dskwargs_pred = datasetkwargs.copy()
        dskwargs_pred['datainfo'] = info_pred
        
        datamodule_obs = erna.ds.GeneDataModule(
            dataset_cls=erna.ds.GeneDataset,
            dataset_kwargs=dskwargs_obs,
            chrom_kwargs=chrom_kwargs,
            batch_size=32,
            num_workers=8
        )
        datamodule_obs.setup()
        datamodule_pred = erna.ds.GeneDataModule(
            dataset_cls=erna.ds.GeneDataset,
            dataset_kwargs=dskwargs_pred,
            chrom_kwargs=chrom_kwargs,
            batch_size=32,
            num_workers=8
        )
        datamodule_pred.setup()
        dataloader_obs = datamodule_obs.val_dataloader()
        dataloader_pred = datamodule_pred.val_dataloader()
        
        # Get PCC
        pcc_arr = pcc_across_samples(
            model=model,
            dataloader_obs=dataloader_obs,
            dataloader_pred=dataloader_pred,
            n_channels=2,
            data_device="cuda:2",
            device_ids="2,3"
        )
        pcc = pcc_arr[-1]
        
        row = {"SID_obs": obs, "SID_pred": pred, "PCC": pcc}
        print(row)
        results_list.append(row)           

result_entry_df = pd.DataFrame(results_list)
result_entry_df.to_csv("allfold_sample_level_pcc_matrix.csv", index=False)