## Summary

### CAGI6-Sherloc

```bash
export NOTEBOOK_PATH="$(realpath 32_process_alphafold.ipynb)"
export DATASET_NAME="cagi6-sherloc"
export DATASET_PATH="30_cagi6_sherloc/input-data-gby-protein.parquet"
export ORIGINAL_ARRAY_TASK_COUNT=4182
unset ARRAY_TASK_OFFSET

sbatch --export NOTEBOOK_PATH,DATASET_NAME,DATASET_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=1-500 --time 3:00:00 --ntasks-per-node 2 --mem=18G ../scripts/run_notebook_cpu.sh
sbatch --export NOTEBOOK_PATH,DATASET_NAME,DATASET_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=501-1000 --time 3:00:00 --ntasks-per-node 2 --mem=24G ../scripts/run_notebook_cpu.sh
sbatch --export NOTEBOOK_PATH,DATASET_NAME,DATASET_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=1001-1500 --time 3:00:00 --ntasks-per-node 2 --mem=24G ../scripts/run_notebook_cpu.sh
sbatch --export NOTEBOOK_PATH,DATASET_NAME,DATASET_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=1501-2000 --time 3:00:00 --ntasks-per-node 2 --mem=24G ../scripts/run_notebook_cpu.sh
sbatch --export NOTEBOOK_PATH,DATASET_NAME,DATASET_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=2001-2500 --time 3:00:00 --ntasks-per-node 2 --mem=32G ../scripts/run_notebook_cpu.sh
sbatch --export NOTEBOOK_PATH,DATASET_NAME,DATASET_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=2501-3000 --time 3:00:00 --ntasks-per-node 2 --mem=32G ../scripts/run_notebook_cpu.sh
sbatch --export NOTEBOOK_PATH,DATASET_NAME,DATASET_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=3001-3500 --time 3:00:00 --ntasks-per-node 2 --mem=40G ../scripts/run_notebook_cpu.sh
sbatch --export NOTEBOOK_PATH,DATASET_NAME,DATASET_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=3501-4000 --time 3:00:00 --ntasks-per-node 2 --mem=48G ../scripts/run_notebook_cpu.sh
sbatch --export NOTEBOOK_PATH,DATASET_NAME,DATASET_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=4001-4182 --time 3:00:00 --ntasks-per-node 2 --mem=62G ../scripts/run_notebook_cpu.sh
```

### Humsavar

```bash
export NOTEBOOK_PATH="$(realpath 32_process_alphafold.ipynb)"
export DATASET_NAME="humsavar"
export DATASET_PATH="30_humsavar/humsavar-gby-protein-waln.parquet"
export ORIGINAL_ARRAY_TASK_COUNT=12557
unset ARRAY_TASK_OFFSET

# !!! Modify the lines below !!!

sbatch --export NOTEBOOK_PATH,DATASET_NAME,DATASET_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=1-9999 --time 3:00:00 --gres=gpu:p100:1 --mem=18G ../scripts/run_notebook_gpu.sh

export ARRAY_TASK_OFFSET=10000
sbatch --export NOTEBOOK_PATH,DATASET_NAME,DATASET_PATH,ORIGINAL_ARRAY_TASK_COUNT,ARRAY_TASK_OFFSET --array=0-1000 --time 3:00:00 --gres=gpu:v100l:1 --mem=32G ../scripts/run_notebook_gpu.sh

export ARRAY_TASK_OFFSET=10000
sbatch --export NOTEBOOK_PATH,DATASET_NAME,DATASET_PATH,ORIGINAL_ARRAY_TASK_COUNT,ARRAY_TASK_OFFSET --array=1000-2557 --time 24:00:00 --gres=gpu:v100l:1 --mem=46G ../scripts/run_notebook_gpu.sh
```

---

## Imports

In [1]:
import concurrent.futures
import itertools
import os
import pickle
import string
import tempfile
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import scipy.special as sps
from alphafold.common import residue_constants
from scipy import stats
from sklearn import metrics, model_selection
from tqdm.auto import tqdm

In [2]:
pd.set_option("max_columns", 1000)
pd.set_option("max_rows", 1000)

## Parameters

In [3]:
NOTEBOOK_DIR = Path("32_process_alphafold").resolve()
NOTEBOOK_DIR.mkdir(exist_ok=True)

NOTEBOOK_DIR

PosixPath('/lustre07/scratch/strokach/workspace/elaspic2-cagi6/notebooks/32_process_alphafold')

In [4]:
INPUT_DIR = Path("31_run_alphafold_wt").resolve()

INPUT_DIR

PosixPath('/lustre07/scratch/strokach/workspace/elaspic2-cagi6/notebooks/31_run_alphafold_wt')

In [5]:
OUTPUT_DIR = NOTEBOOK_DIR.parent.joinpath("31_run_alphafold").resolve()
OUTPUT_DIR.mkdir(exist_ok=True)

OUTPUT_DIR

PosixPath('/lustre07/scratch/strokach/workspace/elaspic2-cagi6/notebooks/31_run_alphafold')

