In [61]:
import matplotlib.pyplot as plt
import pandas as pd
import polars as pl
import seaborn as sns
from tqdm import tqdm
import warnings
warnings.filterwarnings("ignore")

In [62]:
def format_number(num):
    """
    Converts a number into a more readable format, using K for thousands, M for millions, etc.
    Args:
    - num: The number to format.
    
    Returns:
    - A formatted string representing the number.
    """
    if num >= 1e9:
        return f'{num/1e9:.1f}B'
    elif num >= 1e6:
        return f'{num/1e6:.1f}M'
    elif num >= 1e3:
        return f'{num/1e3:.1f}K'
    else:
        return str(num)

def make_plots():
    metric = df.columns[0]
    
    for subset in subsets:
        n_pos, n_neg = sample_size[subset]
        if metric == "AUROC":
            baseline = 0.5
        elif metric == "AUPRC":
            baseline = n_pos / (n_pos + n_neg)
        elif metric == "Odds ratio":
            baseline = 1
        g = sns.catplot(
            data=df[df.subset==subset],
            y="feature",
            x=metric,
            row="Model",
            sharex=True,
            sharey=False,
            kind="bar",
            color="C0",
            height=3,
            aspect=2,
        )
        g.set(
            xlim=baseline,
            ylabel="",
        )
        g.set_titles("{row_name}")
        sample_size_str = f"n={format_number(n_pos)} vs. {format_number(n_neg)}"
        plt.suptitle(
            f"{subset_renaming.get(subset, subset)}\n{sample_size_str}",
            x=1.0,
            y=1.05,
        )
        #plt.savefig("top_features.pdf", bbox_inches="tight")
        plt.show()

In [146]:
subset_renaming = {
    "non_coding_AND_600886": "Hyperferritinemia",
    "non_coding_AND_613985": "Beta-thalassemia",
    "non_coding_AND_614743": "Pulmonary fibrosis",
    "non_coding_AND_306900": "Hemophilia B",
    "non_coding_AND_250250": "Cartilage-hair hypoplasia",
    "non_coding_AND_174500": "Preaxial polydactyly II",
    "non_coding_AND_143890": "Hypercholesterolemia-1",
    "non_coding_AND_210710": "Dwarfism (MOPD1)",

    "non_coding_AND_Mono": "Monocyte count",
    "non_coding_AND_HbA1c": "Hemoglobin A1c",
    "non_coding_AND_HDLC": "High density lipoprotein cholesterol",
}

In [123]:
enformer_metadata = pd.read_csv(
    "../../results/metadata/Enformer.csv",
    usecols=["name", "description", "assay", "sample"],
)
enformer_metadata

Unnamed: 0,name,description,assay,sample
0,ENCFF833POA,DNASE:cerebellum male adult (27 years) and mal...,DNASE,cerebellum male adult (27 years) and male adul...
1,ENCFF110QGM,DNASE:frontal cortex male adult (27 years) and...,DNASE,frontal cortex male adult (27 years) and male ...
2,ENCFF880MKD,DNASE:chorion,DNASE,chorion
3,ENCFF463ZLQ,DNASE:Ishikawa treated with 0.02% dimethyl sul...,DNASE,Ishikawa treated with 0.02% dimethyl sulfoxide...
4,ENCFF890OGQ,DNASE:GM03348,DNASE,GM03348
...,...,...,...,...
5308,CNhs14239,CAGE:epithelioid sarcoma cell line:HS-ES-2R,CAGE,epithelioid sarcoma cell line:HS-ES-2R
5309,CNhs14240,CAGE:squamous cell lung carcinoma cell line:RE...,CAGE,squamous cell lung carcinoma cell line:RERF-LC-AI
5310,CNhs14241,CAGE:gastric cancer cell line:GSS,CAGE,gastric cancer cell line:GSS
5311,CNhs14244,CAGE:carcinoid cell line:NCI-H727,CAGE,carcinoid cell line:NCI-H727


In [98]:
borzoi_metadata = pd.read_csv(
    "../../results/metadata/Borzoi.csv", usecols=["name", "description", "assay", "sample"]
)
borzoi_metadata

Unnamed: 0,name,description,assay,sample
0,CNhs10608+,CAGE:Clontech Human Universal Reference Total ...,CAGE,"Clontech Human Universal Reference Total RNA, ..."
1,CNhs10608-,CAGE:Clontech Human Universal Reference Total ...,CAGE,"Clontech Human Universal Reference Total RNA, ..."
2,CNhs10610+,CAGE:SABiosciences XpressRef Human Universal T...,CAGE,SABiosciences XpressRef Human Universal Total ...
3,CNhs10610-,CAGE:SABiosciences XpressRef Human Universal T...,CAGE,SABiosciences XpressRef Human Universal Total ...
4,CNhs10612+,CAGE:Universal RNA - Human Normal Tissues Bioc...,CAGE,"Universal RNA - Human Normal Tissues Biochain,..."
...,...,...,...,...
7606,GTEX-13FTX-1026-SM-5J2O5.1,RNA:uterus,RNA,uterus
7607,GTEX-1MA7W-1526-SM-DHXKS.1,RNA:uterus,RNA,uterus
7608,GTEX-11EMC-1926-SM-5A5JU.1,RNA:vagina,RNA,vagina
7609,GTEX-12WSB-2426-SM-5EGJC.1,RNA:vagina,RNA,vagina


In [142]:
#dataset = "mendelian_matched_9"
#subsets = [
#    "non_coding_AND_613985",  # Beta-thalassemia
#    #"non_coding_AND_614743",
#    "non_coding_AND_306900",  # Hemophilia B
#    #"non_coding_AND_250250",
#    "non_coding_AND_143890",  # Hypercholesterolemia-1
#]

