In [15]:
import numpy as np
import pandas as pd

In [16]:
mendelian_traits = (
    pd.read_csv("../../config/omim/filtered_traits.txt", header=None, dtype=str)
    .values.ravel().tolist()
)
complex_traits = (
    pd.read_csv("../../config/gwas/independent_traits_filtered.csv", header=None)
    .values.ravel().tolist()
)
complex_traits_n30 = (
    pd.read_csv("../../config/gwas/independent_traits_filtered_n30.csv", header=None)
    .values.ravel().tolist()
)

complex_trait_renaming = pd.read_csv(
    "../../results/gwas/raw/release1.1/UKBB_94traits_release1.traits", sep="\t",
    usecols=["trait", "description"]
).set_index("trait")["description"].to_dict()

mendelian_trait_renaming = {
    "600886": "Hyperferritinemia",
    "613985": "Beta-thalassemia",
    "614743": "Pulmonary fibrosis",
    "306900": "Hemophilia B",
    "250250": "Cartilage-hair hypoplasia",
    "174500": "Preaxial polydactyly II",
    "143890": "Hypercholesterolemia-1",
    "210710": "Dwarfism (MOPD1)",
}

trait_renaming = {**complex_trait_renaming, **mendelian_trait_renaming}

dataset_renaming = {
    "mendelian_matched_9": "Mendelian traits",
    "gwas_matched_9": "Complex traits",
}

subset_renaming = {
    f"non_coding_AND_{trait}": name for trait, name in trait_renaming.items()
}

In [17]:
datasets = [
    "mendelian_matched_9",
    "gwas_matched_9",
]

subsets = {
    "mendelian_matched_9": [f"non_coding_AND_{trait}" for trait in mendelian_traits],
    "gwas_matched_9": [f"non_coding_AND_{trait}" for trait in complex_traits],
}

linear_probing_subsets = [f"non_coding_AND_{trait}" for trait in complex_traits_n30]

modalities = [
    "Zero-shot",
    "Linear probing",
]

models = [
    "Borzoi",
    "GPN-MSA",
    "CADD",
]

def get_model_path(model, modality, dataset, subset):
    if modality == "Linear probing":
        predictor = f"{model}.LogisticRegression.chrom"
    elif modality == "Zero-shot":
        if model == "CADD":
            predictor = "CADD.plus.RawScore"
        elif model in ["Enformer", "Borzoi"]:
            predictor = f"{model}_L2_L2.plus.all"
        else:
            if "mendelian" in dataset:
                llr_version = "LLR"
                sign = "minus"
            elif "gwas" in dataset:
                llr_version = "absLLR"
                sign = "plus"
            predictor = f"{model}_{llr_version}.{sign}.score"
    return f"../../results/dataset/{dataset}/metrics_by_chrom_weighted_average/{subset}/{predictor}.csv"

In [18]:
rows = []
for dataset in datasets:
    for subset in subsets[dataset]:
        for modality in modalities:
            if modality == "Linear probing" and subset not in linear_probing_subsets:
                continue
            for model in models:
                path = get_model_path(model, modality, dataset, subset)
                df = pd.read_csv(path).iloc[0]
                rows.append([
                    dataset_renaming.get(dataset, dataset),
                    subset_renaming.get(subset, subset),
                    modality,
                    model,
                    df["score"],
                    df["se"],
                ])
df = pd.DataFrame(rows, columns=["dataset", "subset", "modality", "model", "score", "se"])
df

Unnamed: 0,dataset,subset,modality,model,score,se
0,Mendelian traits,Hyperferritinemia,Zero-shot,Borzoi,0.131633,2.776946e-17
1,Mendelian traits,Hyperferritinemia,Zero-shot,GPN-MSA,0.964481,0.000000e+00
2,Mendelian traits,Hyperferritinemia,Zero-shot,CADD,0.956952,2.221557e-16
3,Mendelian traits,Beta-thalassemia,Zero-shot,Borzoi,0.624800,1.110779e-16
4,Mendelian traits,Beta-thalassemia,Zero-shot,GPN-MSA,0.802821,0.000000e+00
...,...,...,...,...,...,...
106,Complex traits,Balding Type 4,Zero-shot,GPN-MSA,0.375868,8.847811e-02
107,Complex traits,Balding Type 4,Zero-shot,CADD,0.388687,1.033617e-01
108,Complex traits,Blood clot in the leg,Zero-shot,Borzoi,0.497024,8.991410e-02
109,Complex traits,Blood clot in the leg,Zero-shot,GPN-MSA,0.550735,1.103819e-01


In [19]:
def format_score(x):
    return (x * 100).round().astype(int).apply(lambda y: f"{y:02d}")

def format_se(x):
    assert (x * 100).max() < 100
    return (x * 100).round().astype(int).apply(lambda y: f"{y:02d}")

