In [1]:
import os
import pdb
import h5py
import pickle
from collections import defaultdict

import matplotlib
import matplotlib.pyplot as plt
from matplotlib.offsetbox import AnchoredText
import numpy as np
import pandas as pd
import seaborn as sns
from scipy.stats import pearsonr, spearmanr
from tqdm import tqdm

In [2]:
%matplotlib auto
plt.ioff()

Using matplotlib backend: <object object at 0x14816106b200>


<contextlib.ExitStack at 0x1480d94e2010>

In [3]:
fig, ax = plt.subplots(3, 2, figsize=(10, 15), tight_layout=True)

# Config

In [4]:
geuvadis_eQTL_dir = "/clusterfs/nilah/Geuvadis/E-GEUV-1/analysis_results"
geuvadis_genotypes_dir = "/clusterfs/nilah/Geuvadis/E-GEUV-1/genotypes/"
data_dir = "/clusterfs/nilah/ruchir/src/finetuning-enformer/finetuning/data/h5_bins_384_chrom_split/"
enformer_data_dir = "/global/scratch/users/aniketh/enformer_data/"
root_save_dir = "/global/scratch/users/aniketh/finetune-enformer/"
code_dir = "/global/home/users/aniketh/finetuning-enformer/"
fasta_path = "/clusterfs/nilah/aniketh/hg19/hg19.fa"
class_path = "../finetuning/data/h5_bins_384_chrom_split/class.csv"

models_dir = os.path.join(root_save_dir, "saved_models")
test_preds_dir = os.path.join(root_save_dir, "test_preds_final")
rest_unseen_preds_dir = os.path.join(root_save_dir, "rest_unseen_preds_final")
ISM_preds_dir = os.path.join(root_save_dir, "ISM")

train_h5_path = os.path.join(data_dir, "train.h5")
val_h5_path = os.path.join(data_dir, "val.h5")
test_h5_path = os.path.join(data_dir, "test.h5")
rest_unseen_h5_path = os.path.join(data_dir, "rest_unseen.h5")

GEUVADIS_COUNTS_PATH = "../process_geuvadis_data/log_tpm/corrected_log_tpm.annot.csv.gz"
FUSION_PREDS_DIR = "../fusion/preds"
FUSION_SCALING_PREDS_DIR = "../fusion/preds_scaling/"

all_main_run_names = {
    #     "classification": "NCCL_P2P_DISABLE=1 python finetuning/train_pairwise_classification_parallel_h5_dataset_dynamic_sampling_dataset.py {train_h5_path} {val_h5_path} {run_name} {models_dir} --batch_size 1 --lr 0.0001 --weight_decay 0.001 --data_seed {data_seed} --resume_from_checkpoint",
    "regression": "NCCL_P2P_DISABLE=1 python finetuning/train_pairwise_regression_parallel_h5_dataset.py {train_h5_path} {val_h5_path} {run_name} {models_dir} --batch_size 1 --lr 0.0001 --weight_decay 0.001 --use_scheduler --warmup_steps 1000 --data_seed {data_seed} --resume_from_checkpoint",
    #     "single_regression_counts": "NCCL_P2P_DISABLE=1 python finetuning/train_single_counts_parallel_h5_dataset.py {train_h5_path} {val_h5_path} {run_name} {models_dir} --batch_size 2 --lr 0.0001 --weight_decay 0.001 --use_scheduler --warmup_steps 1000 --data_seed {data_seed} --resume_from_checkpoint",
    # #     "single_regression": "NCCL_P2P_DISABLE=1 python finetuning/train_single_parallel_h5_dataset.py {train_h5_path} {val_h5_path} {run_name} {models_dir} --batch_size 2 --lr 0.0001 --weight_decay 0.001 --use_scheduler --warmup_steps 1000 --data_seed {data_seed} --resume_from_checkpoint", ###DOEST PERFORM WELL, NOT USED
    #     "joint_classification": "NCCL_P2P_DISABLE=1 python finetuning/train_pairwise_classification_with_enformer_data_parallel_h5_dynamic_sampling_dataset.py {train_h5_path} {val_h5_path} {enformer_data_dir} {run_name} {models_dir} --batch_size 1 --lr 0.0005 --weight_decay 0.005 --use_scheduler --warmup_steps 1000 --data_seed {data_seed} --resume_from_checkpoint",
    #     "joint_regression": "NCCL_P2P_DISABLE=1 python finetuning/train_pairwise_regression_with_enformer_data_parallel_h5_dynamic_sampling_dataset.py {train_h5_path} {val_h5_path} {enformer_data_dir} {run_name} {models_dir} --batch_size 1 --lr 0.0005 --weight_decay 0.005 --use_scheduler --warmup_steps 1000 --data_seed {data_seed} --resume_from_checkpoint",
    #     "joint_regression_with_Malinois_MPRA": "NCCL_P2P_DISABLE=1 python finetuning/train_pairwise_regression_with_MPRA_data_parallel_h5_dynamic_sampling_dataset.py {train_h5_path} {val_h5_path} {malinois_data_path} {run_name} {models_dir} --batch_size 1 --lr 0.0001 --weight_decay 0.001 --use_scheduler --warmup_steps 1000 --data_seed {data_seed} --resume_from_checkpoint",
    "baseline": "",
}
all_seeds = [42, 97, 7]
subsample_fracs = [0.2, 0.4, 0.6, 0.8]
all_afs = (
    list(np.arange(0.01, 0.1, 0.01).round(2))
    + list(np.arange(0.1, 0.4, 0.05).round(2))
    + list(np.arange(0.41, 0.49, 0.01).round(2))
)

