## Summary

---

## Imports

In [None]:
import functools
from pathlib import Path

import elaspic2 as el2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from scipy import stats
from sklearn import metrics, model_selection
from tqdm.auto import tqdm

In [None]:
pd.set_option("max_columns", 1000)
pd.set_option("max_rows", 1000)

## Parameters

In [None]:
NOTEBOOK_DIR = Path("37_cagi6_sherloc_combine_results").resolve()
NOTEBOOK_DIR.mkdir(exist_ok=True)

NOTEBOOK_DIR

## Load results

In [None]:
DATASET_NAME = "cagi6-sherloc"
TASK_COUNT = 4182

DATASET_NAME, TASK_COUNT

In [None]:
def get_result_files(result_dir):
    if "msa_analysis" in str(result_dir):
        prefix = "result"
    else:
        prefix = "shard"

    present_files = []
    missing_files = []
    for i in tqdm(range(1, TASK_COUNT + 1)):
        path = result_dir.joinpath(f"{prefix}-{i}-of-{TASK_COUNT}.parquet")
        if path.is_file():
            present_files.append(path)
        else:
            missing_files.append(path)
    return present_files, missing_files

In [None]:
def read_files(files):
    dfs = []
    for file in tqdm(files):
        df = pq.read_table(file).to_pandas(integer_object_nulls=True)
        dfs.append(df)
    return pd.concat(dfs, ignore_index=True)

### ProteinSolver

In [None]:
ps_result_dir = NOTEBOOK_DIR.parent.joinpath("31_run_proteinsolver", DATASET_NAME)

In [None]:
present_files, missing_files = get_result_files(ps_result_dir)

assert len(missing_files) == 0
len(present_files), len(missing_files)

In [None]:
result_ps_df = read_files(present_files)

In [None]:
display(result_ps_df.head(2))
print(len(result_ps_df))  # 221816

### ProtBert

In [None]:
pb_result_dir = NOTEBOOK_DIR.parent.joinpath("31_run_protbert", DATASET_NAME)

In [None]:
present_files, missing_files = get_result_files(pb_result_dir)

assert len(missing_files) == 0
len(present_files), len(missing_files)

In [None]:
result_pb_df = read_files(present_files)

In [None]:
display(result_pb_df.head(2))
print(len(result_pb_df))  # 221793

### MSA

In [None]:
msa_result_dir = NOTEBOOK_DIR.parent.joinpath("31_run_msa_analysis", DATASET_NAME)

In [None]:
present_files, missing_files = get_result_files(msa_result_dir)

assert len(missing_files) == 0
len(present_files), len(missing_files)

In [None]:
result_msa_df = read_files(present_files)

In [None]:
display(result_msa_df.head(2))
print(len(result_msa_df))  # 221793

### Rosetta

In [None]:
ra_result_dir = NOTEBOOK_DIR.parent.joinpath("31_run_rosetta_ddg", DATASET_NAME)

In [None]:
present_files, missing_files = get_result_files(ra_result_dir)

assert len(missing_files) == 0
len(present_files), len(missing_files)

In [None]:
result_ra_df = read_files(present_files)

In [None]:
display(result_ra_df.head(2))
print(len(result_ra_df))  # 216899

### AlphaFold

In [None]:
af_result_dir = NOTEBOOK_DIR.parent.joinpath("31_run_alphafold", DATASET_NAME)

In [None]:
present_files, missing_files = get_result_files(af_result_dir)

len(present_files), len(missing_files)

In [None]:
result_af_df = read_files(present_files)

In [None]:
display(result_af_df.head(2))
print(len(result_af_df))  # 110554

### AlphaFold WT

In [None]:
afwt_result_dir = NOTEBOOK_DIR.parent.joinpath("31_run_alphafold_wt", DATASET_NAME)

In [None]:
present_files, missing_files = get_result_files(afwt_result_dir)

len(present_files), len(missing_files)

In [None]:
result_afwt_df = read_files(present_files).set_index("protein_id")

In [None]:
display(result_afwt_df.head(2))
print(len(result_afwt_df))

In [None]:
af_finished_protein_ids = set(
    result_af_df[["protein_id", "mutation", "mutation_id", "effect"]].apply(tuple, axis=1)
)

af_missing_protein_ids = (
    set(result_pb_df[["protein_id", "mutation", "mutation_id", "effect"]].apply(tuple, axis=1))
    - af_finished_protein_ids
)

len(af_missing_protein_ids)

In [None]:
def get_mutation_embeddings(idx, predictions):
    assert idx >= 0

    def as_residue(x):
        return x[idx].astype(np.float32)

    def as_protein(x):
        return x.mean(axis=0).astype(np.float32)

    embeddings = {
        "experimentally_resolved": predictions["experimentally_resolved"],
        "predicted_lddt": predictions["predicted_lddt"],
        "msa_first_row": predictions["msa_first_row"],
        "single": predictions["single"],
        "structure_module": predictions["structure_module"],
    }

    output = {
        "scores_residue_plddt": predictions["plddt"][idx],
        "scores_protein_plddt": np.mean(predictions["plddt"]),
        "scores_protein_max_predicted_aligned_error": predictions["max_predicted_aligned_error"],
        "scores_proten_ptm": predictions["ptm"],
        **{f"features_residue_{key}": as_residue(value) for key, value in embeddings.items()},
        **{f"features_protein_{key}": as_protein(value) for key, value in embeddings.items()},
    }

    return output


# get_mutation_embeddings(0, result_afwt_df.loc["Q8N5M1"])

