## Summary

Calculate features using [ProtBert](https://github.com/agemagician/ProtTrans).

```bash
export NOTEBOOK_PATH="$(realpath 31_run_protbert.ipynb)"
export DATASET_NAME="cagi6-sherloc"
export DATASET_PATH="30_cagi6_sherloc/input-data-gby-protein.parquet"
export ORIGINAL_ARRAY_TASK_COUNT=4182

# === Graham ===
sbatch --export DATASET_NAME,DATASET_PATH,NOTEBOOK_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=1-2800 --time 3:00:00 --gres=gpu:t4:1 ../scripts/run_notebook_gpu.sh

sbatch --export DATASET_NAME,DATASET_PATH,NOTEBOOK_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=181-200 --time 3:00:00 --gres=gpu:v100:1 ../scripts/run_notebook_gpu.sh

sbatch --export DATASET_NAME,DATASET_PATH,NOTEBOOK_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=201-306 ../scripts/run_notebook_cpu.sh


# === Cedar ===
sbatch --export DATASET_NAME,DATASET_PATH,NOTEBOOK_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=3518-3800 --ntasks-per-node=40 --mem=80G ../scripts/run_notebook_cpu.sh


# === Beluga

sbatch --export DATASET_NAME,DATASET_PATH,NOTEBOOK_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=5,20,30,68,93,94,97,202,263,266,281,284,294,306,309,350,359,391,409,435,465,473,486,535,582,616,658,693,704,705,706,739,756,861,951,975,992,1036,1038,1053,1098,1101,1106,1141,1191,1206,1223,1288,1322,1361,1364,1382,1399,1404,1427,1473,1502,1520,1555,1646,1667,1668,1691,1692,1718,1727,1731,1751,1758,1764,1842,1857,1917,1935,1964,1981,1985,2024,2036,2053,2054,2058,2103,2140,2143,2165,2175,2229,2318,2340,2353,2417,2433,2526,2643,2644,2684,2701,2742,2754,2767,2812,2823,2848,2858,2889,2907,2918,2949,2976,2981,2982,3021,3023,3027,3038,3070,3091,3108,3123,3140,3245,3264,3336,3352,3365,3388,3413,3510 --ntasks-per-node=40 --mem=80G ../scripts/run_notebook_cpu.sh
```

```bash
export NOTEBOOK_PATH="$(realpath 31_run_protbert.ipynb)"
export DATASET_NAME="humsavar"
export DATASET_PATH="30_humsavar/humsavar-gby-protein.parquet"
export ORIGINAL_ARRAY_TASK_COUNT=612

# Cedar
sbatch --export DATASET_NAME,DATASET_PATH,NOTEBOOK_PATH,ORIGINAL_ARRAY_TASK_COUNT --job-name=protbert --array=1-400 --gres=gpu:p100:1 ../scripts/run_notebook_gpu.sh

sbatch --export DATASET_NAME,DATASET_PATH,NOTEBOOK_PATH,ORIGINAL_ARRAY_TASK_COUNT --job-name=protbert --array=400-550 --gres=gpu:v100l:1 ../scripts/run_notebook_gpu.sh

sbatch --export DATASET_NAME,DATASET_PATH,NOTEBOOK_PATH,ORIGINAL_ARRAY_TASK_COUNT --job-name=protbert --array=550-612 --ntasks-per-node=32 --mem=110G --time=24:00:00 ../scripts/run_notebook_cpu.sh
```

---

## Imports

In [None]:
import concurrent.futures
import os
import re
import subprocess
import sys
import tempfile
from pathlib import Path

import kmbio
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import torch
from elaspic2.plugins.protbert import ProtBert, ProtBertAnalyzeError, ProtBertBuildError
from kmbio import PDB
from tqdm.notebook import tqdm

## Parameters

In [None]:
NOTEBOOK_DIR = Path("31_run_protbert").resolve()
NOTEBOOK_DIR.mkdir(exist_ok=True)

NOTEBOOK_DIR

In [None]:
if (slurm_tmpdir := os.getenv("SLURM_TMPDIR")) is not None:
    os.environ["TMPDIR"] = slurm_tmpdir
    
print(tempfile.gettempdir())

In [None]:
DATASET_NAME = os.getenv("DATASET_NAME")
DATASET_PATH = os.getenv("DATASET_PATH")
TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")
TASK_COUNT = os.getenv("ORIGINAL_ARRAY_TASK_COUNT") or os.getenv("SLURM_ARRAY_TASK_COUNT")

TASK_ID = int(TASK_ID) if TASK_ID is not None else None
TASK_COUNT = int(TASK_COUNT) if TASK_COUNT is not None else None

DATASET_NAME, DATASET_PATH, TASK_ID, TASK_COUNT

In [None]:
DEBUG = TASK_ID is None

if DEBUG:
    DATASET_NAME = "humsavar"
    DATASET_PATH = str(
        NOTEBOOK_DIR.parent.joinpath("30_humsavar", "humsavar-gby-protein.parquet")
        # NOTEBOOK_DIR.parent.joinpath("30_cagi6_sherloc", "input-data-gby-protein.parquet")
    )
    TASK_ID = 612
    TASK_COUNT = 612 # 4182
else:
    assert DATASET_NAME is not None
    assert DATASET_PATH is not None
    assert TASK_ID is not None
    assert TASK_COUNT is not None
    DATASET_PATH = Path(DATASET_PATH).expanduser().resolve()

DATASET_NAME, DATASET_PATH, TASK_ID, TASK_COUNT

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device

In [None]:
output_file = NOTEBOOK_DIR.joinpath(DATASET_NAME, f"shard-{TASK_ID}-of-{TASK_COUNT}.parquet")
output_file.parent.mkdir(exist_ok=True)

output_file

In [None]:
if output_file.is_file():
    raise Exception("Already finished!")

## Workspace

### Initialize model

In [None]:
ProtBert.load_model(device=device)

### Load data

In [None]:
pfile = pq.ParquetFile(DATASET_PATH)

pfile.num_row_groups

In [None]:
assert TASK_COUNT == pfile.num_row_groups

In [None]:
pfile.num_row_groups

In [None]:
input_df = pfile.read_row_group(TASK_ID - 1).to_pandas(integer_object_nulls=True)

In [None]:
display(input_df.head(2))
print(len(input_df))

In [None]:
protein_id_column = None

for col in ["protein_id", "uniprot_id"]:
    if col in input_df:
        protein_id_column = col
        
assert protein_id_column is not None
protein_id_column

In [None]:
tup = next(input_df.itertuples(index=False))

iterable_fields = []
for field in tup._fields:
    if field in [protein_id_column]:
        continue
    try:
        if len(getattr(tup, field)) == len(tup.mutation):
            iterable_fields.append(field)
    except TypeError:
        pass

iterable_fields

In [None]:
if DEBUG:
    for field in iterable_fields:
        input_df[field] = input_df[field].str[:3]

### Score mutations

In [None]:
def validate_mutation(mutation):
    aa = "GVALICMFWPDESTYQNKRH"
    if re.search(f"^[{aa}][1-9]+[0-9]*[{aa}]$", mutation) is None:
        print(f"Skipping mutation {mutation} because it appears to be malformed.")
        return False

    if mutation[0] == mutation[-1]:
        print(
            f"Skipping mutation {mutation} because the wildtype and mutant residues are the same."
        )
        return False

    return True

In [None]:
def worker(mutation, data):
    try:
        results = ProtBert.analyze_mutation(f"A_{mutation}", data)
    except Exception as error:
        print(f"{error!r}")
        return None
    else:
        results = {f"protbert_core_{key}": value for key, value in results.items()}
        return results

In [None]:
results = []
for tup in tqdm(input_df.itertuples(index=False), total=len(input_df)):
    assert all([(len(getattr(tup, field)) == len(tup.mutation)) for field in iterable_fields])

    data = ProtBert.build(tup.sequence, ligand_sequence=None)

    for mutation_idx, mutation in enumerate(tqdm(tup.mutation, leave=False)):
        if not validate_mutation(mutation):
            continue

        result = worker(mutation, data)
        if result is None:
            continue

        results.append(
            {
                "protein_id": getattr(tup, protein_id_column),
                "mutation": mutation,
            }
            | {field: getattr(tup, field)[mutation_idx] for field in iterable_fields}
            | result
        )

results_df = pd.DataFrame(results)

In [None]:
display(results_df.head(2))
print(len(results_df))

### Save results

In [None]:
output_file = NOTEBOOK_DIR.joinpath(DATASET_NAME, f"shard-{TASK_ID}-of-{TASK_COUNT}.parquet")
output_file.parent.mkdir(exist_ok=True)

output_file

In [None]:
if not DEBUG:
    pq.write_table(pa.Table.from_pandas(results_df, preserve_index=False), output_file)