In [5]:
ALL_PREDS_PATHS = {}
ALL_REST_UNSEEN_PREDS_PATHS = {}

# MAIN TEST RUNS
for run in all_main_run_names:
    if run == "baseline":
        ALL_PREDS_PATHS[run] = os.path.join(test_preds_dir, run)
        ALL_REST_UNSEEN_PREDS_PATHS[run] = os.path.join(rest_unseen_preds_dir, run)
    else:
        for i, seed in enumerate(all_seeds):
            train_cmd_template = all_main_run_names[run]
            lr_used_during_training = train_cmd_template.split("--lr ")[-1].split(" ")[
                0
            ]
            wd_used_during_training = train_cmd_template.split("--weight_decay ")[
                -1
            ].split(" ")[0]
            rcprob_used_during_training = 0.5
            rsmax_used_during_training = 3

            model_name = f"{run}_data_seed_{seed}_lr_{lr_used_during_training}_wd_{wd_used_during_training}_rcprob_{rcprob_used_during_training}_rsmax_{rsmax_used_during_training}"
            ALL_PREDS_PATHS[model_name] = os.path.join(test_preds_dir, model_name)
            ALL_REST_UNSEEN_PREDS_PATHS[model_name] = os.path.join(
                rest_unseen_preds_dir, model_name
            )

test_h5 = h5py.File(test_h5_path, "r")
rest_unseen_h5 = h5py.File(rest_unseen_h5_path, "r")

# Read cached correlations

In [6]:
# from figure_1c+3a+3b+S3+S4+S5.ipynb
gene_perf = pd.read_csv("all_gene_perf.csv")

### Compute the percent of all variants seen during training