In [None]:
results = []
missing = 0
proteins_with_multiple_rows = set()
for protein_id, mutation, mutation_id, effect in tqdm(af_missing_protein_ids):
    try:
        predictions = result_afwt_df.loc[protein_id]
    except KeyError:
        missing += 1
        continue

    if isinstance(predictions, pd.DataFrame):
        #         print(f"Multiple rows encountered for {protein_id=} ({len(predictions)=}).")
        proteins_with_multiple_rows.add(protein_id)
        predictions = predictions.iloc[0]

    features = {
        f"alphafold_core_{key}_wt": value
        for key, value in get_mutation_embeddings(int(mutation[1:-1]) - 1, predictions).items()
    }

    results.append(
        {
            "protein_id": protein_id,
            "mutation": mutation,
            "mutation_id": mutation_id,
            "effect": effect,
        }
        | features
    )

result_af2_df = pd.DataFrame(results)
len(result_af2_df), missing

## Combine results

In [None]:
len(result_ps_df) == len(result_ps_df["mutation_id"].unique())

In [None]:
result_df = (
    result_ps_df.merge(
        result_pb_df, on=["protein_id", "mutation", "mutation_id", "effect"], how="left"
    )
    .merge(result_msa_df, on=["protein_id", "mutation", "mutation_id"], how="left")
    .merge(result_ra_df, on=["protein_id", "mutation", "mutation_id", "effect"], how="left")
    .merge(
        pd.concat([result_af_df, result_af2_df], ignore_index=True),
        on=["protein_id", "mutation", "mutation_id", "effect"],
        how="left",
    )
)

assert len(result_df) == len(result_df["mutation_id"].unique())
assert not (
    set(result_pb_df["mutation_id"])
    | set(result_ra_df["mutation_id"])
    | set(result_af_df["mutation_id"])
) - set(result_df["mutation_id"])

In [None]:
display(result_df.head(2))
print(len(result_df))

## Calculate EL2 score

In [None]:
model = el2.ELASPIC2()

In [None]:
proteinsolver_columnms = [c for c in result_df if c.startswith("proteinsolver_core_score")]
protbert_columns = [c for c in result_df if c.startswith("protbert_core_")]

el2_missing = result_df[proteinsolver_columnms + protbert_columns].isnull().any(axis=1)
result_df["el2_score"] = np.nan
result_df.loc[~el2_missing, "el2_score"] = model.predict_mutation_effect(
    [
        t._asdict()
        for t in result_df.loc[~el2_missing, proteinsolver_columnms + protbert_columns].itertuples(
            index=False
        )
    ]
)

## Calculate deltas

In [None]:
for column in list(result_df):
    if not column.endswith("_mut"):
        continue

    column_wt = column.removesuffix("_mut") + "_wt"
    if column_wt not in result_df:
        print(column_wt)
        continue

    column_change = column.removesuffix("_mut") + "_change"
    result_df[column_change] = result_df[column_wt] - result_df[column]
    del result_df[column]

## Encode mutation

In [None]:
amino_acids = list("ARNDCEQGHILKMFPSTWYV")
result_df["aa_wt_onehot"] = pd.get_dummies(result_df["mutation"].str[0]).apply(list, axis=1)
result_df["aa_mut_onehot"] = pd.get_dummies(result_df["mutation"].str[-1]).apply(list, axis=1)

## Save results

In [None]:
output_file = NOTEBOOK_DIR.joinpath("combined-results.parquet")

output_file

In [None]:
pq.write_table(
    pa.Table.from_pandas(result_df, preserve_index=False),
    output_file,
    row_group_size=10_000,
)

## Exploratory data analysis

In [None]:
proteinsolver_columnms = [c for c in result_df if c.startswith("proteinsolver_")]
protbert_columns = [c for c in result_df if c.startswith("protbert_")]
rosetta_columns = [c for c in result_df if c.startswith("rosetta_")]
alphafold_columns = [c for c in result_df if c.startswith("alphafold_")]

In [None]:
result_df["effect"].value_counts()

In [None]:
effect_map = {
    "Uncertain significance": 0,
    "Likely benign": -1,
    "Benign": -2,
    "Likely pathogenic": 1,
    "Pathogenic": 2,
}

result_df["effect_score"] = result_df["effect"].map(effect_map)

In [None]:
score_columns = [
    "el2_score",
    "proteinsolver_core_score_change",
    "protbert_core_score_change",
    "rosetta_dg_change",
        "alphafold_core_scores_residue_plddt_wt",
#         "alphafold_core_scores_protein_plddt_wt",
#         "alphafold_core_scores_protein_max_predicted_aligned_error_wt",
#         "alphafold_core_scores_proten_ptm_wt",
]

df = result_df.dropna(subset=score_columns + ["effect_score"])
df = df[df["effect_score"] != 0].reset_index(drop=True)

for col in score_columns:
    corr = stats.spearmanr(df["effect_score"], df[col])
    auc = metrics.roc_auc_score(df["effect_score"] > 0, df[col])
    print(col, corr[0], auc)

In [None]:
score_columns = [
    "el2_score",
    "proteinsolver_core_score_change",
    "protbert_core_score_change",
    "rosetta_dg_change",
        "alphafold_core_scores_residue_plddt_wt",
    "alphafold_core_scores_residue_plddt_change",
#         "alphafold_core_scores_protein_plddt_wt",
#         "alphafold_core_scores_protein_max_predicted_aligned_error_wt",
#         "alphafold_core_scores_proten_ptm_wt",
]

for column in score_columns:
    print(f"{column} {result_df[column].isnull().sum()}")

In [None]:
rosetta_