In [1]:
import numpy as np
import pandas as pd

In [3]:
mendelian_traits = (
    pd.read_csv("../../config/omim/filtered_traits.txt", header=None, dtype=str)
    .values.ravel().tolist()
)
complex_traits = (
    pd.read_csv("../../config/gwas/independent_traits_filtered.csv", header=None)
    .values.ravel().tolist()
)
complex_traits_n30 = (
    pd.read_csv("../../config/gwas/independent_traits_filtered_n30.csv", header=None)
    .values.ravel().tolist()
)

complex_trait_renaming = pd.read_csv(
    "../../results/gwas/raw/release1.1/UKBB_94traits_release1.traits", sep="\t",
    usecols=["trait", "description"]
).set_index("trait")["description"].to_dict()

mendelian_trait_renaming = {
    "600886": "Hyperferritinemia",
    "613985": "Beta-thalassemia",
    "614743": "Pulmonary fibrosis",
    "306900": "Hemophilia B",
    "250250": "Cartilage-hair hypoplasia",
    "174500": "Preaxial polydactyly II",
    "143890": "Hypercholesterolemia-1",
    "210710": "Dwarfism (MOPD1)",
}

trait_renaming = {**complex_trait_renaming, **mendelian_trait_renaming}

dataset_renaming = {
    "mendelian_matched_9": "Mendelian traits",
    "gwas_matched_9": "Complex traits",
}

subset_renaming = {
    f"non_coding_AND_{trait}": name for trait, name in trait_renaming.items()
}

In [21]:
datasets = [
    "mendelian_matched_9",
    "gwas_matched_9",
]

subsets = {
    "mendelian_matched_9": [f"non_coding_AND_{trait}" for trait in mendelian_traits],
    "gwas_matched_9": [f"non_coding_AND_{trait}" for trait in complex_traits],
}

linear_probing_subsets = [f"non_coding_AND_{trait}" for trait in complex_traits_n30]

modalities = [
    "Zero-shot",
    "Linear probing",
]

models = [
    "CADD",
    "GPN-MSA",
    "Enformer",
    "Borzoi",
]

def get_model_path(model, modality, dataset, subset):
    supervised_suffix = "LogisticRegression.chrom"
    if model == "CADD":
        predictor = "CADD.plus.RawScore" if modality == "Zero-shot" else f"CADD.{supervised_suffix}"
    elif model in ["Enformer", "Borzoi"]:
        predictor = f"{model}_L2_L2.plus.all" if modality == "Zero-shot" else f"{model}.{supervised_suffix}"
    elif model == "Ensemble":
        if "mendelian" in dataset:
            prefix = "OMIM_Ensemble_v2" 
        else:
            prefix = "Enformer+GPN-MSA+CADD"
        predictor = f"{prefix}.{supervised_suffix}"
    else:
        if "mendelian" in dataset:
            llr_version = "LLR"
            sign = "minus"
        elif "gwas" in dataset:
            llr_version = "absLLR"
            sign = "plus"
        predictor = f"{model}_{llr_version}.{sign}.score" if modality == "Zero-shot" else f"{model}_{llr_version}+InnerProducts.{supervised_suffix}"
    #return f"../../results/dataset/{dataset}/metrics/{subset}/{predictor}.csv"
    return f"../../results/dataset/{dataset}/metrics_by_chrom_weighted_average/{subset}/{predictor}.csv"

In [30]:
rows = []
for dataset in datasets:
    for subset in subsets[dataset]:
        for modality in modalities:
            if modality == "Linear probing" and subset not in linear_probing_subsets:
                continue
            for model in models:
                if model == "Ensemble" and modality == "Zero-shot":
                    continue
                path = get_model_path(model, modality, dataset, subset)
                df = pd.read_csv(path).iloc[0]
                rows.append([
                    dataset_renaming.get(dataset, dataset),
                    subset_renaming.get(subset, subset),
                    modality,
                    model,
                    df["score"],
                    df["se"],
                ])
df = pd.DataFrame(rows, columns=["dataset", "subset", "modality", "model", "score", "se"])
df

