## Summary

---

## Imports

In [1]:
import concurrent.futures
import itertools
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from scipy import stats
from sklearn import metrics, model_selection
from tqdm.auto import tqdm

Matplotlib created a temporary config/cache directory at /tmp/matplotlib-f6nvvi3w because the default path (/home/p/pmkim/strokach/.config/matplotlib) is not a writable directory; it is highly recommended to set the MPLCONFIGDIR environment variable to a writable directory, in particular to speed up the import of Matplotlib and to better support multiprocessing.


In [2]:
pd.set_option("max_columns", 1000)
pd.set_option("max_rows", 1000)

## Parameters

In [3]:
NOTEBOOK_DIR = Path("32_analyze_alphafold_wt").resolve()
NOTEBOOK_DIR.mkdir(exist_ok=True)

NOTEBOOK_DIR

PosixPath('/gpfs/fs0/scratch/p/pmkim/strokach/workspace/elaspic2-cagi6/notebooks/32_analyze_alphafold_wt')

## Load results

In [4]:
DATASET_NAME = "humsavar"
DATASET_PATH = str(
    NOTEBOOK_DIR.parent.joinpath("30_humsavar", "humsavar-gby-protein.parquet")
)
DATASET_ALN_PATH = str(
    NOTEBOOK_DIR.parent.joinpath("30_humsavar", "humsavar-gby-protein-waln.parquet")
)
TASK_COUNT = 612
TASK_COUNT_ALN = 12557

DATASET_NAME, DATASET_PATH, TASK_COUNT, TASK_COUNT_ALN

('humsavar',
 '/gpfs/fs0/scratch/p/pmkim/strokach/workspace/elaspic2-cagi6/notebooks/30_humsavar/humsavar-gby-protein.parquet',
 612,
 12557)

In [5]:
pfile = pq.ParquetFile(DATASET_PATH)

assert TASK_COUNT == pfile.num_row_groups

In [6]:
total_num_mutations = 0
for row_group in tqdm(range(pfile.num_row_groups)):
    num_mutations = (
        pfile.read_row_group(row_group, columns=["mutation"])
        .to_pandas()["mutation"]
        .str.len()
        .sum()
    )

    total_num_mutations += num_mutations

total_num_mutations

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=612.0), HTML(value='')))




61179

In [7]:
pfile_aln = pq.ParquetFile(DATASET_ALN_PATH)

assert TASK_COUNT_ALN == pfile_aln.num_row_groups

In [8]:
total_num_aln_mutations = 0
for row_group in tqdm(range(pfile_aln.num_row_groups)):
    num_mutations = (
        pfile_aln.read_row_group(row_group, columns=["mutation"])
        .to_pandas()["mutation"]
        .str.len()
        .sum()
    )

    total_num_aln_mutations += num_mutations

total_num_aln_mutations

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=12557.0), HTML(value='')))




61174

In [9]:
def get_result_files(result_dir, task_count=TASK_COUNT):
    if "msa_analysis" in str(result_dir):
        prefix = "result"
    else:
        prefix = "shard"

    present_files = []
    missing_files = []
    for i in tqdm(range(1, task_count + 1)):
        path = result_dir.joinpath(f"{prefix}-{i}-of-{task_count}.parquet")
        if path.is_file():
            present_files.append(path)
        else:
            missing_files.append(path)
    return present_files, missing_files

In [10]:
def read_files(files, columns=None):
    dfs = []
    for file in tqdm(files):
        try:
            df = pq.read_table(file, columns=columns).to_pandas(integer_object_nulls=True)
        except pa.ArrowInvalid as error:
            print(error)
            continue
        dfs.append(df)
    return pd.concat(dfs, ignore_index=True)

### MSA

In [11]:
msa_result_dir = NOTEBOOK_DIR.parent.joinpath("31_run_msa_analysis", DATASET_NAME)

In [12]:
present_files, missing_files = get_result_files(msa_result_dir, TASK_COUNT_ALN)

assert len(missing_files) == 0
len(present_files), len(missing_files)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=12557.0), HTML(value='')))




(12557, 0)

In [13]:
result_msa_df = read_files(present_files)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=12557.0), HTML(value='')))




