In [28]:
import os
import re
import warnings

import pandas as pd
import numpy as np
from scipy.stats import pearsonr, spearmanr, ConstantInputWarning
from tqdm.notebook import tqdm

warnings.filterwarnings("ignore", category=ConstantInputWarning)

In [50]:
PREDIXCAN_DIRS = {
    (
        "random_split",
        "384_bins",
    ): "/data/yosef3/users/ruchir/finetuning-enformer/linear_methods/predixcan/random_split.384_bins.no_cv",
    (
        "population_split",
        "384_bins",
    ): "/data/yosef3/users/ruchir/finetuning-enformer/linear_methods/predixcan/population_split.384_bins.no_cv",
    (
        "random_split",
        "1Mb",
    ): "/data/yosef3/users/ruchir/finetuning-enformer/linear_methods/predixcan/random_split.1Mb.no_cv",
    (
        "population_split",
        "1Mb",
    ): "/data/yosef3/users/ruchir/finetuning-enformer/linear_methods/predixcan/population_split.1Mb.no_cv",
}

MANUAL_PRIOR_DIRS = {
    (
        "random_split",
        "384_bins",
        "LCL_weights",
    ): "/data/yosef3/users/ruchir/finetuning-enformer/linear_methods/manual_prior/random_split.384_bins.LCL_weights",
    (
        "population_split",
        "384_bins",
        "LCL_weights",
    ): "/data/yosef3/users/ruchir/finetuning-enformer/linear_methods/manual_prior/population_split.384_bins.LCL_weights",
    (
        "random_split",
        "384_bins",
        "uniform_weights",
    ): "/data/yosef3/users/ruchir/finetuning-enformer/linear_methods/manual_prior/random_split.384_bins.uniform_weights",
    (
        "population_split",
        "384_bins",
        "uniform_weights",
    ): "/data/yosef3/users/ruchir/finetuning-enformer/linear_methods/manual_prior/population_split.384_bins.uniform_weights",
    (
        "random_split",
        "1Mb",
        "LCL_weights",
    ): "/data/yosef3/users/ruchir/finetuning-enformer/linear_methods/manual_prior/random_split.1Mb.LCL_weights",
    (
        "population_split",
        "1Mb",
        "LCL_weights",
    ): "/data/yosef3/users/ruchir/finetuning-enformer/linear_methods/manual_prior/population_split.1Mb.LCL_weights",
    # ("random_split", "1Mb", "uniform_weights"): "/data/yosef3/users/ruchir/finetuning-enformer/linear_methods/manual_prior/random_split.1Mb.uniform_weights",
    # ("population_split", "1Mb", "uniform_weights"): "/data/yosef3/users/ruchir/finetuning-enformer/linear_methods/manual_prior/population_split.1Mb.uniform_weights",
}

LEARNED_PRIOR_DIRS = {
    (
        "random_split",
        "384_bins",
    ): "/data/yosef3/scratch/ruchir/finetuning-enformer/meta_feature_prior/r1_sgd_random_split/r1_random_split_C_0.01_lr_0.001/preds/",
    (
        "population_split",
        "384_bins",
    ): "/data/yosef3/scratch/ruchir/finetuning-enformer/meta_feature_prior/r1_sgd_population_split/r1_population_split_C_0.01_lr_0.001/preds/",
}

COUNTS_PATH = "/data/yosef3/users/ruchir/finetuning-enformer/process_geuvadis_data/log_tpm/corrected_log_tpm.annot.csv.gz"

In [11]:
def load_counts_df():
    counts_df = pd.read_csv(COUNTS_PATH, index_col="our_gene_name")
    counts_df = counts_df[~counts_df.index.isna()]
    assert counts_df.shape[0] == 3259
    return counts_df


def rename_gene(gene):
    return gene.replace(".", "_")

In [9]:
counts_df = load_counts_df()
genes = set(counts_df.index)
assert all("_" not in g for g in genes)

In [35]:
def compute_predixcan_and_manual_prior_corrs(
    preds_dir: str, counts_df: pd.DataFrame, genes: set[str]
):
    preds_path = os.path.join(preds_dir, "preds.csv")
    preds_df = pd.read_csv(preds_path, index_col=0)

    pearsons = {}
    spearmans = {}
    for g in tqdm(genes):
        Y_hat = preds_df.loc[g].dropna()
        Y = counts_df.loc[g, Y_hat.index]
        assert Y_hat.shape[0] == Y.shape[0] == 77, (Y_hat.shape, Y.shape)
        pearson = pearsonr(Y, Y_hat)[0]
        spearman = spearmanr(Y, Y_hat)[0]
        pearsons[g] = pearson if not np.isnan(pearson) else 0
        spearmans[g] = spearman if not np.isnan(spearman) else 0
    return {"pearsons": pearsons, "spearmans": spearmans}


