## Summary

Calculate features using [ProtBert](https://github.com/agemagician/ProtTrans).

---

## Imports

In [1]:
import concurrent.futures
import os
import re
import subprocess
import sys
import tempfile
from pathlib import Path

import jax
import kmbio
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from elaspic2.plugins.alphafold import AlphaFold, AlphaFoldAnalyzeError, AlphaFoldBuildError
from jax.lib import xla_bridge
from kmbio import PDB
from tqdm.notebook import tqdm



## Parameters

In [2]:
NOTEBOOK_DIR = Path("31_run_alphafold").resolve()
NOTEBOOK_DIR.mkdir(exist_ok=True)

NOTEBOOK_DIR

PosixPath('/home/kimlab5/strokach/workspace/elaspic/elaspic2-cagi6/notebooks/31_run_alphafold')

In [3]:
if (slurm_tmpdir := os.getenv("SLURM_TMPDIR")) is not None:
    os.environ["TMPDIR"] = slurm_tmpdir
    
print(tempfile.gettempdir())

/tmp/strokach


In [4]:
DATASET_NAME = os.getenv("DATASET_NAME")
DATASET_NAME = os.getenv("DATASET_PATH")
TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")
TASK_COUNT = os.getenv("ORIGINAL_ARRAY_TASK_COUNT") or os.getenv("SLURM_ARRAY_TASK_COUNT")

TASK_ID = int(TASK_ID) if TASK_ID is not None else None
TASK_COUNT = int(TASK_COUNT) if TASK_COUNT is not None else None

DATASET_NAME, TASK_ID, TASK_COUNT

(None, None, None)

In [5]:
DEBUG = TASK_ID is None

if DEBUG:
    DATASET_NAME = "cagi6-sherloc"
    DATASET_PATH = str(
        NOTEBOOK_DIR.parent.joinpath("30_cagi6_sherloc", "input-data-gby-protein.parquet")
    )
    TASK_ID = 1
    TASK_COUNT = 4182
else:
    assert DATASET_NAME is not None
    assert DATASET_PATH is not None
    assert TASK_ID is not None
    assert TASK_COUNT is not None

DATASET_NAME, TASK_ID, TASK_COUNT

('cagi6-sherloc', 1, 4182)

In [6]:
device = xla_bridge.get_backend().platform

device



'cpu'

## Workspace

### Initialize model

In [7]:
!ls /home/kimlab5/strokach/workspace/elaspic/elaspic2/src/elaspic2/plugins/alphafold/data/

__init__.py  params  __pycache__


In [8]:
AlphaFold.load_model(device=device)

### Load data

In [9]:
pfile = pq.ParquetFile(DATASET_PATH)

pfile.num_row_groups

4182

In [10]:
assert TASK_COUNT == pfile.num_row_groups

In [11]:
input_df = pfile.read_row_group(TASK_ID - 1).to_pandas(integer_object_nulls=True)

In [12]:
display(input_df.head(2))
print(len(input_df))

Unnamed: 0,protein_id,mutation_id,mutation,effect,sequence,structure,alignment
0,P26678,"[NM_002667.3:c.25C>T, NM_002667.3:c.22A>C, NM_...","[R9C, T8P, F32S, P21T, R25C, R25H, I38T, F32S,...","[Pathogenic, Uncertain significance, None, Unc...",MEKVQYLTRSAIRRASTIEMPQQARQKLQNLFINFCLILICLLLIC...,HEADER ...,"[>101\n, MEKVQYLTRSAIRRASTIEMPQQARQKLQNLFINFCL..."


1


In [13]:
tup = next(input_df.itertuples(index=False))

iterable_fields = []
for field in tup._fields:
    try:
        if len(getattr(tup, field)) == len(tup.mutation):
            iterable_fields.append(field)
    except TypeError:
        pass

iterable_fields

['mutation_id', 'mutation', 'effect']

In [14]:
if DEBUG:
    for field in iterable_fields:
        input_df[field] = input_df[field].str[:3]

### Score mutations

In [15]:
def validate_mutation(mutation):
    aa = "GVALICMFWPDESTYQNKRH"
    if re.search(f"^[{aa}][1-9]+[0-9]*[{aa}]$", mutation) is None:
        print(f"Skipping mutation {mutation} because it appears to be malformed.")
        return False

    if mutation[0] == mutation[-1]:
        print(
            f"Skipping mutation {mutation} because the wildtype and mutant residues are the same."
        )
        return False

    return True

In [16]:
def worker(mutation, data):
    try:
        results = AlphaFold.analyze_mutation(f"A_{mutation}", data)
    except AlphaFoldAnalyzeError as error:
        print(f"{error!r}")
        return None
    else:
        results = {f"alphafold_core_{key}": value for key, value in results.items()}
        return results

In [17]:
results = []
for tup in tqdm(input_df.itertuples(index=False), total=len(input_df)):
    assert all([(len(getattr(tup, field)) == len(tup.mutation)) for field in iterable_fields])

    data = AlphaFold.build(tup.sequence, ligand_sequence=None, msa=tup.alignment)

    for mutation_idx, mutation in enumerate(tqdm(tup.mutation, leave=False)):
        if not validate_mutation(mutation):
            continue

        result = worker(mutation, data)
        if result is None:
            continue

        results.append(
            {
                "protein_id": tup.protein_id,
                "mutation": mutation,
            }
            | {field: getattr(tup, field)[mutation_idx] for field in iterable_fields}
            | result
        )

results_df = pd.DataFrame(results)

  0%|          | 0/1 [00:00<?, ?it/s]

Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: invalid syntax (tmpfk7cxx3n.py, line 12)


Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: invalid syntax (tmpfk7cxx3n.py, line 12)


Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: invalid syntax (tmpfk7cxx3n.py, line 12)
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: invalid syntax (tmp6zy4w5lt.py, line 27)


Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: invalid syntax (tmp6zy4w5lt.py, line 27)


Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: invalid syntax (tmp6zy4w5lt.py, line 27)


2021-09-17 20:49:07.102788: E external/org_tensorflow/tensorflow/compiler/xla/service/slow_operation_alarm.cc:55] 
********************************
Very slow compile?  If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
Compiling module jit_apply_fn.66745
********************************


  0%|          | 0/3 [00:00<?, ?it/s]

