In [1]:
import numpy as np
import pandas as pd

In [2]:
mendelian_traits = (
    pd.read_csv("../../config/omim/filtered_traits.txt", header=None, dtype=str)
    .values.ravel().tolist()
)
complex_traits = (
    pd.read_csv("../../config/gwas/independent_traits_filtered.csv", header=None)
    .values.ravel().tolist()
)

complex_trait_renaming = pd.read_csv(
    "../../results/gwas/raw/release1.1/UKBB_94traits_release1.traits", sep="\t",
    usecols=["trait", "description"]
).set_index("trait")["description"].to_dict()

mendelian_trait_renaming = {
    "600886": "Hyperferritinemia",
    "613985": "Beta-thalassemia",
    "614743": "Pulmonary fibrosis",
    "306900": "Hemophilia B",
    "250250": "Cartilage-hair hypoplasia",
    "174500": "Preaxial polydactyly II",
    "143890": "Hypercholesterolemia-1",
    "210710": "Dwarfism (MOPD1)",
}

trait_renaming = {**complex_trait_renaming, **mendelian_trait_renaming}

In [3]:
datasets = [
    "mendelian_matched_9",
    "gwas_matched_9",
]

subsets = {
    "mendelian_matched_9": [f"non_coding_AND_{trait}" for trait in mendelian_traits],
    "gwas_matched_9": [f"non_coding_AND_{trait}" for trait in complex_traits],
}

modalities = [
    "Zero-shot",
]

models = [
    "CADD",
    "GPN-MSA",
    "GPN",
    "Enformer",
    "Borzoi",
]

def get_model_path(model, modality, dataset, subset):
    supervised_suffix = "LogisticRegression.chrom"
    if model == "CADD":
        predictor = "CADD.plus.RawScore" if modality == "Zero-shot" else f"CADD.{supervised_suffix}"
    elif model in ["Enformer", "Borzoi"]:
        predictor = f"{model}_L2_L2.plus.all" if modality == "Zero-shot" else f"{model}.{supervised_suffix}"
    elif model == "Ensemble":
        if "mendelian" in dataset:
            prefix = "OMIM_Ensemble_v2" 
        else:
            prefix = "Enformer+GPN-MSA+CADD"
        predictor = f"{prefix}.{supervised_suffix}"
    else:
        if "mendelian" in dataset:
            llr_version = "LLR"
            sign = "minus"
        elif "gwas" in dataset:
            llr_version = "absLLR"
            sign = "plus"
        predictor = f"{model}_{llr_version}.{sign}.score" if modality == "Zero-shot" else f"{model}_{llr_version}+InnerProducts.{supervised_suffix}"
    return f"../../results/dataset/{dataset}/metrics/{subset}/{predictor}.csv"

In [13]:
rows = []
for dataset in datasets:
    for subset in subsets[dataset]:
        for modality in modalities:
            for model in models:
                if model == "Ensemble" and modality == "Zero-shot":
                    continue
                path = get_model_path(model, modality, dataset, subset)
                df = pd.read_csv(path).iloc[0]
                rows.append([dataset, subset, modality, model, df["score"], df["se"]])
df = pd.DataFrame(rows, columns=["dataset", "subset", "modality", "model", "score", "se"])
df

Unnamed: 0,dataset,subset,modality,model,score,se
0,mendelian_matched_9,non_coding_AND_600886,Zero-shot,CADD,0.956952,0.032780
1,mendelian_matched_9,non_coding_AND_600886,Zero-shot,GPN-MSA,0.964481,0.028176
2,mendelian_matched_9,non_coding_AND_600886,Zero-shot,GPN,0.950000,0.034265
3,mendelian_matched_9,non_coding_AND_600886,Zero-shot,Enformer,0.095493,0.007456
4,mendelian_matched_9,non_coding_AND_600886,Zero-shot,Borzoi,0.131633,0.024743
...,...,...,...,...,...,...
120,gwas_matched_9,non_coding_AND_DVT,Zero-shot,CADD,0.277169,0.118544
121,gwas_matched_9,non_coding_AND_DVT,Zero-shot,GPN-MSA,0.319829,0.105694
122,gwas_matched_9,non_coding_AND_DVT,Zero-shot,GPN,0.092867,0.020430
123,gwas_matched_9,non_coding_AND_DVT,Zero-shot,Enformer,0.319249,0.093419


In [14]:
def format_score(x):
    return (x * 100).round().astype(int).apply(lambda y: f"{y:02d}")

def format_se(x):
    assert (x * 100).max() < 100
    return (x * 100).round().astype(int).apply(lambda y: f"{y:02d}")

df["value"] = format_score(df.score) + "$\pm$" + format_se(df.se)
#df["value"] = df.score.apply(lambda x: f"{x:.2f}") + "$\pm$" + df.se.apply(lambda x: f"{x:.2f}")

In [15]:
#df.loc[(df.model.isin(["Enformer", "Borzoi", "GPN"])) & (df.subset!="non_missense"), "value"] = "-"