In [7]:
if not os.path.exists("gene_variant_stats.csv"):
    train_h5 = h5py.File(train_h5_path, "r")
    train_genes = train_h5["genes"][:].astype(str)

    val_h5 = h5py.File(val_h5_path, "r")
    val_genes = val_h5["genes"][:].astype(str)

    test_genes = test_h5["genes"][:].astype(str)
    rest_unseen_genes = rest_unseen_h5["genes"][:].astype(str)

    all_percent_seen_during_training = []
    all_num_train_variants = []
    all_percent_only_seen_during_validation = []
    all_num_only_validation = []
    all_percent_only_seen_during_testing = []
    all_num_only_test = []
    all_num_all_variants = []

    gene_variant_stats = (
        gene_perf[["gene", "class"]].drop_duplicates().reset_index(drop=True)
    )

    for i in tqdm(range(len(gene_variant_stats))):
        row = gene_variant_stats.iloc[i]
        gene_name = row["gene"]
        gene_class = row["class"]
        print(gene_name, gene_class)
        train_gene_seqs = train_h5["seqs"][train_genes == gene_name]
        val_gene_seqs = val_h5["seqs"][val_genes == gene_name]
        test_gene_seqs = test_h5["seqs"][test_genes == gene_name]
        rest_unseen_gene_seqs = rest_unseen_h5["seqs"][rest_unseen_genes == gene_name]

        # get variants in each split
        train_variants = np.any(train_gene_seqs > 0, axis=(0, 1))  # (seqlen, 4)
        val_variants = np.any(val_gene_seqs > 0, axis=(0, 1))
        test_variants = (
            np.any((test_gene_seqs > 0), axis=(0, 1))
            + np.any(rest_unseen_gene_seqs > 0, axis=(0, 1))
        ) > 0

        # get all variants
        all_variants = np.stack([train_variants, val_variants, test_variants])
        all_variants = np.any(all_variants, axis=0)
        num_all_variants = (all_variants[(all_variants.sum(-1) > 1)].sum(-1) - 1).sum()

        # compute percent of variants seen in the training set
        # this XOR will find variants that are not present in the train set but are found in the val or test set
        if gene_class == "unseen":
            all_percent_seen_during_training.append(0)
            all_num_train_variants.append(0)
            all_num_all_variants.append(num_all_variants)
            all_percent_only_seen_during_validation.append(0)
            all_num_only_test.append(0)
            all_percent_only_seen_during_testing.append(100.0)
            all_num_only_test.append(num_all_variants)
        else:
            num_not_seen_in_train = np.logical_xor(all_variants, train_variants).sum()
            num_seen_in_train = num_all_variants - num_not_seen_in_train
            percent_seen_during_training = (
                num_seen_in_train / num_all_variants
            ) * 100.0
            assert percent_seen_during_training >= 0
            print(f"Num all variants = {num_all_variants}")
            print(
                f"Num variants seen during training = {num_seen_in_train} ({percent_seen_during_training}%)"
            )

            all_percent_seen_during_training.append(percent_seen_during_training)
            all_num_train_variants.append(num_seen_in_train)
            all_num_all_variants.append(num_all_variants)

            # compute percent of variants that are only seen in the test set
            # this op will find variants that are present in the test set but are not found in the val or test set
            num_only_seen_in_test = (
                test_variants & np.logical_not(train_variants | val_variants)
            ).sum()
            num_only_seen_in_validation = (
                val_variants & np.logical_not(train_variants | test_variants)
            ).sum()

            percent_only_seen_during_validation = (
                num_only_seen_in_validation / num_all_variants
            ) * 100.0
            print(
                f"Num variants only seen during validation = {num_only_seen_in_validation} ({percent_only_seen_during_validation}%)"
            )

            all_percent_only_seen_during_validation.append(
                percent_only_seen_during_validation
            )
            all_num_only_validation.append(num_only_seen_in_validation)

            percent_only_seen_during_testing = (
                num_only_seen_in_test / num_all_variants
            ) * 100.0
            print(
                f"Num variants only seen during testing = {num_only_seen_in_test} ({percent_only_seen_during_testing}%)"
            )

            all_percent_only_seen_during_testing.append(
                percent_only_seen_during_testing
            )
            all_num_only_test.append(num_only_seen_in_test)

    gene_variant_stats[
        "Percent of all variants seen during training"
    ] = all_percent_seen_during_training
    gene_variant_stats["Num variants seen during training"] = all_num_train_variants

    gene_variant_stats[
        "Percent of all variants only seen during validation"
    ] = all_percent_only_seen_during_validation
    gene_variant_stats[
        "Num variants only seen during validation"
    ] = all_num_only_validation

    gene_variant_stats[
        "Percent of all variants only seen during testing"
    ] = all_percent_only_seen_during_testing
    gene_variant_stats["Num variants only seen during testing"] = all_num_only_test

    gene_variant_stats["Num variants"] = all_num_all_variants

    gene_variant_stats.to_csv("gene_variant_stats.csv", index=False)