In [14]:
display(result_msa_df.head(2))
print(len(result_msa_df))

assert len(result_msa_df) == total_num_aln_mutations

Unnamed: 0,protein_id,mutation,effect,msa_count_wt,msa_count_mut,msa_count_total,msa_proba_wt,msa_proba_mut,msa_proba_total,msa_length,msa_proba,msa_H,msa_KL
0,A0A0C5B5G6,K14Q,US,3.0,0.0,3.0,-1.7492,-3.135494,-61.32359,3,-1.7492,-0.0,2.849038
1,P0CJ72,T13I,LB/B,5.0,9.0,16.0,-1.791759,-1.280934,-66.477422,16,-1.160392,0.862912,1.853562


61174


### AlphaFold WT

In [15]:
afwt_result_dir = NOTEBOOK_DIR.parent.joinpath("31_run_alphafold_wt", DATASET_NAME)

In [16]:
present_files, missing_files = get_result_files(afwt_result_dir, TASK_COUNT_ALN)

len(present_files), len(missing_files)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=12557.0), HTML(value='')))




(12552, 5)

In [17]:
protein_mutations_df = pq.read_table(DATASET_ALN_PATH, columns=["protein_id", "mutation", "effect"]).to_pandas()

In [18]:
def get_mutation_embeddings(idx, predictions):
    assert idx >= 0

    # Sequence
    def as_residue(x):
        return x[idx].astype(np.float32)

    def as_protein(x):
        return x.mean(axis=0).astype(np.float32)

    # Pairwise
    def average_over_rows(x):
        return x[idx, :, :].mean(axis=0)

    def average_over_columns(x):
        return x[:, idx, :].mean(axis=0)

    def extract_diagonal(x):
        return x[idx, idx, :]

    # MSA
    def extract_msa_mean(x):
        return x[:, idx, :].mean(axis=0)

    def extract_msa_max(x):
        return x[:, idx, :].max(axis=0)

    sequence_embeddings = {
        "experimentally_resolved": predictions["experimentally_resolved"],
        "predicted_lddt": predictions["predicted_lddt"],
        "msa_first_row": predictions["msa_first_row"],
        "single": predictions["single"],
        "structure_module": predictions["structure_module"],
    }

    pairwise_embeddings = {
        "distogram": predictions["distogram"],
        "aligned_confidence_probs": predictions["aligned_confidence_probs"],
        "pair": predictions["pair"],
    }

    msa_embeddings = {
        "masked_msa": predictions["masked_msa"],
        "msa": predictions["msa"],
    }

    output = {
        # Sequence
        "scores_residue_plddt": predictions["plddt"][idx],
        "scores_protein_plddt": np.mean(predictions["plddt"]),
        "scores_protein_max_predicted_aligned_error": predictions[
            "max_predicted_aligned_error"
        ],
        "scores_proten_ptm": predictions["ptm"],
        **{
            f"features_residue_{key}": as_residue(value)
            for key, value in sequence_embeddings.items()
        },
        **{
            f"features_protein_{key}": as_protein(value)
            for key, value in sequence_embeddings.items()
        },
        # Pairwise 2D
        "predicted_aligned_error_row": predictions["predicted_aligned_error"][idx, :]
        .mean()
        .item(),
        "predicted_aligned_error_col": predictions["predicted_aligned_error"][:, idx]
        .mean()
        .item(),
        "predicted_aligned_error_col": predictions["predicted_aligned_error"][
            idx, idx
        ].item(),
        # Pairwise 3D
        **{
            f"features_residue_row_{key}": average_over_rows(value)
            for key, value in pairwise_embeddings.items()
        },
        **{
            f"features_residue_col_{key}": average_over_columns(value)
            for key, value in pairwise_embeddings.items()
        },
        **{
            f"features_residue_diag_{key}": extract_diagonal(value)
            for key, value in pairwise_embeddings.items()
        },
        # MSA
        **{
            f"features_residue_mean_{key}": extract_msa_mean(value)
            for key, value in msa_embeddings.items()
        },
        **{
            f"features_residue_max_{key}": extract_msa_max(value)
            for key, value in msa_embeddings.items()
        },
    }

    return output

In [19]:
protein_mutation_lookup = protein_mutations_df.set_index("protein_id")["mutation"].to_dict()