In [20]:
DATASET_NAME = os.getenv("DATASET_NAME")
DATASET_PATH = os.getenv("DATASET_PATH")
TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")
TASK_COUNT = (
    #
    os.getenv("ORIGINAL_ARRAY_TASK_COUNT")
    or os.getenv("SLURM_ARRAY_TASK_COUNT")
)
ARRAY_TASK_OFFSET = int(os.getenv("ARRAY_TASK_OFFSET", "0"))

TASK_ID = (int(TASK_ID) + ARRAY_TASK_OFFSET) if TASK_ID is not None else None
TASK_COUNT = int(TASK_COUNT) if TASK_COUNT is not None else None

TASK_ID, TASK_COUNT, DATASET_NAME, DATASET_PATH, ARRAY_TASK_OFFSET

(None,
 4182,
 'cagi6-sherloc',
 '30_cagi6_sherloc/input-data-gby-protein.parquet',
 0)

In [7]:
DEBUG = TASK_ID is None

if DEBUG:
    DATASET_NAME = "cagi6-sherloc"
    DATASET_PATH = str(
        NOTEBOOK_DIR.parent.joinpath(
            "30_cagi6_sherloc", "input-data-gby-protein.parquet"
        )
    )
    TASK_ID = 1
    TASK_COUNT = 4182

    # DATASET_NAME = "humsavar"
    # DATASET_PATH = str(
    #     NOTEBOOK_DIR.parent.joinpath("30_humsavar", "humsavar-gby-protein-waln.parquet")
    # )
    # TASK_ID = 10000
    # TASK_COUNT = 12557

else:
    assert TASK_ID is not None
    assert TASK_COUNT is not None
    assert DATASET_NAME is not None
    assert DATASET_PATH is not None

TASK_ID, TASK_COUNT, DATASET_NAME, DATASET_PATH

(1,
 4182,
 'cagi6-sherloc',
 '/lustre07/scratch/strokach/workspace/elaspic2-cagi6/notebooks/30_cagi6_sherloc/input-data-gby-protein.parquet')

In [8]:
output_file = OUTPUT_DIR.joinpath(
    DATASET_NAME, f"results-{TASK_ID}-of-{TASK_COUNT}.parquet"
)
output_file.parent.mkdir(exist_ok=True)

if output_file.is_file():
    raise Exception(f"Output file {output_file!r} already exists!")

## Workspace

### Load input data

In [9]:
input_file = INPUT_DIR.joinpath(
    DATASET_NAME, f"shard-{TASK_ID}-of-{TASK_COUNT}.parquet"
)
assert input_file.is_file(), input_file

input_file

PosixPath('/lustre07/scratch/strokach/workspace/elaspic2-cagi6/notebooks/31_run_alphafold_wt/cagi6-sherloc/shard-1-of-4182.parquet')

In [10]:
columns = [
    "protein_id",
    # Sequence
    "single",
    "experimentally_resolved",
    "predicted_lddt",
    "msa_first_row",
    "structure_module",
    "max_predicted_aligned_error",
    "plddt",
    "ptm",
    # Pairwise
    "pair",
    "distogram",
    # "distogram_bin_edges",
    "aligned_confidence_probs",
    "predicted_aligned_error",
    # MSA
    "msa",
    "masked_msa",
]

In [11]:
input_df = pq.read_table(input_file, columns=columns).to_pandas(
    integer_object_nulls=True
)

assert len(input_df) == 1
input_row = input_df.iloc[0].to_dict()
del input_df

In [12]:
for column in [
    "distogram",
    "masked_msa",
    "predicted_aligned_error",
    "aligned_confidence_probs",
    "msa",
    "pair",
]:
    input_row[column] = np.stack(
        [np.stack(input_row[column][i]) for i in range(len(input_row[column]))]
    )

In [13]:
list(input_row)

['protein_id',
 'single',
 'experimentally_resolved',
 'predicted_lddt',
 'msa_first_row',
 'structure_module',
 'max_predicted_aligned_error',
 'plddt',
 'ptm',
 'pair',
 'distogram',
 'aligned_confidence_probs',
 'predicted_aligned_error',
 'msa',
 'masked_msa']

### Map protein ids to mutations

In [14]:
protein_mutation_lookup = (
    pq.read_table(DATASET_PATH, columns=["protein_id", "mutation"])
    .to_pandas()
    .set_index("protein_id")["mutation"]
    .to_dict()
)

len(protein_mutation_lookup)

2833

In [15]:
input_row["mutations"] = protein_mutation_lookup[input_row["protein_id"]]

input_row["mutations"]

array(['R9C', 'T8P', 'F32S', 'P21T', 'R25C', 'R25H', 'I38T', 'F32S',
       'I12M', 'I12V', 'V49M', 'S10L', 'R9H', 'M50T', 'L51I', 'L52I',
       'R25G', 'L42I', 'L44P', 'I18T', 'R14I'], dtype=object)

### Process AlphaFold embeddings