In [18]:
display(results_df.head(2))
print(len(results_df))

Unnamed: 0,protein_id,mutation,mutation_id,effect,alphafold_core_scores_residue_plddt_wt,alphafold_core_scores_protein_plddt_wt,alphafold_core_scores_protein_max_predicted_aligned_error_wt,alphafold_core_scores_proten_ptm_wt,alphafold_core_features_residue_experimentally_resolved_wt,alphafold_core_features_residue_predicted_lddt_wt,...,alphafold_core_features_residue_experimentally_resolved_mut,alphafold_core_features_residue_predicted_lddt_mut,alphafold_core_features_residue_msa_first_row_mut,alphafold_core_features_residue_single_mut,alphafold_core_features_residue_structure_module_mut,alphafold_core_features_protein_experimentally_resolved_mut,alphafold_core_features_protein_predicted_lddt_mut,alphafold_core_features_protein_msa_first_row_mut,alphafold_core_features_protein_single_mut,alphafold_core_features_protein_structure_module_mut
0,P26678,R9C,NM_002667.3:c.25C>T,Pathogenic,60.143024,76.669229,31.75,0.428104,"[0.91980046, 0.93805194, 1.1262269, 0.71249473...","[-6.096446, -7.130394, -6.2001715, -5.482113, ...",...,"[0.86495286, 0.8837349, 0.99142677, 0.69788605...","[-5.995908, -7.0391264, -6.134736, -5.429315, ...","[0.25319177, 2.9904017, -10.597328, 4.9771013,...","[13.532112, -14.943062, 22.028437, -28.517002,...","[0.004844746, 0.010447413, -0.005723976, 0.001...","[0.5032713, 0.5308786, 0.6542605, 0.49861935, ...","[-6.077047, -7.7973776, -6.7493367, -5.9762626...","[1.7017388, 2.4887137, -3.5103338, 2.986837, -...","[8.305865, 9.063175, 22.442371, -10.052725, -5...","[0.0020534596, 0.011630214, -0.005989557, 0.00..."
1,P26678,T8P,NM_002667.3:c.22A>C,Uncertain significance,60.421725,76.669229,31.75,0.428104,"[-0.9088592, -0.91381735, -1.0385773, -1.09823...","[-5.636691, -7.329773, -6.1853166, -5.387794, ...",...,"[-0.68684113, -0.6815068, -0.8155767, -0.88297...","[-5.5785475, -7.255927, -6.1465797, -5.368289,...","[4.0005136, 2.2147536, -8.588517, 3.2489474, -...","[0.6381166, 15.6405945, 40.020134, -30.35323, ...","[0.005779301, 0.007556449, -0.0058354866, 0.00...","[0.5103155, 0.5385052, 0.66315275, 0.50580376,...","[-6.0924473, -7.7991695, -6.75408, -5.981562, ...","[1.7752985, 2.522232, -3.5305977, 3.0123546, -...","[8.289321, 9.056566, 22.490763, -10.121397, -5...","[0.0021891003, 0.0117158955, -0.0059981886, 0...."


3


### Save results

In [19]:
output_file = NOTEBOOK_DIR.joinpath(DATASET_NAME, f"shard-{TASK_ID}-of-{TASK_COUNT}.parquet")
output_file.parent.mkdir(exist_ok=True)

output_file

PosixPath('/home/kimlab5/strokach/workspace/elaspic/elaspic2-cagi6/notebooks/31_run_alphafold/cagi6-sherloc/shard-1-of-4182.parquet')

In [20]:
pq.write_table(pa.Table.from_pandas(results_df, preserve_index=False), output_file)