## Summary

---

## Imports

In [None]:
import contextlib
import itertools
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from IPython.display import SVG, Image
from scipy import stats
from sklearn import metrics, model_selection
from tqdm.auto import tqdm

In [None]:
pd.set_option("max_columns", 1000)
pd.set_option("max_rows", 1000)

## Parameters

In [None]:
NOTEBOOK_DIR = Path("33_analyze_alphafold").resolve()
NOTEBOOK_DIR.mkdir(exist_ok=True)

NOTEBOOK_DIR

In [None]:
DATASET_NAME = "humsavar"
DATASET_PATH = str(
    NOTEBOOK_DIR.parent.joinpath("30_humsavar", "humsavar-gby-protein.parquet")
)
DATASET_ALN_PATH = str(
    NOTEBOOK_DIR.parent.joinpath("30_humsavar", "humsavar-gby-protein-waln.parquet")
)

In [None]:
ALPHAFOLD_FEATURES_DIR = NOTEBOOK_DIR.parent.joinpath(
    "32_process_alphafold", DATASET_NAME, "run-alphafold-wt"
)

ALPHAFOLD_FEATURES_DIR

In [None]:
# ALPHAFOLD_FEATURES_DIR = NOTEBOOK_DIR.parent.joinpath("32_process_alphafold", DATASET_NAME, "run-alphafold-wt-template")

# ALPHAFOLD_FEATURES_DIR

## Load data

### AlphaFold featuers

In [None]:
# dfs = []
# for filepath in tqdm(list(ALPHAFOLD_FEATURES_DIR.glob("**/features-shard-*.parquet"))):
#     try:
#         df = pq.read_table(filepath).to_pandas()
#     except pa.ArrowInvalid as error:
#         print(f"Encountered error for file {filepath}.")
#         continue
#     dfs.append(df)

# result_df = pd.concat(dfs, ignore_index=True)
# del dfs

In [None]:
total = 26

dfs = []
for i in tqdm(range(1, total + 1)):
    filepath = ALPHAFOLD_FEATURES_DIR.joinpath(
        f"features-shard-{i:04d}-of-{total:04d}.parquet"
    )
    try:
        df = pq.read_table(filepath).to_pandas()
    except pa.ArrowInvalid:
        print(filepath)
        continue
    dfs.append(df)

af_results_df = pd.concat(dfs, ignore_index=True)
del dfs

In [None]:
af_results_df.head()

### Mutation effects

In [None]:
protein_mutations_df = pq.read_table(
    DATASET_ALN_PATH, columns=["protein_id", "mutation", "effect"]
).to_pandas()

In [None]:
protein_mutations_df.head()

In [None]:
dfs = []
for tup in tqdm(protein_mutations_df.itertuples(), total=len(protein_mutations_df)):
    assert len(tup.mutation) == len(tup.effect)
    rows = []
    for mutation, effect in zip(tup.mutation, tup.effect):
        rows.append((tup.protein_id, mutation, effect))
    df = pd.DataFrame(rows, columns=["protein_id", "mutation", "effect"])
    dfs.append(df)

protein_mutation_effects_df = pd.concat(dfs, ignore_index=True)
del dfs

In [None]:
protein_mutation_effects_df

## Process data

### Combine datasets

In [None]:
@contextlib.contextmanager
def tracker(original_df):
    def wrapped(df):
        assert len(original_df) == len(df), (len(original_df), len(df))
        return df

    yield wrapped

In [None]:
with tracker(af_results_df) as track:
    result_df = track(
        af_results_df.merge(protein_mutation_effects_df, on=["protein_id", "mutation"])
    )

In [None]:
effect_map = {
    "LB/B": -1,
    "US": 0,
    "LP/P": 1,
}

result_df["effect_score"] = result_df["effect"].map(effect_map)

In [None]:
result_df.head()

### Calculate deltas

In [None]:
for column in list(result_df):
    if not column.endswith("_mut"):
        continue

    column_wt = column.removesuffix("_mut") + "_wt"
    if column_wt not in result_df:
        print(column_wt)
        continue

    column_change = column.removesuffix("_mut") + "_change"
    result_df[column_change] = result_df[column_wt] - result_df[column]
    del result_df[column]

In [None]:
result_df.head()

## Exploratory data analysis

In [None]:
result_df["effect"].value_counts()

In [None]:
result_df["effect_score"].value_counts()

In [None]:
row = next(result_df.itertuples(index=False))._asdict()

column_group_map = {}
for column, data in row.items():
    if isinstance(data, (list, tuple, np.ndarray)):
        new_columns = [f"{column}_{i}" for i in range(len(data))]
        result_df[new_columns] = np.vstack(result_df[column].values)
        del result_df[column]
        column_group_map |= {nc: column for nc in new_columns}

In [None]:
result_df = result_df.copy()

In [None]:
score_columns = list(result_df.select_dtypes(exclude=["object"]))
score_columns.remove("effect_score")

df = result_df.dropna(subset=score_columns + ["effect_score"])
df = df[df["effect_score"] != 0].reset_index(drop=True)

scores = []
for col in tqdm(score_columns):
    corr = stats.spearmanr(df["effect_score"], df[col])
    auc = metrics.roc_auc_score(df["effect_score"] > 0, df[col])
    precision = metrics.average_precision_score(df["effect_score"] > 0, df[col])
    scores.append(
        {
            "column": col,
            "spearman_rho": corr[0],
            "auc": auc,
            "average_precision": precision,
        }
    )