In [16]:
def process_mutation_embeddings(mutation, predictions):
    wt, pos, mut = mutation[0], mutation[1:-1], mutation[-1]
    idx = int(pos) - 1
    assert idx >= 0

    af_wt_idx = residue_constants.restype_order_with_x[wt]
    af_mut_idx = residue_constants.restype_order_with_x[mut]

    # Sequence
    def as_residue(x):
        return x[idx].astype(np.float32)

    def as_protein(x):
        return x.mean(axis=0).astype(np.float32)

    # Pairwise
    def agg_rows(x, fn):
        return fn(x[idx, :, :], axis=0)

    def agg_columns(x, fn):
        return fn(x[:, idx, :], axis=0)

    def extract_diagonal(x):
        return x[idx, idx, :]

    # MSA
    def extract_msa_logit(value, aa_idx):
        return value[:, idx, aa_idx]

    def extract_msa_logproba(value, aa_idx):
        return sps.log_softmax(value, axis=-1)[:, idx, aa_idx]

    sequence_embeddings = {
        "experimentally_resolved": predictions["experimentally_resolved"],
        "predicted_lddt": predictions["predicted_lddt"],
        "msa_first_row": predictions["msa_first_row"],
        "single": predictions["single"],
        "structure_module": predictions["structure_module"],
    }

    pairwise_embeddings = {
        "distogram": predictions["distogram"],
        # "aligned_confidence_probs": predictions["aligned_confidence_probs"],
        "pair": predictions["pair"],
    }

    msa_embeddings = {
        "msa": predictions["msa"],
    }

    output = {
        # Sequence
        "score_plddt": predictions["plddt"][idx].item(),
        # "score_protein_plddt": predictions["plddt"].mean().item(),  # no good
        # "score_protein_max_predicted_aligned_error": (
        #     predictions["max_predicted_aligned_error"]
        # ),  # all same
        # "score_protein_ptm": predictions["ptm"],  # no good
        # ...
        **{
            f"features_{key}": as_residue(value)
            for key, value in sequence_embeddings.items()
        },
        # Pairwise 2D
        "score_predicted_aligned_error_row_mean": (
            predictions["predicted_aligned_error"][idx, :].mean().item()
        ),
        # "score_predicted_aligned_error_row_max": (
        #     predictions["predicted_aligned_error"][idx, :].max().item()
        # ),  # no good
        "score_predicted_aligned_error_col_mean": (
            predictions["predicted_aligned_error"][:, idx].mean().item()
        ),
        # "score_predicted_aligned_error_col_max": (
        #     predictions["predicted_aligned_error"][:, idx].max().item()
        # ),  # no good
        "score_predicted_aligned_error_diag": (
            predictions["predicted_aligned_error"][idx, idx].item()
        ),
        # Pairwise 3D
        **{
            f"features_{key}_row_mean": agg_rows(value, np.mean)
            for key, value in pairwise_embeddings.items()
        },
        **{
            f"features_{key}_row_max": agg_rows(value, np.max)
            for key, value in pairwise_embeddings.items()
        },
        **{
            f"features_{key}_col_mean": agg_columns(value, np.mean)
            for key, value in pairwise_embeddings.items()
        },
        **{
            f"features_{key}_col_max": agg_columns(value, np.max)
            for key, value in pairwise_embeddings.items()
        },
        **{
            f"features_{key}_diag": extract_diagonal(value)
            for key, value in pairwise_embeddings.items()
        },
        # MSA
        **{
            f"score_msa_{process}_{agg}_{variant}": (
                agg_fn(process_fn(predictions["masked_msa"], variant_idx))
            )
            for process, process_fn in [
                # ("logits", extract_msa_logit),  # does not help
                ("logproba", extract_msa_logproba),
            ]
            for agg, agg_fn in [
                ("first", lambda x: x[0]),
                ("mean", lambda x: np.mean(x, axis=0)),
                ("max", lambda x: np.max(x, axis=0)),
            ]
            for variant, variant_idx in [
                ("wt", af_wt_idx),
                ("mut", af_mut_idx),
            ]
        },
        **{
            f"features_{key}_first": value[0, idx, :]
            for key, value in msa_embeddings.items()
        },
        **{
            f"features_{key}_mean": agg_columns(value, np.mean)
            for key, value in msa_embeddings.items()
        },
        **{
            f"features_{key}_max": agg_columns(value, np.max)
            for key, value in msa_embeddings.items()
        },
    }

    return output

In [17]:
results = []
for mutation in input_row["mutations"]:
    try:
        features = {
            f"alphafold_core_{key}": value
            for key, value in process_mutation_embeddings(mutation, input_row).items()
        }
    except Exception as e:
        if TASK_ID == 1098:
            print(f"Encountered an error: {e!r}!")
            continue
        else:
            raise

    results.append(
        {
            "protein_id": input_row["protein_id"],
            "mutation": mutation,
        }
        | features
    )

results_df = pd.DataFrame(results)
del results

### Write results

In [18]:
if not DEBUG:
    pq.write_table(pa.Table.from_pandas(results_df, preserve_index=False), output_file)
else:
    with tempfile.NamedTemporaryFile() as tmp_file:
        print(tmp_file.name)
        pq.write_table(
            pa.Table.from_pandas(results_df, preserve_index=False), tmp_file.name
        )

/tmp/tmptewi2_vp


In [19]:
print("Done!")

Done!
