In [7]:
import h5py
from functools import cache, partial

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import r2_score
from tqdm.notebook import tqdm

import evaluation_utils

In [8]:
# fmt: off
GENE_CLASS_PATH = "../finetuning/data/h5_bins_384_chrom_split/gene_class.csv"
GEUVADIS_COUNTS_PATH = "../process_geuvadis_data/log_tpm/corrected_log_tpm.annot.csv.gz"
BASELINE_PREDS_PATH = "../baseline/baseline_enformer.384_bins.rc.csv"
METADATA_PATH = "/data/yosef3/users/ruchir/pgp_uq/data/E-GEUV-1.sdrf.txt"
PREDIXCAN_PREDS_384_BINS_PATH = "../predixcan_lite/h5_bins_384_chrom_split/384_bins_no_cv/preds.csv"
PREDIXCAN_PREDS_1MB_PATH = "../predixcan_lite/h5_bins_384_chrom_split/1Mb_no_cv/preds.csv"
# fmt: on

In [9]:
def load_model_preds(path: str) -> pd.DataFrame:
    data = np.load(path)
    preds = data["preds"]
    genes = data["genes"]
    samples = data["samples"]

    unique_genes = sorted(np.unique(genes))
    unique_samples = sorted(np.unique(samples))
    gene_to_idx = {gene: idx for idx, gene in enumerate(unique_genes)}
    sample_to_idx = {sample: idx for idx, sample in enumerate(unique_samples)}

    preds_mtx = np.full((len(unique_genes), len(unique_samples)), np.nan)
    for (pred, gene, sample) in zip(preds, genes, samples):
        preds_mtx[gene_to_idx[gene], sample_to_idx[sample]] = pred
    return pd.DataFrame(preds_mtx, index=unique_genes, columns=unique_samples)


def check_sample_split_consistency(df1: pd.DataFrame, df2: pd.DataFrame):
    common_genes = df1.index.intersection(df2.index)
    for g in common_genes:
        df1_samples = df1.loc[g].dropna().index
        df2_samples = df2.loc[g].dropna().index
        assert set(df1_samples) == set(
            df2_samples
        ), f"Samples for gene {g} are not consistent"

In [10]:
predixcan_preds_384_bins_df = pd.read_csv(PREDIXCAN_PREDS_384_BINS_PATH, index_col=0)
predixcan_preds_1Mb_bins_df = pd.read_csv(PREDIXCAN_PREDS_1MB_PATH, index_col=0)
baseline_preds_df = pd.read_csv(BASELINE_PREDS_PATH, index_col=0)

In [11]:
geuvadis_counts_df = pd.read_csv(GEUVADIS_COUNTS_PATH, index_col="our_gene_name")

In [12]:
def compute_correlations(
    df1: pd.DataFrame, df2: pd.DataFrame, metric: str = "pearson"
) -> dict[str, float]:
    """
    For metric == r2_score, df1 should contain the true values and df2 the predictions.
    """
    common_genes = df1.index.intersection(df2.index)
    correlations = {}
    for g in tqdm(common_genes):
        df1_samples = df1.loc[g].dropna().index
        df2_samples = df2.loc[g].dropna().index
        common_samples = df1_samples.intersection(df2_samples)
        assert len(common_samples) == 77 or len(common_samples) == 421
        if metric == "spearman":
            corr, _ = spearmanr(df1.loc[g, common_samples], df2.loc[g, common_samples])
        elif metric == "pearson":
            corr, _ = pearsonr(df1.loc[g, common_samples], df2.loc[g, common_samples])
        elif metric == "r2_score":
            corr = r2_score(df1.loc[g, common_samples], df2.loc[g, common_samples])
        else:
            raise ValueError(f"Unknown metric {metric}")
        correlations[g] = corr if not np.isnan(corr) else 0.0
    return correlations