Unnamed: 0,dataset,subset,modality,model,score,se
0,Mendelian traits,Hyperferritinemia,Zero-shot,CADD,0.956952,0.000000
1,Mendelian traits,Hyperferritinemia,Zero-shot,GPN-MSA,0.964481,0.000000
2,Mendelian traits,Hyperferritinemia,Zero-shot,Enformer,0.095493,0.000000
3,Mendelian traits,Hyperferritinemia,Zero-shot,Borzoi,0.131633,0.000000
4,Mendelian traits,Beta-thalassemia,Zero-shot,CADD,0.825912,0.000000
...,...,...,...,...,...,...
143,Complex traits,Balding Type 4,Zero-shot,Borzoi,0.246771,0.038382
144,Complex traits,Blood clot in the leg,Zero-shot,CADD,0.481151,0.099264
145,Complex traits,Blood clot in the leg,Zero-shot,GPN-MSA,0.550735,0.110382
146,Complex traits,Blood clot in the leg,Zero-shot,Enformer,0.454497,0.084802


In [31]:
def format_score(x):
    return (x * 100).round().astype(int).apply(lambda y: f"{y:02d}")

def format_se(x):
    assert (x * 100).max() < 100
    return (x * 100).round().astype(int).apply(lambda y: f"{y:02d}")

#df["value"] = format_score(df.score) + "$\pm$" + format_se(df.se)
#df["value"] = format_score(df.score)
#df["value"] = df.score.apply(lambda x: f"{x:.2f}") + "$\pm$" + df.se.apply(lambda x: f"{x:.2f}")
df["value"] = df.score.apply(lambda x: f"{x:.3f}")
df

Unnamed: 0,dataset,subset,modality,model,score,se,value
0,Mendelian traits,Hyperferritinemia,Zero-shot,CADD,0.956952,0.000000,0.957
1,Mendelian traits,Hyperferritinemia,Zero-shot,GPN-MSA,0.964481,0.000000,0.964
2,Mendelian traits,Hyperferritinemia,Zero-shot,Enformer,0.095493,0.000000,0.095
3,Mendelian traits,Hyperferritinemia,Zero-shot,Borzoi,0.131633,0.000000,0.132
4,Mendelian traits,Beta-thalassemia,Zero-shot,CADD,0.825912,0.000000,0.826
...,...,...,...,...,...,...,...
143,Complex traits,Balding Type 4,Zero-shot,Borzoi,0.246771,0.038382,0.247
144,Complex traits,Blood clot in the leg,Zero-shot,CADD,0.481151,0.099264,0.481
145,Complex traits,Blood clot in the leg,Zero-shot,GPN-MSA,0.550735,0.110382,0.551
146,Complex traits,Blood clot in the leg,Zero-shot,Enformer,0.454497,0.084802,0.454


In [32]:
#df.loc[(df.model.isin(["Enformer", "Borzoi", "GPN"])) & (df.subset!="non_missense"), "value"] = "-"

In [33]:
"""
bold_values = [
    ("mendelian_matched_9", "all", "Zero-shot", "CADD"),
    ("mendelian_matched_9", "all", "Zero-shot", "GPN-MSA"),
    ("mendelian_matched_9", "missense_variant", "Zero-shot", "CADD"),
    ("mendelian_matched_9", "missense_variant", "Zero-shot", "GPN-MSA"),
    ("mendelian_matched_9", "non_missense", "Zero-shot", "CADD"),
    ("mendelian_matched_9", "non_missense", "Zero-shot", "GPN-MSA"),

    ("mendelian_matched_9", "all", "Linear probing", "CADD"),
    ("mendelian_matched_9", "all", "Linear probing", "Ensemble"),
    ("mendelian_matched_9", "missense_variant", "Linear probing", "CADD"),
    ("mendelian_matched_9", "missense_variant", "Linear probing", "Ensemble"),
    ("mendelian_matched_9", "non_missense", "Linear probing", "CADD"),
    ("mendelian_matched_9", "non_missense", "Linear probing", "Ensemble"),

    ("gwas_matched_9", "all", "Zero-shot", "CADD"),
    ("gwas_matched_9", "all", "Zero-shot", "GPN-MSA"),
    ("gwas_matched_9", "missense_variant", "Zero-shot", "CADD"),
    #("gwas_matched_9", "missense_variant", "Zero-shot", "GPN-MSA"),
    ("gwas_matched_9", "non_missense", "Zero-shot", "Enformer"),
    ("gwas_matched_9", "non_missense", "Zero-shot", "Borzoi"),

    ("gwas_matched_9", "all", "Linear probing", "Ensemble"),
    ("gwas_matched_9", "missense_variant", "Linear probing", "Ensemble"),
    ("gwas_matched_9", "non_missense", "Linear probing", "Ensemble"),
]

for dataset, subset, modality, model in bold_values:
    mask = (
        (df.dataset==dataset) & (df.subset==subset) &
        (df.modality==modality) & (df.model==model)
    )
    df.loc[mask, "value"] = r"\textbf{" + df.loc[mask, "value"] + "}"
""";