#df["value"] = format_score(df.score) + "$\pm$" + format_se(df.se)
#df["value"] = format_score(df.score)
#df["value"] = df.score.apply(lambda x: f"{x:.2f}") + "$\pm$" + df.se.apply(lambda x: f"{x:.2f}")
df["value"] = df.score.apply(lambda x: f"{x:.3f}")
df

Unnamed: 0,dataset,subset,modality,model,score,se,value
0,Mendelian traits,Hyperferritinemia,Zero-shot,Borzoi,0.131633,2.776946e-17,0.132
1,Mendelian traits,Hyperferritinemia,Zero-shot,GPN-MSA,0.964481,0.000000e+00,0.964
2,Mendelian traits,Hyperferritinemia,Zero-shot,CADD,0.956952,2.221557e-16,0.957
3,Mendelian traits,Beta-thalassemia,Zero-shot,Borzoi,0.624800,1.110779e-16,0.625
4,Mendelian traits,Beta-thalassemia,Zero-shot,GPN-MSA,0.802821,0.000000e+00,0.803
...,...,...,...,...,...,...,...
106,Complex traits,Balding Type 4,Zero-shot,GPN-MSA,0.375868,8.847811e-02,0.376
107,Complex traits,Balding Type 4,Zero-shot,CADD,0.388687,1.033617e-01,0.389
108,Complex traits,Blood clot in the leg,Zero-shot,Borzoi,0.497024,8.991410e-02,0.497
109,Complex traits,Blood clot in the leg,Zero-shot,GPN-MSA,0.550735,1.103819e-01,0.551


In [20]:
#df.modality = df.modality.map({
#    "Zero-shot": r"\textbf{Zero-shot}",
#    "Linear probing": r"\textbf{Linear probing}",
#})

In [21]:
df = df.pivot_table(
    columns=["modality", "model"],
    index=[
        #"dataset",
        "subset",
    ],
    values="value",
    aggfunc="first", sort=False,
)
df = df.fillna("-")
df

modality,Zero-shot,Zero-shot,Zero-shot,Linear probing,Linear probing,Linear probing
model,Borzoi,GPN-MSA,CADD,Borzoi,GPN-MSA,CADD
subset,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2
Hyperferritinemia,0.132,0.964,0.957,-,-,-
Beta-thalassemia,0.625,0.803,0.826,-,-,-
Pulmonary fibrosis,0.356,0.863,0.944,-,-,-
Hemophilia B,0.905,0.753,0.895,-,-,-
Cartilage-hair hypoplasia,0.361,0.833,0.778,-,-,-
Preaxial polydactyly II,0.185,0.929,0.916,-,-,-
Hypercholesterolemia-1,0.799,0.789,0.766,-,-,-
Dwarfism (MOPD1),0.319,0.967,0.893,-,-,-
Adult height,0.292,0.313,0.407,0.281,0.341,0.383
Platelet count,0.355,0.234,0.281,0.372,0.268,0.315


In [22]:
#df.index.names = [None, None]
df.index.name = None
df.columns.names = [None, None]
#df.columns.name = None

In [24]:
print(df.to_latex(multicolumn_format='c', escape=False))

\begin{tabular}{lllllll}
\toprule
 & \multicolumn{3}{c}{Zero-shot} & \multicolumn{3}{c}{Linear probing} \\
 & Borzoi & GPN-MSA & CADD & Borzoi & GPN-MSA & CADD \\
\midrule
Hyperferritinemia & 0.132 & 0.964 & 0.957 & - & - & - \\
Beta-thalassemia & 0.625 & 0.803 & 0.826 & - & - & - \\
Pulmonary fibrosis & 0.356 & 0.863 & 0.944 & - & - & - \\
Hemophilia B & 0.905 & 0.753 & 0.895 & - & - & - \\
Cartilage-hair hypoplasia & 0.361 & 0.833 & 0.778 & - & - & - \\
Preaxial polydactyly II & 0.185 & 0.929 & 0.916 & - & - & - \\
Hypercholesterolemia-1 & 0.799 & 0.789 & 0.766 & - & - & - \\
Dwarfism (MOPD1) & 0.319 & 0.967 & 0.893 & - & - & - \\
Adult height & 0.292 & 0.313 & 0.407 & 0.281 & 0.341 & 0.383 \\
Platelet count & 0.355 & 0.234 & 0.281 & 0.372 & 0.268 & 0.315 \\
Mean corpuscular volume & 0.383 & 0.288 & 0.342 & 0.448 & 0.306 & 0.299 \\
Estimated heel bone mineral density & 0.292 & 0.378 & 0.418 & 0.412 & 0.394 & 0.400 \\
Monocyte count & 0.468 & 0.358 & 0.375 & 0.611 & 0.312 & 0.320 \\
H