## Summary

```bash
export DATASET_NAME="humsavar"
# export RUN_ALPHAFOLD_NOTEBOOK_NAME="31_run_alphafold_wt"
export RUN_ALPHAFOLD_NOTEBOOK_NAME="31_run_alphafold_wt_template"

export NOTEBOOK_PATH="$(realpath 32_process_alphafold.ipynb)"
export ORIGINAL_ARRAY_TASK_COUNT=26

sbatch --export DATASET_NAME,RUN_ALPHAFOLD_NOTEBOOK_NAME,NOTEBOOK_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=1-9 --ntasks-per-node=32 --time 6:00:00 --mem=240G --account=rrg-pmkim ../scripts/run_notebook_cpu.sh

sbatch --export DATASET_NAME,RUN_ALPHAFOLD_NOTEBOOK_NAME,NOTEBOOK_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=10-18 --ntasks-per-node=16 --time 6:00:00 --mem=240G --account=rrg-pmkim ../scripts/run_notebook_cpu.sh

sbatch --export DATASET_NAME,RUN_ALPHAFOLD_NOTEBOOK_NAME,NOTEBOOK_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=19-26 --ntasks-per-node=8 --time 6:00:00 --mem=240G --account=rrg-pmkim ../scripts/run_notebook_cpu.sh



sbatch --export DATASET_NAME,RUN_ALPHAFOLD_NOTEBOOK_NAME,NOTEBOOK_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=1,2 --ntasks-per-node=32 --time 6:00:00 --mem=240G --account=rrg-pmkim --mail-user=alexey.strokach@kimlab.org ../scripts/run_notebook_cpu.sh

sbatch --export DATASET_NAME,RUN_ALPHAFOLD_NOTEBOOK_NAME,NOTEBOOK_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=24 --ntasks-per-node=8 --time 6:00:00 --mem=240G --account=rrg-pmkim --mail-user=alexey.strokach@kimlab.org ../scripts/run_notebook_cpu.sh

```

---

## Imports

In [None]:
import concurrent.futures
import itertools
import os
import pickle
import string
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import scipy.special as sps
from alphafold.common import residue_constants
from scipy import stats
from sklearn import metrics, model_selection
from tqdm.auto import tqdm

In [None]:
pd.set_option("max_columns", 1000)
pd.set_option("max_rows", 1000)

## Parameters

In [None]:
NOTEBOOK_DIR = Path("32_process_alphafold").resolve()
NOTEBOOK_DIR.mkdir(exist_ok=True)

NOTEBOOK_DIR

In [None]:
CPU_COUNT = max(1, len(os.sched_getaffinity(0)))

CPU_COUNT

In [None]:
TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")
TASK_COUNT = os.getenv("ORIGINAL_ARRAY_TASK_COUNT") or os.getenv(
    "SLURM_ARRAY_TASK_COUNT"
)
DATASET_NAME = os.getenv("DATASET_NAME")
RUN_ALPHAFOLD_NOTEBOOK_NAME = os.getenv(
    "RUN_ALPHAFOLD_NOTEBOOK_NAME", "31_run_alphafold_wt_template"
)

TASK_ID = int(TASK_ID) if TASK_ID is not None else None
TASK_COUNT = int(TASK_COUNT) if TASK_COUNT is not None else None

TASK_ID, TASK_COUNT, DATASET_NAME, RUN_ALPHAFOLD_NOTEBOOK_NAME

In [None]:
DEBUG = TASK_ID is None

if DEBUG:
    TASK_ID = 25
    TASK_COUNT = 26
    DATASET_NAME = "humsavar"
    RUN_ALPHAFOLD_NOTEBOOK_NAME = "31_run_alphafold_wt_template"
else:
    assert TASK_ID is not None
    assert TASK_COUNT is not None
    assert DATASET_NAME is not None
    assert RUN_ALPHAFOLD_NOTEBOOK_NAME is not None

TASK_ID, TASK_COUNT, DATASET_NAME, RUN_ALPHAFOLD_NOTEBOOK_NAME

In [None]:
# DATASET_NAME = "humsavar"
# DATASET_PATH = str(
#     NOTEBOOK_DIR.parent.joinpath("30_humsavar", "humsavar-gby-protein.parquet")
# )
# DATASET_ALN_PATH = str(
#     NOTEBOOK_DIR.parent.joinpath("30_humsavar", "humsavar-gby-protein-waln.parquet")
# )
# TASK_COUNT = 612
# TASK_COUNT_ALN = 12557

# DATASET_NAME, DATASET_PATH, TASK_COUNT, TASK_COUNT_ALN

In [None]:
DATASET_NAME = "humsavar"
DATASET_PATH = str(
    NOTEBOOK_DIR.parent.joinpath("30_humsavar", "humsavar-gby-protein.parquet")
)
DATASET_ALN_PATH = str(
    NOTEBOOK_DIR.parent.joinpath("30_humsavar", "humsavar-gby-protein-waln.parquet")
)
DATASET_TASK_COUNT = 612
DATASET_ALN_TASK_COUNT = 12557