In [21]:
# present_files = present_files[:20]

In [22]:
columns = [
    "protein_id",
    # Sequence
    "experimentally_resolved",
    "predicted_lddt",
    "msa_first_row",
    "single",
    "structure_module",
    "plddt",
    "max_predicted_aligned_error",
    "ptm",
    # Pairwise
    "distogram",
    # "distogram_bin_edges",
    "predicted_aligned_error",
    "aligned_confidence_probs",
    "pair",
    # MSA
    "masked_msa",
    "msa",
]


def worker(file):
    df = pq.read_table(file, columns=columns).to_pandas(integer_object_nulls=True)

    assert len(df) == 1
    row = df.iloc[0].to_dict()

    for column in [
        "distogram",
        "masked_msa",
        "predicted_aligned_error",
        "aligned_confidence_probs",
        "msa",
        "pair",
    ]:
        row[column] = np.stack(
            [np.stack(row[column][i]) for i in range(len(row[column]))]
        )

    mutations = protein_mutation_lookup[row["protein_id"]]

    worker_results = []
    for mutation in mutations:
        features = {
            f"alphafold_core_{key}_wt": value
            for key, value in get_mutation_embeddings(
                int(mutation[1:-1]) - 1, row
            ).items()
        }
        worker_results.append(
            {
                "protein_id": row["protein_id"],
                "mutation": mutation,
            }
            | features
        )
    return worker_results


with concurrent.futures.ProcessPoolExecutor(2) as pool:
    results = list(tqdm(pool.map(worker, present_files), total=len(present_files)))

result_af_df = pd.DataFrame([r for rr in results for r in rr])
len(result_af_df)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=20.0), HTML(value='')))




28

## Combine results

In [23]:
assert len(result_msa_df) == len(result_msa_df[["protein_id", "mutation"]].drop_duplicates())

In [24]:
result_df = result_msa_df.merge(result_af_df, on=["protein_id", "mutation"], how="left")

assert len(result_df) == len(result_df[["protein_id", "mutation"]].drop_duplicates())
assert not (
    set(result_msa_df[["protein_id", "mutation"]].apply(tuple, axis=1))
    | set(result_af_df[["protein_id", "mutation"]].apply(tuple, axis=1))
) - set(result_df[["protein_id", "mutation"]].apply(tuple, axis=1))

In [25]:
print(len(result_df))
result_df = result_df.dropna()

61174


In [26]:
display(result_df.head(2))
print(len(result_df))