dataset = "gwas_matched_9"
subsets = [
    "non_coding_AND_Mono",
    "non_coding_AND_HbA1c",
    "non_coding_AND_HDLC",
    #"non_coding_AND_Alb",
    #"non_coding_AND_DVT",
]

base_dir = f"../../results/dataset/{dataset}"
V = pd.read_parquet(f"{base_dir}/test.parquet")

## Marginal performance

In [147]:
models = pd.DataFrame(
    [
        ["Enformer", "Enformer_L2", "C0"],
        ["Borzoi", "Borzoi_L2", "C1"],
    ],
    columns=["Model", "path", "color"]
).set_index("Model")
models

Unnamed: 0_level_0,path,color
Model,Unnamed: 1_level_1,Unnamed: 2_level_1
Enformer,Enformer_L2,C0
Borzoi,Borzoi_L2,C1


In [148]:
dfs = []
sample_size = {}

for subset in tqdm(subsets):
    s = pd.read_parquet(f"{base_dir}/subset/{subset}.parquet")
    V_s = s.merge(V, on=["chrom", "pos", "ref", "alt"], how="left")
    sample_size[subset] = V_s.label.sum(), (~V_s.label).sum()
    for model in models.index:
        df = pd.read_csv(f"{base_dir}/unsupervised_metrics/{subset}/{models.loc[model, 'path']}.csv")
        if model == "Enformer":
            df = df.merge(enformer_metadata, left_on="feature", right_on="name", how="inner")
            df = df[df.assay.isin(["RNA", "CAGE"])]
            df = df.drop_duplicates("sample")
        elif model == "Borzoi":
            df = df.merge(borzoi_metadata, left_on="feature", right_on="name", how="inner")
            df = df[df.assay.isin(["RNA", "CAGE"])]
            df = df.drop_duplicates("sample")
        df = df.head(30)
        df["subset"] = subset_renaming.get(subset, subset)
        df["Model"] = model
        dfs.append(df)
df = pd.concat(dfs)
df

100%|██████████████████████████████████████████████| 3/3 [00:00<00:00, 18.51it/s]


Unnamed: 0,AUPRC,feature,name,description,assay,sample,subset,Model
5,0.557506,CNhs10852,CNhs10852,"CAGE:CD14+ Monocytes,",CAGE,"CD14+ Monocytes,",Monocyte count,Enformer
7,0.553407,CNhs13475,CNhs13475,"CAGE:CD14+ monocytes - treated with BCG,",CAGE,"CD14+ monocytes - treated with BCG,",Monocyte count,Enformer
8,0.551956,CNhs13483,CNhs13483,CAGE:CD14+ monocytes - treated with Trehalose ...,CAGE,CD14+ monocytes - treated with Trehalose dimyc...,Monocyte count,Enformer
10,0.548486,CNhs13488,CNhs13488,"CAGE:CD14+ monocytes - treated with Candida,",CAGE,"CD14+ monocytes - treated with Candida,",Monocyte count,Enformer
11,0.547704,CNhs13532,CNhs13532,CAGE:CD14+ monocytes - treated with Group A st...,CAGE,CD14+ monocytes - treated with Group A strepto...,Monocyte count,Enformer
...,...,...,...,...,...,...,...,...
573,0.309830,ENCFF643UKE-,ENCFF643UKE-,RNA:spleen tissue male adult (37 years),RNA,spleen tissue male adult (37 years),High density lipoprotein cholesterol,Borzoi
597,0.308383,CNhs12328-,CNhs12328-,CAGE:hepatocellular carcinoma cell line: HepG2...,CAGE,hepatocellular carcinoma cell line: HepG2 ENCO...,High density lipoprotein cholesterol,Borzoi
608,0.307394,CNhs11062-,CNhs11062-,CAGE:Dendritic Cells - monocyte immature deriv...,CAGE,"Dendritic Cells - monocyte immature derived, ,...",High density lipoprotein cholesterol,Borzoi
614,0.306492,CNhs12195-,CNhs12195-,"CAGE:Dendritic Cells - monocyte immature derived,",CAGE,"Dendritic Cells - monocyte immature derived,",High density lipoprotein cholesterol,Borzoi


In [150]:
for subset in df.subset.unique():
    print(subset)
    for model in df.Model.unique():
        print(model)
        df2 = df[(df.subset==subset) & (df.Model==model)]
        print(df2[["AUPRC", "sample"]].values)

Monocyte count
Enformer
[[0.5575056900181379 'CD14+ Monocytes,']
 [0.553407430896893 'CD14+ monocytes - treated with BCG,']
 [0.5519562307877873
  'CD14+ monocytes - treated with Trehalose dimycolate (TDM),']
 [0.5484858468486936 'CD14+ monocytes - treated with Candida,']
 [0.5477039854067974
  'CD14+ monocytes - treated with Group A streptococci,']
 [0.5441841878348003 'CD14+CD16+ Monocytes,']
 [0.5414147891748046 'Basophils,']
 [0.5413346969887541 'CD14+CD16- Monocytes,']
 [0.5400793643610309 'CD14+ monocytes - treated with lipopolysaccharide,']
 [0.5377562561485341 'CD14+ monocytes - treated with Cryptococcus,']
 [0.5356152004566095 'CD14+ monocytes - treated with B-glucan,']
 [0.5292521666142148 'CD14-CD16+ Monocytes,']
 [0.5277392387183522 'CD14+ monocytes - treated with Salmonella,']
 [0.5271750254239144 'CD14+ monocytes - treated with IFN + N-hexane,']
 [0.524208618956415 'CD14+ monocytes - mock treated,']
 [0.5138705446091504 'Peripheral Blood Mononuclear Cells,']
 [0.503465870

In [128]:
# Mendelian
make_plots()

ValueError: min() arg is an empty sequence