gene_variant_stats = pd.read_csv("gene_variant_stats.csv")

## Find strongest eQTL within context window

In [8]:
if not os.path.exists(
    os.path.join(data_dir, "YRI_eQTLs_within_context_for_selected_genes.csv")
):
    geuvadis_id_to_dbSNP_id = pd.read_csv(
        os.path.join(
            geuvadis_genotypes_dir, "Phase1.Geuvadis_dbSnp137_idconvert.txt.gz"
        ),
        names=["dbSNP_id", "geuvadis_id"],
        sep="\t",
    )
    geuvadis_id_annotations = pd.read_csv(
        os.path.join(
            geuvadis_genotypes_dir,
            "ALL.phase1_release_v3.20101123.snps_indels_sv.sites.gdid.gdannot.v2.vcf.gz",
        ),
        skiprows=78,
        sep="\t",
    )
    geuvadis_id_annotations = geuvadis_id_annotations.merge(
        geuvadis_id_to_dbSNP_id, left_on="ID", right_on="geuvadis_id", how="inner"
    )

    EUR_all_eQTLs_path = os.path.join(
        geuvadis_eQTL_dir, "EUR373.gene.cis.FDR5.all.rs137.txt.gz"
    )
    YRI_all_eQTLs_path = os.path.join(
        geuvadis_eQTL_dir, "YRI89.gene.cis.FDR5.all.rs137.txt.gz"
    )

    EUR_all_eQTLs = pd.read_csv(EUR_all_eQTLs_path, sep="\t")
    EUR_all_eQTLs = EUR_all_eQTLs[
        EUR_all_eQTLs["distance"] < (context_size // 2)
    ].reset_index(drop=True)
    EUR_all_eQTLs["abs_rvalue"] = np.abs(EUR_all_eQTLs["rvalue"])
    EUR_all_eQTLs["stable_id"] = EUR_all_eQTLs.apply(
        lambda x: x["GENE_ID"].split(".")[0], axis=1
    )
    EUR_all_eQTLs = EUR_all_eQTLs.sort_values(
        by=["GENE_ID", "abs_rvalue"], ascending=[True, False]
    ).reset_index(drop=True)

    YRI_all_eQTLs = pd.read_csv(YRI_all_eQTLs_path, sep="\t")
    YRI_all_eQTLs = YRI_all_eQTLs[
        YRI_all_eQTLs["distance"] < (context_size // 2)
    ].reset_index(drop=True)
    YRI_all_eQTLs["abs_rvalue"] = np.abs(YRI_all_eQTLs["rvalue"])
    YRI_all_eQTLs["stable_id"] = YRI_all_eQTLs.apply(
        lambda x: x["GENE_ID"].split(".")[0], axis=1
    )
    YRI_all_eQTLs = YRI_all_eQTLs.sort_values(
        by=["GENE_ID", "abs_rvalue"], ascending=[True, False]
    ).reset_index(drop=True)

    EUR_all_eQTLs = EUR_all_eQTLs.merge(
        geuvadis_id_annotations[
            ["dbSNP_id", "geuvadis_id", "#CHROM", "POS", "REF", "ALT", "INFO"]
        ],
        left_on="SNP_ID",
        right_on="dbSNP_id",
        how="inner",
    )
    YRI_all_eQTLs = YRI_all_eQTLs.merge(
        geuvadis_id_annotations[
            ["dbSNP_id", "geuvadis_id", "#CHROM", "POS", "REF", "ALT", "INFO"]
        ],
        left_on="SNP_ID",
        right_on="dbSNP_id",
        how="inner",
    )

    GEUVADIS_COUNTS = pd.read_csv(GEUVADIS_COUNTS_PATH)
    EUR_all_eQTLs = EUR_all_eQTLs.merge(
        GEUVADIS_COUNTS[["stable_id", "our_gene_name"]], on="stable_id", how="inner"
    )
    YRI_all_eQTLs = YRI_all_eQTLs.merge(
        GEUVADIS_COUNTS[["stable_id", "our_gene_name"]], on="stable_id", how="inner"
    )

    EUR_eQTLs_within_context_for_selected_genes = gene_variant_stats.merge(
        EUR_all_eQTLs, left_on="gene", right_on="our_gene_name", how="inner"
    )
    EUR_eQTLs_within_context_for_selected_genes = (
        EUR_eQTLs_within_context_for_selected_genes[
            EUR_eQTLs_within_context_for_selected_genes.columns[8:]
        ]
    )
    EUR_eQTLs_within_context_for_selected_genes.to_csv(
        os.path.join(data_dir, "EUR_eQTLs_within_context_for_selected_genes.csv"),
        index=False,
    )

    YRI_eQTLs_within_context_for_selected_genes = gene_variant_stats.merge(
        YRI_all_eQTLs, left_on="gene", right_on="our_gene_name", how="inner"
    )
    YRI_eQTLs_within_context_for_selected_genes = (
        YRI_eQTLs_within_context_for_selected_genes[
            YRI_eQTLs_within_context_for_selected_genes.columns[8:]
        ]
    )
    YRI_eQTLs_within_context_for_selected_genes.to_csv(
        os.path.join(data_dir, "YRI_eQTLs_within_context_for_selected_genes.csv"),
        index=False,
    )

EUR_eQTLs_within_context_for_selected_genes = pd.read_csv(
    os.path.join(data_dir, "EUR_eQTLs_within_context_for_selected_genes.csv")
)
YRI_eQTLs_within_context_for_selected_genes = pd.read_csv(
    os.path.join(data_dir, "YRI_eQTLs_within_context_for_selected_genes.csv")
)

top_EUR_eQTL_for_every_gene = (
    EUR_eQTLs_within_context_for_selected_genes.groupby("our_gene_name")
    .head(1)
    .reset_index(drop=True)
)
top_YRI_eQTL_for_every_gene = (
    YRI_eQTLs_within_context_for_selected_genes.groupby("our_gene_name")
    .head(1)
    .reset_index(drop=True)
)

In [9]:
top_eQTL_AFs = []
for i in tqdm(range(len(top_EUR_eQTL_for_every_gene))):
    row = top_EUR_eQTL_for_every_gene.iloc[i]
    af = float(row["INFO"].split(";EUR_AF=")[-1].split(";")[0])
    af = min(af, 1 - af)
    top_eQTL_AFs.append(af)

top_EUR_eQTL_for_every_gene["EUR_AF"] = top_eQTL_AFs

100%|██████████| 383/383 [00:00<00:00, 14261.15it/s]


# Compute mean performance across replicates

In [10]:
summarized_gene_perf = gene_perf[
    (
        gene_perf["model"]
        == "regression_data_seed_42_lr_0.0001_wd_0.001_rcprob_0.5_rsmax_3"
    )
    | (
        gene_perf["model"]
        == "regression_data_seed_97_lr_0.0001_wd_0.001_rcprob_0.5_rsmax_3"
    )
    | (
        gene_perf["model"]
        == "regression_data_seed_7_lr_0.0001_wd_0.001_rcprob_0.5_rsmax_3"
    )
]
summarized_gene_perf = (
    summarized_gene_perf.drop("model", axis=1).groupby(["gene", "class", "Chr"]).mean()
)
summarized_gene_perf["gene"] = summarized_gene_perf.index.get_level_values(0)
summarized_gene_perf["class"] = summarized_gene_perf.index.get_level_values(1)
summarized_gene_perf["Chr"] = summarized_gene_perf.index.get_level_values(2)
summarized_gene_perf = summarized_gene_perf.reset_index(drop=True)
summarized_gene_perf = summarized_gene_perf.merge(
    gene_perf[(gene_perf["model"] == "baseline")],
    on=["gene", "class", "Chr"],
    suffixes=("_finetuned", "_baseline"),
    how="inner",
)
summarized_gene_perf = summarized_gene_perf.reset_index(drop=True).drop("model", axis=1)
summarized_gene_perf

Unnamed: 0,Pearson_finetuned,|Pearson|_finetuned,gene,class,Chr,Pearson_baseline,|Pearson|_baseline
0,-0.057469,0.074269,a1bg,unseen,19,0.151831,0.151831
1,0.137963,0.178966,a4galt,unseen,22,0.231653,0.231653
2,-0.072919,0.072919,aanat,unseen,17,0.029424,0.029424
3,0.082937,0.082937,aasdh,unseen,4,-0.148702,0.148702
4,-0.035450,0.058455,abcb6,unseen,2,-0.058654,0.058654
...,...,...,...,...,...,...,...
3254,0.080985,0.080985,zscan21,unseen,7,0.063867,0.063867
3255,-0.110306,0.212780,zscan23,unseen,6,0.376974,0.376974
3256,0.217700,0.217700,zswim4,unseen,19,0.386495,0.386495
3257,-0.146993,0.146993,zxdc,unseen,3,0.180485,0.180485


# Random split genes analyses

In [11]:
random_split_gene_perf = (
    summarized_gene_perf[summarized_gene_perf["class"] == "random_split"]
    .copy()
    .reset_index(drop=True)
)

## Performance vs. distance of strongest eQTL within context window to TSS

In [12]:
EUR_top_eQTLs_path = os.path.join(
    geuvadis_eQTL_dir, "EUR373.gene.cis.FDR5.best.rs137.txt.gz"
)
YRI_top_eQTLs_path = os.path.join(
    geuvadis_eQTL_dir, "YRI89.gene.cis.FDR5.best.rs137.txt.gz"
)

EUR_top_eQTLs = pd.read_csv(
    EUR_top_eQTLs_path,
    sep="\t",
    names=[
        "SNP_ID",
        "ID",
        "GENE_ID",
        "PROBE_ID",
        "CHR_SNP",
        "CHR_GENE",
        "SNPpos",
        "TSSpos",
        "Absolute_Distance_from_TSS",
        "rvalue",
        "pvalue",
        "log10pvalue",
    ],
)
YRI_top_eQTLs = pd.read_csv(
    YRI_top_eQTLs_path,
    sep="\t",
    names=[
        "SNP_ID",
        "ID",
        "GENE_ID",
        "PROBE_ID",
        "CHR_SNP",
        "CHR_GENE",
        "SNPpos",
        "TSSpos",
        "Absolute_Distance_from_TSS",
        "rvalue",
        "pvalue",
        "log10pvalue",
    ],
)

GEUVADIS_COUNTS = pd.read_csv(GEUVADIS_COUNTS_PATH)

In [13]:
random_split_gene_perf = random_split_gene_perf.merge(
    top_EUR_eQTL_for_every_gene[
        [
            "our_gene_name",
            "abs_rvalue",
            "pvalue",
            "log10pvalue",
            "distance",
            "#CHROM",
            "POS",
            "REF",
            "ALT",
            "EUR_AF",
        ]
    ],
    left_on="gene",
    right_on="our_gene_name",
    how="inner",
).drop("our_gene_name", axis=1)

In [14]:
spr = spearmanr(
    random_split_gene_perf["distance"], random_split_gene_perf["Pearson_finetuned"]
)[0]
pr = pearsonr(
    random_split_gene_perf["distance"], random_split_gene_perf["Pearson_finetuned"]
)[0]

sns.regplot(
    data=random_split_gene_perf,
    x="distance",
    y="Pearson_finetuned",
    scatter_kws=dict(alpha=0.5),
    line_kws=dict(color="r"),
    ax=ax[0][0],
)
ax[0][0].set_xlabel(
    "Absolute distance of top European eQTL \nwithin context window (24.5kb) to TSS"
)
ax[0][0].set_ylabel("Fine-tuned Enformer Pearson Correlation")
ax[0][0].set_title(
    f"Fine-tuned model performance \non random-split genes vs.\nthe distance of the top European eQTL to the TSS\n(Spearman Correlation = {spr.round(3)})"
)

Text(0.5, 1.0, 'Fine-tuned model performance \non random-split genes vs.\nthe distance of the top European eQTL to the TSS\n(Spearman Correlation = -0.073)')

## Performance vs. log10p-val of strongest eQTL within context window

In [15]:
spr = spearmanr(
    random_split_gene_perf["log10pvalue"], random_split_gene_perf["Pearson_finetuned"]
)[0]
pr = pearsonr(
    random_split_gene_perf["log10pvalue"], random_split_gene_perf["Pearson_finetuned"]
)[0]

sns.regplot(
    data=random_split_gene_perf,
    x="log10pvalue",
    y="Pearson_finetuned",
    scatter_kws=dict(alpha=0.5),
    line_kws=dict(color="r"),
    ax=ax[0][1],
)
ax[0][1].set_xlabel(
    "-log10(p-value) of top European eQTL \nwithin context window (24.5kb)"
)
ax[0][1].set_xscale("log")
ax[0][1].set_ylabel("Fine-tuned Enformer Pearson Correlation")
ax[0][1].set_title(
    f"Fine-tuned model performance \non random-split genes vs.\nthe p-value of top European eQTL within context window\n(Spearman Correlation = {spr.round(3)})"
)

Text(0.5, 1.0, 'Fine-tuned model performance \non random-split genes vs.\nthe p-value of top European eQTL within context window\n(Spearman Correlation = 0.602)')

## Performance vs. EUR allele frequency of strongest eQTL within context window

In [16]:
# spr = spearmanr(random_split_gene_perf["EUR_AF"],
#                 random_split_gene_perf["Pearson_finetuned"])[0]
# pr = pearsonr(random_split_gene_perf["EUR_AF"],
#               random_split_gene_perf["Pearson_finetuned"])[0]

# sns.regplot(data=random_split_gene_perf,
#             x="EUR_AF",
#             y="Pearson_finetuned")
# plt.xlabel("EUR minor allele frequency of top European eQTL \nwithin context window (24.5kb)")
# plt.ylabel("Pearson Correlation")
# plt.title(f"Relationship between fine-tuned model performance \non random-split genes \nand the EUR minor AF of top European eQTL within context window\n(Spearman Correlation = {spr.round(3)})")
# plt.show()

## Performance vs. percentage of variants observed during training

In [17]:
random_split_gene_perf = random_split_gene_perf.merge(
    gene_variant_stats[["gene"] + gene_variant_stats.columns[-7:].tolist()],
    on="gene",
    how="inner",
)

In [18]:
spr = spearmanr(
    random_split_gene_perf["Percent of all variants seen during training"],
    random_split_gene_perf["Pearson_finetuned"],
)[0]
pr = pearsonr(
    random_split_gene_perf["Percent of all variants seen during training"],
    random_split_gene_perf["Pearson_finetuned"],
)[0]

sns.regplot(
    data=random_split_gene_perf,
    x="Percent of all variants seen during training",
    y="Pearson_finetuned",
    scatter_kws=dict(alpha=0.5),
    line_kws=dict(color="r"),
    ax=ax[1][0],
)

ax[1][0].set_ylabel("Fine-tuned Enformer Pearson Correlation")
ax[1][0].set_title(f"Random-split genes\n(Spearman Correlation = {spr.round(3)})")

Text(0.5, 1.0, 'Random-split genes\n(Spearman Correlation = 0.112)')

# Population split genes analyses

In [19]:
yri_split_gene_perf = (
    summarized_gene_perf[summarized_gene_perf["class"] == "yri_split"]
    .copy()
    .reset_index(drop=True)
)

## Performance vs. percentage of variants observed during training

In [20]:
yri_split_gene_perf = yri_split_gene_perf.merge(
    gene_variant_stats[["gene"] + gene_variant_stats.columns[-7:].tolist()],
    on="gene",
    how="inner",
)
yri_split_gene_perf

Unnamed: 0,Pearson_finetuned,|Pearson|_finetuned,gene,class,Chr,Pearson_baseline,|Pearson|_baseline,Percent of all variants seen during training,Num variants seen during training,Percent of all variants only seen during testing,Num variants only seen during testing,Num variants,Percent of all variants only seen during validation,Num variants only seen during validation
0,0.103805,0.103805,abhd10,yri_split,3,0.042268,0.042268,70.168067,334,29.831933,142,476,0.000000,0
1,0.242679,0.242679,ac017104.1,yri_split,2,0.159504,0.159504,55.803571,250,43.526786,195,448,0.446429,2
2,0.382086,0.382086,ac079921.2,yri_split,4,-0.132472,0.132472,57.739558,235,41.523342,169,407,0.737101,3
3,0.079437,0.079437,ac104653.1,yri_split,2,-0.119865,0.119865,74.492099,330,25.507901,113,443,0.000000,0
4,0.476652,0.476652,ac108206.1,yri_split,4,0.508323,0.508323,61.111111,341,38.530466,215,558,0.358423,2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
195,0.229185,0.229185,znf22-as1,yri_split,10,-0.054358,0.054358,58.503401,258,40.589569,179,441,0.907029,4
196,0.098031,0.098031,znf354a,yri_split,5,0.208737,0.208737,56.060606,296,43.560606,230,528,0.189394,1
197,0.288661,0.288661,znf354b,yri_split,5,0.271102,0.271102,73.103448,530,26.758621,194,725,0.137931,1
198,0.262740,0.262740,znf438,yri_split,10,0.353091,0.353091,56.103286,239,43.427230,185,426,0.469484,2


In [21]:
spr = spearmanr(
    yri_split_gene_perf["Percent of all variants seen during training"],
    yri_split_gene_perf["Pearson_finetuned"],
)[0]
pr = pearsonr(
    yri_split_gene_perf["Percent of all variants seen during training"],
    yri_split_gene_perf["Pearson_finetuned"],
)[0]

sns.regplot(
    data=yri_split_gene_perf,
    x="Percent of all variants seen during training",
    y="Pearson_finetuned",
    scatter_kws=dict(alpha=0.5),
    line_kws=dict(color="r"),
    ax=ax[1][1],
)
ax[1][1].set_ylabel("Fine-tuned Enformer Pearson Correlation")
ax[1][1].set_title(f"Population-split genes\n(Spearman Correlation = {spr.round(3)})")

Text(0.5, 1.0, 'Population-split genes\n(Spearman Correlation = 0.102)')

# Unseen genes analyses

In [22]:
unseen_split_gene_perf = (
    summarized_gene_perf[summarized_gene_perf["class"] == "unseen"]
    .copy()
    .reset_index(drop=True)
)

## Performance vs. Baseline performance

In [23]:
spr = spearmanr(
    unseen_split_gene_perf["Pearson_baseline"],
    unseen_split_gene_perf["Pearson_finetuned"],
)[0]
pr = pearsonr(
    unseen_split_gene_perf["Pearson_baseline"],
    unseen_split_gene_perf["Pearson_finetuned"],
)[0]

# sns.regplot(data=unseen_split_gene_perf,
#             x="Pearson_baseline",
#             y="Pearson_finetuned",
#             scatter_kws=dict(alpha=0.5),
#             line_kws=dict(color="r"),
#             ax=ax[2][0])

sns.scatterplot(
    data=unseen_split_gene_perf,
    x="Pearson_baseline",
    y="Pearson_finetuned",
    alpha=0.5,
    ax=ax[2][0],
)
ax[2][0].set_xlabel("Baseline Enformer Pearson Correlation")
ax[2][0].set_ylabel("Fine-tuned Enformer Pearson Correlation")
# ax[2][0].set_title(f"Fine-tuned model performance vs.\nbaseline Enformer performance \non unseen genes\n(Spearman Correlation = {spr.round(3)})")

Text(0, 0.5, 'Fine-tuned Enformer Pearson Correlation')

In [24]:
ax[2][1].set_visible(False)

In [25]:
plt.savefig(
    "figures/fig_S6+S7_correlations_with_gene_features.pdf",
    dpi=600,
    bbox_inches="tight",
)
plt.savefig(
    "figures/fig_S6+S7_correlations_with_gene_features.svg",
    dpi=600,
    bbox_inches="tight",
)