In [16]:
"""
bold_values = [
    ("mendelian_matched_9", "all", "Zero-shot", "CADD"),
    ("mendelian_matched_9", "all", "Zero-shot", "GPN-MSA"),
    ("mendelian_matched_9", "missense_variant", "Zero-shot", "CADD"),
    ("mendelian_matched_9", "missense_variant", "Zero-shot", "GPN-MSA"),
    ("mendelian_matched_9", "non_missense", "Zero-shot", "CADD"),
    ("mendelian_matched_9", "non_missense", "Zero-shot", "GPN-MSA"),

    ("mendelian_matched_9", "all", "Linear probing", "CADD"),
    ("mendelian_matched_9", "all", "Linear probing", "Ensemble"),
    ("mendelian_matched_9", "missense_variant", "Linear probing", "CADD"),
    ("mendelian_matched_9", "missense_variant", "Linear probing", "Ensemble"),
    ("mendelian_matched_9", "non_missense", "Linear probing", "CADD"),
    ("mendelian_matched_9", "non_missense", "Linear probing", "Ensemble"),

    ("gwas_matched_9", "all", "Zero-shot", "CADD"),
    ("gwas_matched_9", "all", "Zero-shot", "GPN-MSA"),
    ("gwas_matched_9", "missense_variant", "Zero-shot", "CADD"),
    #("gwas_matched_9", "missense_variant", "Zero-shot", "GPN-MSA"),
    ("gwas_matched_9", "non_missense", "Zero-shot", "Enformer"),
    ("gwas_matched_9", "non_missense", "Zero-shot", "Borzoi"),

    ("gwas_matched_9", "all", "Linear probing", "Ensemble"),
    ("gwas_matched_9", "missense_variant", "Linear probing", "Ensemble"),
    ("gwas_matched_9", "non_missense", "Linear probing", "Ensemble"),
]

for dataset, subset, modality, model in bold_values:
    mask = (
        (df.dataset==dataset) & (df.subset==subset) &
        (df.modality==modality) & (df.model==model)
    )
    df.loc[mask, "value"] = r"\textbf{" + df.loc[mask, "value"] + "}"
""";

In [17]:
df.dataset = df.dataset.map({
    "mendelian_matched_9": r"\textbf{Mendelian traits}",
    "gwas_matched_9": r"\textbf{Complex traits}",
})
df.subset = df.subset.map({
    f"non_coding_AND_{trait}": trait_renaming[trait] for trait in mendelian_traits + complex_traits
})
df.modality = df.modality.map({
    "Zero-shot": r"\textbf{Zero-shot}",
    "Linear probing": r"\textbf{Linear probing}",
})

In [18]:
df = df.pivot_table(
    columns=["model"],
    index=[
        #"dataset",
        "subset",
    ],
    values="value",
    aggfunc="first", sort=False,
)
df

model,CADD,GPN-MSA,GPN,Enformer,Borzoi
subset,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
Hyperferritinemia,96$\pm$03,96$\pm$03,95$\pm$03,10$\pm$01,13$\pm$02
Beta-thalassemia,83$\pm$07,80$\pm$09,57$\pm$07,48$\pm$09,62$\pm$10
Pulmonary fibrosis,94$\pm$04,86$\pm$06,16$\pm$04,45$\pm$09,36$\pm$07
Hemophilia B,89$\pm$06,75$\pm$08,30$\pm$08,95$\pm$04,91$\pm$04
Cartilage-hair hypoplasia,79$\pm$09,84$\pm$06,08$\pm$01,51$\pm$10,36$\pm$10
Preaxial polydactyly II,92$\pm$07,93$\pm$06,10$\pm$03,18$\pm$06,18$\pm$06
Hypercholesterolemia-1,76$\pm$10,79$\pm$10,42$\pm$12,79$\pm$08,80$\pm$08
Dwarfism (MOPD1),89$\pm$09,97$\pm$03,86$\pm$08,68$\pm$12,32$\pm$08
Adult height,36$\pm$05,24$\pm$05,11$\pm$02,20$\pm$03,22$\pm$04
Platelet count,20$\pm$03,17$\pm$03,10$\pm$01,26$\pm$04,24$\pm$04


In [20]:
#df.index.names = [None, None]
df.index.name = None
#df.columns.names = [None, None]
df.columns.name = None

In [21]:
print(df.to_latex(multicolumn_format='c', escape=False))

\begin{tabular}{llllll}
\toprule
 & CADD & GPN-MSA & GPN & Enformer & Borzoi \\
\midrule
Hyperferritinemia & 96$\pm$03 & 96$\pm$03 & 95$\pm$03 & 10$\pm$01 & 13$\pm$02 \\
Beta-thalassemia & 83$\pm$07 & 80$\pm$09 & 57$\pm$07 & 48$\pm$09 & 62$\pm$10 \\
Pulmonary fibrosis & 94$\pm$04 & 86$\pm$06 & 16$\pm$04 & 45$\pm$09 & 36$\pm$07 \\
Hemophilia B & 89$\pm$06 & 75$\pm$08 & 30$\pm$08 & 95$\pm$04 & 91$\pm$04 \\
Cartilage-hair hypoplasia & 79$\pm$09 & 84$\pm$06 & 08$\pm$01 & 51$\pm$10 & 36$\pm$10 \\
Preaxial polydactyly II & 92$\pm$07 & 93$\pm$06 & 10$\pm$03 & 18$\pm$06 & 18$\pm$06 \\
Hypercholesterolemia-1 & 76$\pm$10 & 79$\pm$10 & 42$\pm$12 & 79$\pm$08 & 80$\pm$08 \\
Dwarfism (MOPD1) & 89$\pm$09 & 97$\pm$03 & 86$\pm$08 & 68$\pm$12 & 32$\pm$08 \\
Adult height & 36$\pm$05 & 24$\pm$05 & 11$\pm$02 & 20$\pm$03 & 22$\pm$04 \\
Platelet count & 20$\pm$03 & 17$\pm$03 & 10$\pm$01 & 26$\pm$04 & 24$\pm$04 \\
Mean corpuscular volume & 21$\pm$04 & 16$\pm$02 & 11$\pm$02 & 30$\pm$05 & 28$\pm$04 \\
Estimated