score_df = pd.DataFrame(scores)

score_df

In [None]:
score_df["spearman_rho_adj"] = score_df["spearman_rho"].abs()
score_df["auc_adj"] = np.where(
    score_df["auc"] > 0.5,
    score_df["auc"],
    1 - score_df["auc"],
)
score_df["average_precision_adj"] = np.where(
    score_df["average_precision"] > 0.5,
    score_df["average_precision"],
    1 - score_df["average_precision"],
)

In [None]:
score_df["column_class"] = score_df["column"].map(column_group_map)
score_df["column_class"] = np.where(
    score_df["column_class"].isnull(), score_df["column"], score_df["column_class"]
)
assert score_df["column_class"].notnull().all()

In [None]:
score_agg_df = (
    score_df.groupby("column_class")
    .agg("max")
    .sort_values("auc_adj", ascending=False)
    .reset_index()
)

score_agg_df

In [None]:
output_file = NOTEBOOK_DIR.joinpath(
    DATASET_NAME, f"{ALPHAFOLD_FEATURES_DIR.name}.parquet"
)
output_file.parent.mkdir(exist_ok=True)

output_file

In [None]:
pq.write_table(pa.Table.from_pandas(score_agg_df, preserve_index=False), output_file)

## Compare with and without templates

In [None]:
score_agg_otemplates_df = (
    pq.read_table(NOTEBOOK_DIR.joinpath(DATASET_NAME, "run-alphafold-wt.parquet"))
    .to_pandas()
    .set_index("column_class")
)

In [None]:
score_agg_xtemplates_df = (
    pq.read_table(
        NOTEBOOK_DIR.joinpath(DATASET_NAME, "run-alphafold-wt-template.parquet")
    )
    .to_pandas()
    .set_index("column_class")
)

In [None]:
score_agg_otemplates_df.head()

In [None]:
score_agg_xtemplates_df.head()

### Make plots

In [None]:
column = "spearman_rho_adj"

In [None]:
name_dict = {
    "spearman_rho_adj": "Spearman ρ",
    "auc_adj": "ROC AUC",
    "average_precision_adj": "Average precision",
}

In [None]:
xlim_dict = {
    "spearman_rho_adj": (0, 0.62),
    "auc_adj": (0.4, 0.86),
    "average_precision_adj": (0.4, 0.86),
}

In [None]:
df = (
    score_agg_otemplates_df[
        score_agg_otemplates_df.index == score_agg_otemplates_df["column"]
    ]
    .sort_values(column, ascending=True)
    .dropna()
)

ind = np.arange(len(df))
width = 0.35

fig, ax = plt.subplots(figsize=(12, 1 + len(df) * 0.3))
rects1 = ax.barh(
    ind + width / 2,
    df[column],
    width,
    label="No structure",
)
rects2 = ax.barh(
    ind - width / 2,
    score_agg_xtemplates_df[column].loc[df.index],
    width,
    label="Structure",
)

ax.set_xlabel(name_dict[column])
ax.set_yticks(ind)
ax.set_yticklabels(df.index)
ax.set_ylim(-0.7, len(df) - 0.3)
ax.set_xlim(*xlim_dict[column])
ax.legend(loc="lower right")

fig.subplots_adjust(left=0.38, bottom=0.06, right=0.99, top=0.99)

output_file_stem = NOTEBOOK_DIR.joinpath(
    f"{column.replace('_adj', '').replace('_', '-')}-scores-ox-template.png"
)
fig.savefig(output_file_stem.with_suffix(".svg"), dpi=300)
fig.savefig(output_file_stem.with_suffix(".pdf"), dpi=300)
fig.savefig(output_file_stem.with_suffix(".png"), dpi=300)

In [None]:
SVG(output_file_stem.with_suffix(".svg"))

In [None]:
df = (
    score_agg_otemplates_df[
        score_agg_otemplates_df.index != score_agg_otemplates_df["column"]
    ]
    .sort_values(column, ascending=True)
    .dropna()
)

ind = np.arange(len(df))
width = 0.35

fig, ax = plt.subplots(figsize=(12, 1 + len(df) * 0.3))
rects1 = ax.barh(
    ind + width / 2,
    df[column],
    width,
    label="No structure",
)
rects2 = ax.barh(
    ind - width / 2,
    score_agg_xtemplates_df[column].loc[df.index],
    width,
    label="Structure",
)

ax.set_xlabel(name_dict[column])
ax.set_yticks(ind)
ax.set_yticklabels(df.index)
ax.set_ylim(-0.7, len(df) - 0.3)
ax.set_xlim(*xlim_dict[column])
ax.legend(loc="lower right")

fig.subplots_adjust(left=0.38, bottom=0.06, right=0.99, top=0.99)

output_file_stem = NOTEBOOK_DIR.joinpath(
    f"{column.replace('_adj', '').replace('_', '-')}-features-ox-template.png"
)
fig.savefig(output_file_stem.with_suffix(".svg"), dpi=300)
fig.savefig(output_file_stem.with_suffix(".pdf"), dpi=300)
fig.savefig(output_file_stem.with_suffix(".png"), dpi=300)

In [None]:
SVG(output_file_stem.with_suffix(".svg"))