Unnamed: 0,protein_id,mutation,effect,msa_count_wt,msa_count_mut,msa_count_total,msa_proba_wt,msa_proba_mut,msa_proba_total,msa_length,msa_proba,msa_H,msa_KL,alphafold_core_scores_residue_plddt_wt,alphafold_core_scores_protein_plddt_wt,alphafold_core_scores_protein_max_predicted_aligned_error_wt,alphafold_core_scores_proten_ptm_wt,alphafold_core_features_residue_experimentally_resolved_wt,alphafold_core_features_residue_predicted_lddt_wt,alphafold_core_features_residue_msa_first_row_wt,alphafold_core_features_residue_single_wt,alphafold_core_features_residue_structure_module_wt,alphafold_core_features_protein_experimentally_resolved_wt,alphafold_core_features_protein_predicted_lddt_wt,alphafold_core_features_protein_msa_first_row_wt,alphafold_core_features_protein_single_wt,alphafold_core_features_protein_structure_module_wt,alphafold_core_predicted_aligned_error_row_wt,alphafold_core_predicted_aligned_error_col_wt,alphafold_core_features_residue_row_distogram_wt,alphafold_core_features_residue_row_aligned_confidence_probs_wt,alphafold_core_features_residue_row_pair_wt,alphafold_core_features_residue_col_distogram_wt,alphafold_core_features_residue_col_aligned_confidence_probs_wt,alphafold_core_features_residue_col_pair_wt,alphafold_core_features_residue_diag_distogram_wt,alphafold_core_features_residue_diag_aligned_confidence_probs_wt,alphafold_core_features_residue_diag_pair_wt,alphafold_core_features_residue_mean_masked_msa_wt,alphafold_core_features_residue_mean_msa_wt,alphafold_core_features_residue_max_masked_msa_wt,alphafold_core_features_residue_max_msa_wt
0,A0A0C5B5G6,K14Q,US,3.0,0.0,3.0,-1.7492,-3.135494,-61.32359,3,-1.7492,-0.0,2.849038,56.406708,62.003188,31.75,0.027069,"[0.0023722572, 0.018356942, 0.10303657, 0.1508...","[-5.2305098, -5.6203985, -4.857919, -3.7483947...","[-3.567015, -5.2736416, -9.009683, -3.05795, 5...","[24.392054, 3.2788322, 21.41848, 21.365389, 6....","[0.005355336, 0.010748506, -0.0057989154, 0.00...","[0.13932697, 0.16886736, 0.2806713, 0.26485315...","[-5.4999924, -6.2546177, -5.384321, -4.418623,...","[-1.6651598, -2.4425163, -5.507199, -0.8483903...","[16.753532, 1.5071084, 15.772654, 13.925894, -...","[0.0041642357, 0.01525829, -0.0061539887, 0.00...",8.574256,0.25235,"[2.496045768260956, -4.968722492456436, -5.866...","[0.06330700959642854, 0.06389473706803983, 0.0...","[3.3381232991814613, 29.497473165392876, -2.04...","[2.496045768260956, -4.968722492456436, -5.866...","[0.0632605097234773, 0.062133513205480995, 0.0...","[5.810895625501871, 36.592444146052, 0.6604679...","[114.23722839355469, -3.2029566764831543, -26....","[0.9999160766601562, 0.0, 0.0, 0.0, 0.0, 0.0, ...","[23.421672821044922, 601.749267578125, 35.1634...","[-2.298722798899403, 1.3645524807333007, 2.009...","[-4.500693058404397, -0.7133069470172791, 5.53...","[-1.4940643310546875, 2.198543071746826, 2.012...","[-4.41256046295166, -0.7046337127685547, 5.566..."
1,P0CJ72,T13I,LB/B,5.0,9.0,16.0,-1.791759,-1.280934,-66.477422,16,-1.160392,0.862912,1.853562,64.43125,61.304924,31.75,0.114342,"[0.08796918, 0.03081914, -0.044761796, 0.01308...","[-5.798191, -7.272913, -6.322961, -5.6527224, ...","[1.0994258, -0.6997881, -10.329717, -0.7453602...","[-18.75544, 4.724907, -9.10808, -33.123413, -3...","[0.004412666, 0.0069596395, -0.0060222056, 0.0...","[0.15945913, 0.1073039, 0.168821, 0.14149028, ...","[-5.7160945, -6.8177466, -5.850517, -5.021598,...","[2.1370337, 4.9773483, -5.5687284, 2.8266225, ...","[2.7686567, 3.993637, 20.736738, 1.4975674, -2...","[0.0040223706, 0.012024378, -0.006008984, 0.00...",10.042964,0.25028,"[0.8779546022415161, -5.235870281855266, -5.63...","[0.04259219509306907, 0.052577671876254804, 0....","[1.9545368279019992, 28.045276482899983, -2.61...","[0.8779546022415161, -5.235870281855266, -5.63...","[0.04238582096998774, 0.05880330579505729, 0.0...","[2.3460903018712997, 26.394239125152428, -2.37...","[129.9250030517578, -8.64072036743164, -35.039...","[0.999992311000824, 0.0, 0.0, 0.0, 0.0, 0.0, 0...","[7.677867412567139, 692.3701171875, 22.9353370...","[-0.5468269462191214, -0.44086736358526185, -0...","[1.1172368567231603, -0.863215982913971, -2.52...","[1.603255271911621, -0.39963558316230774, -0.2...","[3.5523734092712402, 1.110595703125, 2.2950477..."


28


## Calculate deltas

In [27]:
for column in list(result_df):
    if not column.endswith("_mut"):
        continue

    column_wt = column.removesuffix("_mut") + "_wt"
    if column_wt not in result_df:
        print(column_wt)
        continue

    column_change = column.removesuffix("_mut") + "_change"
    result_df[column_change] = result_df[column_wt] - result_df[column]
    del result_df[column]