DATASET_NAME, DATASET_PATH, DATASET_TASK_COUNT, DATASET_ALN_TASK_COUNT

In [None]:
AF_RESULT_DIR = NOTEBOOK_DIR.parent.joinpath(RUN_ALPHAFOLD_NOTEBOOK_NAME, DATASET_NAME)

AF_RESULT_DIR

In [None]:
output_dir = NOTEBOOK_DIR.joinpath(
    DATASET_NAME, AF_RESULT_DIR.parent.name.strip(string.digits + "_").replace("_", "-")
)
output_dir.mkdir(exist_ok=True)

output_dir

## Load results

In [None]:
def get_result_files(result_dir, task_count=DATASET_TASK_COUNT):
    if "msa_analysis" in str(result_dir):
        prefix = "result"
    else:
        prefix = "shard"

    present_files = []
    missing_files = []
    for i in tqdm(range(1, task_count + 1)):
        path = result_dir.joinpath(f"{prefix}-{i}-of-{task_count}.parquet")
        if path.is_file():
            present_files.append(path)
        else:
            missing_files.append(path)
    return present_files, missing_files

### Find finished files

In [None]:
present_files_cache = output_dir.joinpath("file-list.pickle")

present_files_cache

In [None]:
if present_files_cache.is_file():
    with present_files_cache.open("rb") as fin:
        present_files = pickle.load(fin)
else:
    present_files, missing_files = get_result_files(
        AF_RESULT_DIR, DATASET_ALN_TASK_COUNT
    )
    assert len(missing_files) == 0
    with present_files_cache.open("wb") as fout:
        pickle.dump(present_files, fout)

In [None]:
len(present_files)

### Map protein ids to mutations

In [None]:
protein_mutations_df = pq.read_table(
    DATASET_ALN_PATH, columns=["protein_id", "mutation", "effect"]
).to_pandas()

assert len(present_files) == len(protein_mutations_df)

In [None]:
display(protein_mutations_df.head(2))

In [None]:
protein_mutation_lookup = protein_mutations_df.set_index("protein_id")[
    "mutation"
].to_dict()

assert len(present_files) == len(protein_mutation_lookup)

### Process AlphaFold embeddings

In [None]:
columns = [
    "protein_id",
    # Sequence
    "single",
    "experimentally_resolved",
    "predicted_lddt",
    "msa_first_row",
    "structure_module",
    "max_predicted_aligned_error",
    "plddt",
    "ptm",
    # Pairwise
    "pair",
    "distogram",
    # "distogram_bin_edges",
    "aligned_confidence_probs",
    "predicted_aligned_error",
    # MSA
    "msa",
    "masked_msa",
]

In [None]:
def get_mutation_embeddings(mutation, predictions):
    wt, pos, mut = mutation[0], mutation[1:-1], mutation[-1]
    idx = int(pos) - 1
    assert idx >= 0

    af_wt_idx = residue_constants.restype_order_with_x[wt]
    af_mut_idx = residue_constants.restype_order_with_x[mut]

    # Sequence
    def as_residue(x):
        return x[idx].astype(np.float32)

    def as_protein(x):
        return x.mean(axis=0).astype(np.float32)

    # Pairwise
    def agg_rows(x, fn):
        return fn(x[idx, :, :], axis=0)

    def agg_columns(x, fn):
        return fn(x[:, idx, :], axis=0)

    def extract_diagonal(x):
        return x[idx, idx, :]

    # MSA
    def extract_msa_logit(value, aa_idx):
        return value[:, idx, aa_idx]

    def extract_msa_logproba(value, aa_idx):
        return sps.log_softmax(value, axis=-1)[:, idx, aa_idx]

    sequence_embeddings = {
        "experimentally_resolved": predictions["experimentally_resolved"],
        "predicted_lddt": predictions["predicted_lddt"],
        "msa_first_row": predictions["msa_first_row"],
        "single": predictions["single"],
        "structure_module": predictions["structure_module"],
    }

    pairwise_embeddings = {
        "distogram": predictions["distogram"],
        "aligned_confidence_probs": predictions["aligned_confidence_probs"],
        "pair": predictions["pair"],
    }

    msa_embeddings = {
        "msa": predictions["msa"],
    }

    output = {
        # Sequence
        "score_plddt": predictions["plddt"][idx].item(),
        "score_protein_plddt": predictions["plddt"].mean().item(),
        "score_protein_max_predicted_aligned_error": (
            predictions["max_predicted_aligned_error"]
        ),
        "score_protein_ptm": predictions["ptm"],
        # ...
        **{
            f"features_{key}": as_residue(value)
            for key, value in sequence_embeddings.items()
        },
        # Pairwise 2D
        "score_predicted_aligned_error_row_mean": (
            predictions["predicted_aligned_error"][idx, :].mean().item()
        ),
        "score_predicted_aligned_error_row_max": (
            predictions["predicted_aligned_error"][idx, :].max().item()
        ),
        "score_predicted_aligned_error_col_mean": (
            predictions["predicted_aligned_error"][:, idx].mean().item()
        ),
        "score_predicted_aligned_error_col_max": (
            predictions["predicted_aligned_error"][:, idx].max().item()
        ),
        "score_predicted_aligned_error_diag": (
            predictions["predicted_aligned_error"][idx, idx].item()
        ),
        # Pairwise 3D
        **{
            f"features_{key}_row_mean": agg_rows(value, np.mean)
            for key, value in pairwise_embeddings.items()
        },
        **{
            f"features_{key}_row_max": agg_rows(value, np.max)
            for key, value in pairwise_embeddings.items()
        },
        **{
            f"features_{key}_col_mean": agg_columns(value, np.mean)
            for key, value in pairwise_embeddings.items()
        },
        **{
            f"features_{key}_col_max": agg_columns(value, np.max)
            for key, value in pairwise_embeddings.items()
        },
        **{
            f"features_{key}_diag": extract_diagonal(value)
            for key, value in pairwise_embeddings.items()
        },
        # MSA
        **{
            f"score_msa_{process}_{agg}_{variant}": (
                agg_fn(process_fn(predictions["masked_msa"], variant_idx))
            )
            for process, process_fn in [
                ("logits", extract_msa_logit),
                ("logproba", extract_msa_logproba),
            ]
            for agg, agg_fn in [
                ("first", lambda x: x[0]),
                ("mean", lambda x: np.mean(x, axis=0)),
                ("max", lambda x: np.max(x, axis=0)),
            ]
            for variant, variant_idx in [
                ("wt", af_wt_idx),
                ("mut", af_mut_idx),
            ]
        },
        **{
            f"features_{key}_first": value[0, idx, :]
            for key, value in msa_embeddings.items()
        },
        **{
            f"features_{key}_mean": agg_columns(value, np.mean)
            for key, value in msa_embeddings.items()
        },
        **{
            f"features_{key}_max": agg_columns(value, np.max)
            for key, value in msa_embeddings.items()
        },
    }

    return output

