## Summary

Calculate features using [AlphaFold](https://github.com/deepmind/alphafold) (for the wildtype protein only).

### Humsavar

```bash
export NOTEBOOK_PATH="$(realpath 31_run_alphafold_wt_template.ipynb)"
export DATASET_NAME="humsavar"
export DATASET_PATH="30_humsavar/humsavar-gby-protein-waln.parquet"
export ORIGINAL_ARRAY_TASK_COUNT=12557

sbatch --export NOTEBOOK_PATH,DATASET_NAME,DATASET_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=1-9999 --time 3:00:00 --gres=gpu:p100:1 --mem=18G ../scripts/run_notebook_gpu.sh

export ARRAY_TASK_OFFSET=10000
sbatch --export NOTEBOOK_PATH,DATASET_NAME,DATASET_PATH,ORIGINAL_ARRAY_TASK_COUNT,ARRAY_TASK_OFFSET --array=0-1000 --time 3:00:00 --gres=gpu:v100l:1 --mem=32G ../scripts/run_notebook_gpu.sh

export ARRAY_TASK_OFFSET=10000
sbatch --export NOTEBOOK_PATH,DATASET_NAME,DATASET_PATH,ORIGINAL_ARRAY_TASK_COUNT,ARRAY_TASK_OFFSET --array=1000-2557 --time 24:00:00 --gres=gpu:v100l:1 --mem=46G ../scripts/run_notebook_gpu.sh
```

**Narval:**

```bash
export NOTEBOOK_PATH="$(realpath 31_run_alphafold_wt_template.ipynb)"
export DATASET_NAME="humsavar"
export DATASET_PATH="$(realpath 30_humsavar/humsavar-gby-protein-waln.parquet)"
export ORIGINAL_ARRAY_TASK_COUNT=12557

export ARRAY_TASK_OFFSET=10000
sbatch --export NOTEBOOK_PATH,DATASET_NAME,DATASET_PATH,ORIGINAL_ARRAY_TASK_COUNT,ARRAY_TASK_OFFSET --array=2382,2385,2386,2387,2388,2389,2390,2391,2392,2393,2394,2395,2396,2397,2398,2399,2400,2401,2402,2403,2404,2405,2406,2407,2408,2409,2410,2411,2412,2413,2414,2415,2416,2417,2418,2419,2420,2421,2422,2423,2424,2425,2426,2427,2428,2429,2430,2431,2432,2433,2434,2435,2436,2437,2438,2439,2440,2441,2442,2443,2444,2445,2446,2447,2448,2449,2450,2451,2452,2453,2454,2455,2456,2457,2458,2459,2460,2461,2462,2463,2464,2465,2466,2467,2468,2469,2470,2471,2472,2473,2474,2475,2476,2477,2478,2479,2480,2481,2482,2483,2484,2485,2486,2487,2488,2489,2490,2491,2492,2493,2494,2495,2496,2497,2498,2499,2500,2501,2502,2503,2504,2505,2506,2507,2508,2509,2510,2511,2512,2513,2514,2515,2516,2517,2518,2519,2520,2521,2522,2523,2524,2525,2526,2527,2528,2529,2530,2531,2532,2533,2534,2535,2536,2537,2538,2539,2540,2541,2542,2543,2544,2545,2546,2547,2548,2549,2550,2551,2552,2553,2554,2555,2556,2557 --time 24:00:00 --gres=gpu:a100:1 --mem=124G --account=def-pmkim ../scripts/run_notebook_gpu.sh
```

### cagi6-sherloc

```bash
export NOTEBOOK_PATH="$(realpath 31_run_alphafold_wt_template2.ipynb)"
export DATASET_NAME="cagi6-sherloc"
export DATASET_PATH="30_cagi6_sherloc/input-data-gby-protein.parquet"
export ORIGINAL_ARRAY_TASK_COUNT=4182

# p100
sbatch --export NOTEBOOK_PATH,DATASET_NAME,DATASET_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=1-1000 --time 3:00:00 --gres=gpu:p100:1 --mem=18G ../scripts/run_notebook_gpu.sh
sbatch --export NOTEBOOK_PATH,DATASET_NAME,DATASET_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=1001-2000 --time 3:00:00 --gres=gpu:p100:1 --mem=18G ../scripts/run_notebook_gpu.sh

# v100l
sbatch --export NOTEBOOK_PATH,DATASET_NAME,DATASET_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=2001-3000 --time 3:00:00 --gres=gpu:v100l:1 --mem=28G ../scripts/run_notebook_gpu.sh
sbatch --export NOTEBOOK_PATH,DATASET_NAME,DATASET_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=3001-3800 --time 3:00:00 --gres=gpu:v100l:1 --mem=28G ../scripts/run_notebook_gpu.sh

# a100
sbatch --export NOTEBOOK_PATH,DATASET_NAME,DATASET_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=3801-4182 --time 3:00:00 --gres=gpu:a100:1 --mem=48G ../scripts/run_notebook_gpu.sh
```

---

## Imports

In [1]:
%env TF_FORCE_UNIFIED_MEMORY=1
%env XLA_PYTHON_CLIENT_MEM_FRACTION=4.0

env: TF_FORCE_UNIFIED_MEMORY=1
env: XLA_PYTHON_CLIENT_MEM_FRACTION=4.0


In [2]:
import concurrent.futures
import os
import re
import subprocess
import sys
import tempfile
from pathlib import Path

import jax
import kmbio
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from elaspic2.plugins.alphafold import AlphaFold, AlphaFoldAnalyzeError, AlphaFoldBuildError
from jax.lib import xla_bridge
from kmbio import PDB
from tqdm.notebook import tqdm



## Parameters

In [3]:
NOTEBOOK_DIR = Path("31_run_alphafold_wt_template").resolve()
NOTEBOOK_DIR.mkdir(exist_ok=True)

NOTEBOOK_DIR

PosixPath('/lustre07/scratch/strokach/workspace/elaspic2-cagi6/notebooks/31_run_alphafold_wt_template')

In [5]:
if (slurm_tmpdir := os.getenv("SLURM_TMPDIR")) is not None:
    os.environ["TMPDIR"] = slurm_tmpdir

print(tempfile.gettempdir())

/tmp


In [6]:
DATASET_NAME = os.getenv("DATASET_NAME")
DATASET_PATH = os.getenv("DATASET_PATH")
TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")
TASK_COUNT = os.getenv("ORIGINAL_ARRAY_TASK_COUNT") or os.getenv("SLURM_ARRAY_TASK_COUNT")
ARRAY_TASK_OFFSET = int(os.getenv("ARRAY_TASK_OFFSET", "0"))

TASK_ID = (int(TASK_ID) + ARRAY_TASK_OFFSET) if TASK_ID is not None else None
TASK_COUNT = int(TASK_COUNT) if TASK_COUNT is not None else None 

DATASET_NAME, DATASET_PATH, TASK_ID, TASK_COUNT, ARRAY_TASK_OFFSET

(None, None, None, None, 0)

In [7]:
DEBUG = TASK_ID is None

if DEBUG:
    DATASET_NAME = "cagi6-sherloc"
    DATASET_PATH = str(
        NOTEBOOK_DIR.parent.joinpath("30_cagi6_sherloc", "input-data-gby-protein.parquet")
    )
    TASK_ID = 3500
    TASK_COUNT = 4182
    # DATASET_NAME = "humsavar"
    # # DATASET_PATH = str(
    # #     NOTEBOOK_DIR.parent.joinpath("30_humsavar", "humsavar-gby-protein-waln.parquet")
    # # )
    # DATASET_PATH = "/lustre07/scratch/strokach/workspace/elaspic2-cagi6/notebooks/30_humsavar/humsavar-gby-protein-waln.parquet"
    # TASK_ID = 12382
    # TASK_COUNT = 12557
else:
    assert DATASET_NAME is not None
    assert DATASET_PATH is not None
    DATASET_PATH = Path(DATASET_PATH).expanduser().resolve()
    assert TASK_ID is not None
    assert TASK_COUNT is not None

DATASET_NAME, DATASET_PATH, TASK_ID, TASK_COUNT

('humsavar',
 '/lustre07/scratch/strokach/workspace/elaspic2-cagi6/notebooks/30_humsavar/humsavar-gby-protein-waln.parquet',
 12382,
 12557)

In [8]:
device = xla_bridge.get_backend().platform

device

'gpu'

In [9]:
output_file = NOTEBOOK_DIR.joinpath(DATASET_NAME, f"shard-{TASK_ID}-of-{TASK_COUNT}.parquet")
output_file.parent.mkdir(exist_ok=True)

output_file

PosixPath('/lustre07/scratch/sunyun/workspace/elaspic2-cagi6/notebooks/31_run_alphafold_wt/humsavar/shard-12382-of-12557.parquet')

In [10]:
if output_file.is_file():
    raise Exception("Already finished!")

## Workspace

### Initialize model

In [11]:
AlphaFold.load_model(device=device)

## Load data

In [12]:
pfile = pq.ParquetFile(DATASET_PATH)

pfile.num_row_groups

12557

In [13]:
assert TASK_COUNT == pfile.num_row_groups

In [14]:
input_df = pfile.read_row_group(TASK_ID - 1).to_pandas(integer_object_nulls=True)

In [15]:
display(input_df.head(1))
print(len(input_df))

Unnamed: 0,protein_id,mutation,effect,sequence,structure,alignment
0,O43157,"[R389W, S753L, D1891V]","[LB/B, LB/B, US]",MPALGPALLQALWAGWVLTLQPLPPTAFTPNGTYLQHLARDPTSGT...,HEADER ...,"[>101\n, MPALGPALLQALWAGWVLTLQPLPPTAFTPNGTYLQH..."


1


In [16]:
protein_id_column = None

for col in ["protein_id", "uniprot_id"]:
    if col in input_df: 
        protein_id_column = col

assert protein_id_column is not None
protein_id_column

'protein_id'

In [17]:
# Get "mutation", "effect" 
tup = next(input_df.itertuples(index=False))
iterable_fields = []
for field in tup._fields: 
    if field in [protein_id_column]:
        continue
    try: 
        
        if len(getattr(tup, field)) == len(tup.mutation):
            iterable_fields.append(field)
    except TypeError: 
        pass

iterable_fields

['mutation', 'effect']

In [18]:
pfile.read_row_group(0).to_pandas(integer_object_nulls=True)

Unnamed: 0,protein_id,mutation,effect,sequence,structure,alignment
0,A0A0C5B5G6,[K14Q],[US],MRWQEMGYIFYPRKLR,HEADER ...,"[>101\n, MRWQEMGYIFYPRKLR\n, >UniRef100_A0A0C5..."


In [19]:
# Only get the first three mutations 
if DEBUG:
    for field in iterable_fields: 
        input_df[field] = input_df[field].str[:3]
    input_df

### Score mutations

In [20]:
def validate_mutation(mutation):
    aa = "GVALICMFWPDESTYQNKRH"
    if re.search(f"^[{aa}][1-9]+[0-9]*[{aa}]$", mutation) is None: 
        print(f"Skipping mutation {mutation} because it appears to be malformed.")
        return False
    
    if mutation[0] == mutation[-1]:
        print(
            f"Skipping mutation {mutation} because the wildtype and mutant residues are the same."
        )
        return False
    
    return True

In [21]:
def predictions_to_embeddings(predictions):
    return {
        "plddt": pa.array([predictions["plddt"]]),
        "max_predicted_aligned_error": pa.array(
            [predictions["max_predicted_aligned_error"].item()]
        ),
        "ptm": pa.array([predictions["ptm"].item()]),
        #
        "experimentally_resolved": pa.array(
            [predictions["experimentally_resolved"]["logits"].to_py().tolist()]
        ),
        "predicted_lddt": pa.array([predictions["predicted_lddt"]["logits"].to_py().tolist()]),
        #
        "msa_first_row": pa.array(
            [predictions["representations"]["msa_first_row"].to_py().tolist()]
        ),
        "single": pa.array([predictions["representations"]["single"].to_py().tolist()]),
        "structure_module": pa.array(
            [predictions["representations"]["structure_module"].to_py().tolist()]
        ),
        # Pairwise metrics
        "distogram": pa.array([data.predictions["distogram"]["logits"].to_py().tolist()]),
        "distogram_bin_edges": pa.array(
            [data.predictions["distogram"]["bin_edges"].to_py().tolist()]
        ),
        "masked_msa": pa.array([data.predictions["masked_msa"]["logits"].to_py().tolist()]),
        "predicted_aligned_error": pa.array([data.predictions["predicted_aligned_error"].tolist()]),
        "aligned_confidence_probs": pa.array(
            [data.predictions["aligned_confidence_probs"].tolist()]
        ),
        "msa": pa.array([data.predictions["representations"]["msa"].to_py().tolist()]),
        "pair": pa.array([data.predictions["representations"]["pair"].to_py().tolist()]),
    }

In [None]:
assert len(input_df) == 1

for tup in tqdm(input_df.itertuples(index=False), total=len(input_df)):
    assert all([(len(getattr(tup, field)) == len(tup.mutation)) for field in iterable_fields])

    with tempfile.NamedTemporaryFile(suffix=".pdb") as tmp_file:
        with open(tmp_file.name, "wt") as fout:
            fout.write(tup.structure)
        data = AlphaFold.build(
            tup.sequence, ligand_sequence=None, msa=tup.alignment, structure_file=tmp_file.name
        )

    results = {"protein_id": pa.array([tup.protein_id])} | predictions_to_embeddings(
        data.predictions
    )

    del data

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1.0), HTML(value='')))

Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'


Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'


In [None]:
table = pa.Table.from_pydict(results)
del results

In [None]:
# table

### Save results

In [None]:
if not DEBUG:
    print(output_file)
    pq.write_table(table, output_file)
else: 
    print(Path(tempfile.gettempdir(), output_file.name))
    pq.write_table(table, Path(tempfile.gettempdir(), output_file.name))