def get_mean_class_correlation(
    gene_to_corrs: dict[str, float], class_: str, abs_corr: bool = False
):
    assert class_ in ["random_split", "yri_split", "unseen"]
    gene_to_class_map = evaluation_utils.get_gene_to_class_map()

    my_corrs = []
    for g in gene_to_class_map:
        if gene_to_class_map[g] != class_:
            continue
        corr = gene_to_corrs[g]
        if abs_corr:
            corr = np.abs(corr)
        my_corrs.append(corr)

    if class_ == "random_split" or class_ == "yri_split":
        assert len(my_corrs) == 200
    else:
        assert len(my_corrs) == 100
    return np.mean(my_corrs)

In [13]:
baseline_pearsons = compute_correlations(geuvadis_counts_df, baseline_preds_df)
predixcan_384_bin_pearsons = compute_correlations(
    geuvadis_counts_df, predixcan_preds_384_bins_df
)
predixcan_1Mb_pearsons = compute_correlations(
    geuvadis_counts_df, predixcan_preds_1Mb_bins_df
)

  0%|          | 0/500 [00:00<?, ?it/s]

  corr, _ = pearsonr(df1.loc[g, common_samples], df2.loc[g, common_samples])


  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

In [18]:
baseline_random_split_mean_pearson = get_mean_class_correlation(
    baseline_pearsons, "random_split"
)
baseline_yri_split_mean_pearson = get_mean_class_correlation(
    baseline_pearsons, "yri_split"
)
baseline_unseen_mean_pearson = get_mean_class_correlation(baseline_pearsons, "unseen")
print(f"Baseline random split mean pearson: {baseline_random_split_mean_pearson:.4f}")
print(f"Baseline YRI split mean pearson: {baseline_yri_split_mean_pearson:.4f}")
print(f"Baseline unseen mean pearson: {baseline_unseen_mean_pearson:.4f}")

Baseline random split mean pearson: 0.055
Baseline YRI split mean pearson: 0.025
Baseline unseen mean pearson: 0.056


In [15]:
baseline_abs_random_split_mean_pearson = get_mean_class_correlation(
    baseline_pearsons, "random_split", abs_corr=True
)
baseline_abs_yri_split_mean_pearson = get_mean_class_correlation(
    baseline_pearsons, "yri_split", abs_corr=True
)
baseline_abs_unseen_mean_pearson = get_mean_class_correlation(
    baseline_pearsons, "unseen", abs_corr=True
)
print(
    f"Baseline abs random split mean pearson: {baseline_abs_random_split_mean_pearson:.4f}"
)
print(f"Baseline abs YRI split mean pearson: {baseline_abs_yri_split_mean_pearson:.4f}")
print(f"Baseline abs unseen mean pearson: {baseline_abs_unseen_mean_pearson:.4f}")

Baseline abs random split mean pearson: 0.147703
Baseline abs YRI split mean pearson: 0.130
Baseline abs unseen mean pearson: 0.138


In [19]:
predixcan_384_bin_random_split_mean_pearson = get_mean_class_correlation(
    predixcan_384_bin_pearsons, "random_split"
)
predixcan_384_bin_yri_split_mean_pearson = get_mean_class_correlation(
    predixcan_384_bin_pearsons, "yri_split"
)
print(
    f"Predixcan (384 bins context) random spit mean pearson: {predixcan_384_bin_random_split_mean_pearson:.4f}"
)
print(
    f"PrediXcan (384 bins context) YRI split mean pearson : {predixcan_384_bin_yri_split_mean_pearson:.4f}"
)

Predixcan (384 bins context) random spit mean pearson: 0.2702
PrediXcan (384 bins context) YRI split mean pearson : 0.1394


In [20]:
predixcan_1Mb_random_split_mean_pearson = get_mean_class_correlation(
    predixcan_1Mb_pearsons, "random_split"
)
predixcan_1Mb_yri_split_mean_pearson = get_mean_class_correlation(
    predixcan_1Mb_pearsons, "yri_split"
)
print(
    f"Predixcan (1Mb context) random spit mean pearson: {predixcan_1Mb_random_split_mean_pearson:.4f}"
)
print(
    f"PrediXcan (1Mb context) YRI split mean pearson : {predixcan_1Mb_yri_split_mean_pearson:.4f}"
)

Predixcan (1Mb context) random spit mean pearson: 0.2633
PrediXcan (1Mb context) YRI split mean pearson : 0.1355