def compute_learned_prior_corrs(
    preds_dir: str, counts_df: pd.DataFrame, genes: set[str]
):
    # Get file in preds_dir with *.preds.csv extension
    preds_files = os.listdir(preds_dir)
    preds_files = [f for f in preds_files if f.endswith(".preds.csv")]
    assert len(preds_files) == 1

    preds_path = os.path.join(preds_dir, preds_files[0])
    preds_df = pd.read_csv(preds_path, index_col=0)

    pearsons = {}
    spearmans = {}
    for g in tqdm(genes):
        g_prime = rename_gene(g)
        if g_prime not in preds_df.index:
            pearsons[g] = 0.0
            spearmans[g] = 0.0
            continue
        Y_hat = preds_df.loc[g_prime].dropna()
        Y = counts_df.loc[g, Y_hat.index]
        assert Y_hat.shape[0] == Y.shape[0] == 77
        pearson = pearsonr(Y, Y_hat)[0]
        spearman = spearmanr(Y, Y_hat)[0]
        pearsons[g] = pearson if not np.isnan(pearson) else 0
        spearmans[g] = spearman if not np.isnan(spearman) else 0
    return {"pearsons": pearsons, "spearmans": spearmans}

# PrediXcan

In [42]:
predixcan_corrs = {}
for k in PREDIXCAN_DIRS:
    predixcan_corrs[k] = compute_predixcan_and_manual_prior_corrs(
        PREDIXCAN_DIRS[k], counts_df, genes
    )
    mean_pearson = np.mean(list(predixcan_corrs[k]["pearsons"].values()))
    mean_spearman = np.mean(list(predixcan_corrs[k]["spearmans"].values()))
    print(f"{k}: Pearson: {mean_pearson}, Spearman: {mean_spearman}")

  0%|          | 0/3259 [00:00<?, ?it/s]

('random_split', '384_bins'): Pearson: 0.26528616802148847, Spearman: 0.2609664036394531


  0%|          | 0/3259 [00:00<?, ?it/s]

('population_split', '384_bins'): Pearson: 0.14322784223149096, Spearman: 0.14005147219856467


  0%|          | 0/3259 [00:00<?, ?it/s]

('random_split', '1Mb'): Pearson: 0.2626795218170747, Spearman: 0.2565304449110364


  0%|          | 0/3259 [00:00<?, ?it/s]

('population_split', '1Mb'): Pearson: 0.14529631901405088, Spearman: 0.14097625423811044


# Manual prior

In [51]:
manual_prior_corrs = {}
for k in MANUAL_PRIOR_DIRS:
    manual_prior_corrs[k] = compute_predixcan_and_manual_prior_corrs(
        MANUAL_PRIOR_DIRS[k], counts_df, genes
    )
    mean_pearson = np.mean(list(manual_prior_corrs[k]["pearsons"].values()))
    mean_spearman = np.mean(list(manual_prior_corrs[k]["spearmans"].values()))
    print(f"{k}: Pearson: {mean_pearson}, Spearman: {mean_spearman}")

  0%|          | 0/3259 [00:00<?, ?it/s]

('random_split', '384_bins', 'LCL_weights'): Pearson: 0.2633094851828583, Spearman: 0.25669468398016415


  0%|          | 0/3259 [00:00<?, ?it/s]

('population_split', '384_bins', 'LCL_weights'): Pearson: 0.14386493976814954, Spearman: 0.1402554654454205


  0%|          | 0/3259 [00:00<?, ?it/s]

('random_split', '384_bins', 'uniform_weights'): Pearson: 0.26442439738556894, Spearman: 0.25764647338101304


  0%|          | 0/3259 [00:00<?, ?it/s]

('population_split', '384_bins', 'uniform_weights'): Pearson: 0.1441908744390086, Spearman: 0.14081572974998954


  0%|          | 0/3259 [00:00<?, ?it/s]

('random_split', '1Mb', 'LCL_weights'): Pearson: 0.2502518155117472, Spearman: 0.24399976245069185


FileNotFoundError: [Errno 2] No such file or directory: '/data/yosef3/users/ruchir/finetuning-enformer/linear_methods/manual_prior/population_split.1Mb.LCL_weights/preds.csv'

# Learned prior

In [30]:
for k in LEARNED_PRIOR_DIRS:
    learned_prior_corrs = compute_learned_prior_corrs(
        LEARNED_PRIOR_DIRS[k], counts_df, genes
    )
    mean_pearson = np.mean(list(learned_prior_corrs["pearsons"].values()))
    mean_spearman = np.mean(list(learned_prior_corrs["spearmans"].values()))
    print(f"{k}: Pearson: {mean_pearson}, Spearman: {mean_spearman}")

  0%|          | 0/3259 [00:00<?, ?it/s]

('random_split', '384_bins'): Pearson: 0.2646572586843869, Spearman: 0.25484699332935273


  0%|          | 0/3259 [00:00<?, ?it/s]

('population_split', '384_bins'): Pearson: 0.15390101729315434, Spearman: 0.14896245577837516