In [None]:
def worker(file):
    df = pq.read_table(file, columns=columns).to_pandas(integer_object_nulls=True)

    assert len(df) == 1
    row = df.iloc[0].to_dict()
    del df

    for column in [
        "distogram",
        "masked_msa",
        "predicted_aligned_error",
        "aligned_confidence_probs",
        "msa",
        "pair",
    ]:
        row[column] = np.stack(
            [np.stack(row[column][i]) for i in range(len(row[column]))]
        )

    mutations = protein_mutation_lookup[row["protein_id"]]

    worker_results = []
    for mutation in mutations:
        features = {
            f"alphafold_core_{key}": value
            for key, value in get_mutation_embeddings(mutation, row).items()
        }
        worker_results.append(
            {
                "protein_id": row["protein_id"],
                "mutation": mutation,
            }
            | features
        )

    return worker_results

In [None]:
out = worker(present_files[0])

In [None]:
pd.DataFrame(out)

### Write results

In [None]:
cpu_count = CPU_COUNT

cpu_counts = {
    0: 32,
    1: 26,
    2: 22,
    3: 18,
    4: 15,
    5: 12,
    6: 10,
    7: 10,
    8: 9,
    9: 9,
    10: 9,
    11: 9,
    12: 8,
    13: 7,
    14: 7,
    15: 6,
    16: 6,
    17: 5,
    18: 5,
    19: 5,
    20: 4,
    21: 3,
    22: 2,
    23: 1,
    24: 1,
    25: 1,
}

In [None]:
chunk_size = 500

file_chunks = []
for i, start in enumerate(range(0, len(present_files), chunk_size)):
    file_chunk = present_files[start : start + chunk_size]
    file_chunks.append((i, cpu_counts[i], file_chunk))

In [None]:
assert len(file_chunks) == TASK_COUNT, len(file_chunks)

In [None]:
file_chunks = file_chunks[TASK_ID - 1 : TASK_ID]

len(file_chunks)

In [None]:
for i, cpu_count, file_chunk in file_chunks:
    print(i, cpu_count)

    output_file = output_dir.joinpath(
        f"features-shard-{i + 1:04d}-of-{TASK_COUNT:04d}.parquet"
    )
    if output_file.is_file():
        print(f"Skipping file with {i=} and {output_file=}.")
        continue

    writer = None
    with concurrent.futures.ProcessPoolExecutor(cpu_count) as pool:
        futures = pool.map(worker, file_chunk)
        for result in tqdm(futures, total=len(file_chunk)):
            table = pa.Table.from_pandas(pd.DataFrame(result), preserve_index=False)
            if writer is None:
                writer = pq.ParquetWriter(output_file, table.schema)
            writer.write_table(table)
    if writer is not None:
        writer.close()

In [None]:
with output_dir.joinpath(f"features-shard-{TASK_ID:04d}-of-{TASK_COUNT:04d}.done").open(
    "wt"
) as fout:
    pass