In [34]:
#df.modality = df.modality.map({
#    "Zero-shot": r"\textbf{Zero-shot}",
#    "Linear probing": r"\textbf{Linear probing}",
#})

In [35]:
df = df.pivot_table(
    columns=["modality", "model"],
    index=[
        #"dataset",
        "subset",
    ],
    values="value",
    aggfunc="first", sort=False,
)
df = df.fillna("-")
df

modality,Zero-shot,Zero-shot,Zero-shot,Zero-shot,Linear probing,Linear probing,Linear probing,Linear probing
model,CADD,GPN-MSA,Enformer,Borzoi,CADD,GPN-MSA,Enformer,Borzoi
subset,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2
Hyperferritinemia,0.957,0.964,0.095,0.132,-,-,-,-
Beta-thalassemia,0.826,0.803,0.476,0.625,-,-,-,-
Pulmonary fibrosis,0.944,0.863,0.451,0.356,-,-,-,-
Hemophilia B,0.895,0.753,0.948,0.905,-,-,-,-
Cartilage-hair hypoplasia,0.786,0.844,0.509,0.363,-,-,-,-
Preaxial polydactyly II,0.916,0.929,0.177,0.185,-,-,-,-
Hypercholesterolemia-1,0.765,0.791,0.79,0.801,-,-,-,-
Dwarfism (MOPD1),0.893,0.967,0.683,0.319,-,-,-,-
Adult height,0.407,0.313,0.267,0.292,0.383,0.330,0.294,0.281
Platelet count,0.281,0.234,0.352,0.355,0.315,0.267,0.385,0.372


In [36]:
#df.index.names = [None, None]
df.index.name = None
df.columns.names = [None, None]
#df.columns.name = None

In [37]:
print(df.to_latex(multicolumn_format='c', escape=False))

\begin{tabular}{lllllllll}
\toprule
 & \multicolumn{4}{c}{Zero-shot} & \multicolumn{4}{c}{Linear probing} \\
 & CADD & GPN-MSA & Enformer & Borzoi & CADD & GPN-MSA & Enformer & Borzoi \\
\midrule
Hyperferritinemia & 0.957 & 0.964 & 0.095 & 0.132 & - & - & - & - \\
Beta-thalassemia & 0.826 & 0.803 & 0.476 & 0.625 & - & - & - & - \\
Pulmonary fibrosis & 0.944 & 0.863 & 0.451 & 0.356 & - & - & - & - \\
Hemophilia B & 0.895 & 0.753 & 0.948 & 0.905 & - & - & - & - \\
Cartilage-hair hypoplasia & 0.786 & 0.844 & 0.509 & 0.363 & - & - & - & - \\
Preaxial polydactyly II & 0.916 & 0.929 & 0.177 & 0.185 & - & - & - & - \\
Hypercholesterolemia-1 & 0.765 & 0.791 & 0.790 & 0.801 & - & - & - & - \\
Dwarfism (MOPD1) & 0.893 & 0.967 & 0.683 & 0.319 & - & - & - & - \\
Adult height & 0.407 & 0.313 & 0.267 & 0.292 & 0.383 & 0.330 & 0.294 & 0.281 \\
Platelet count & 0.281 & 0.234 & 0.352 & 0.355 & 0.315 & 0.267 & 0.385 & 0.372 \\
Mean corpuscular volume & 0.342 & 0.288 & 0.340 & 0.383 & 0.299 & 0.316 & 0.4