## Save results

In [28]:
output_file = NOTEBOOK_DIR.joinpath("combined-results.parquet")

output_file

PosixPath('/gpfs/fs0/scratch/p/pmkim/strokach/workspace/elaspic2-cagi6/notebooks/32_analyze_alphafold_wt/combined-results.parquet')

In [29]:
pq.write_table(
    pa.Table.from_pandas(result_df, preserve_index=False),
    output_file,
    row_group_size=10_000,
)

## Exploratory data analysis

In [30]:
proteinsolver_columnms = [c for c in result_df if c.startswith("proteinsolver_")]
alphafold_columns = [c for c in result_df if c.startswith("alphafold_")]

In [31]:
result_df["effect"].value_counts()

LB/B    18
LP/P     7
US       3
Name: effect, dtype: int64

In [32]:
effect_map = {
    "LB/B": -1,
    "US": 0,
    "LP/P": 1,
}

result_df["effect_score"] = result_df["effect"].map(effect_map)

In [33]:
row = next(result_df.itertuples(index=False))._asdict()

for column, data in row.items():
    if isinstance(data, (list, tuple, np.ndarray)):
        result_df[[f"{column}_{i}" for i in range(len(data))]] = np.vstack(result_df[column].values)
        del result_df[column]

  self[col] = igetitem(value, i)


In [34]:
result_df = result_df.copy()

In [35]:
score_columns = list(result_df.select_dtypes(exclude=["object"]))
score_columns.remove("effect_score")

df = result_df.dropna(subset=score_columns + ["effect_score"])
df = df[df["effect_score"] != 0].reset_index(drop=True)

scores = []
for col in tqdm(score_columns):
    corr = stats.spearmanr(df["effect_score"], df[col])
    auc = metrics.roc_auc_score(df["effect_score"] > 0, df[col])
    precision = metrics.average_precision_score(df["effect_score"] > 0, df[col])
    scores.append(
        {
            "column": col,
            "spearman_rho": corr[0],
            "auc": auc,
            "average_precision": precision,
        }
    )
score_df = pd.DataFrame(scores)

score_df

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=3564.0), HTML(value='')))






Unnamed: 0,column,spearman_rho,auc,average_precision
0,msa_count_wt,0.358408,0.730159,0.651184
1,msa_count_total,0.185348,0.619048,0.573220
2,msa_proba_wt,0.345916,0.722222,0.576374
3,msa_proba_total,-0.407687,0.238095,0.210891
4,msa_length,0.198009,0.626984,0.578822
...,...,...,...,...
3559,alphafold_core_features_residue_max_msa_wt_251,0.370625,0.738095,0.616407
3560,alphafold_core_features_residue_max_msa_wt_252,-0.135896,0.412698,0.259549
3561,alphafold_core_features_residue_max_msa_wt_253,-0.037062,0.476190,0.320248
3562,alphafold_core_features_residue_max_msa_wt_254,0.037062,0.523810,0.455318


In [36]:
(
    score_df.assign(value_abs=lambda df: df["auc"].abs())
    .sort_values("value_abs", ascending=False)
    .drop("value_abs", axis=1)
    .head(1000)
)

Unnamed: 0,column,spearman_rho,auc,average_precision
2955,alphafold_core_features_residue_diag_pair_wt_77,0.716541,0.960317,0.940476
3519,alphafold_core_features_residue_max_msa_wt_211,0.691833,0.944444,0.754592
3320,alphafold_core_features_residue_max_msa_wt_12,0.679479,0.936508,0.839549
694,alphafold_core_features_residue_single_wt_335,0.679479,0.936508,0.890909
189,alphafold_core_features_residue_msa_first_row_...,0.667124,0.928571,0.840136
3499,alphafold_core_features_residue_max_msa_wt_191,0.65477,0.920635,0.784354
1516,alphafold_core_features_protein_single_wt_46,0.61878,0.896825,0.75
11,alphafold_core_scores_proten_ptm_wt,0.61878,0.896825,0.857143
3000,alphafold_core_features_residue_diag_pair_wt_122,0.617708,0.896825,0.82381
3094,alphafold_core_features_residue_mean_msa_wt_65,0.605354,0.888889,0.807463
