## Summary

Calculate features using [ProteinSolver](https://gitlab.com/ostrokach/proteinsolver).

---

## Imports

In [None]:
import concurrent.futures
import os
import re
import socket
import subprocess
import sys
import tempfile
from pathlib import Path

import kmbio
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from ev2.plugins.proteinsolver import (
    ProteinSolver,
    ProteinSolverAnalyzeError,
    ProteinSolverBuildError,
)
from kmbio import PDB
from tqdm.notebook import tqdm

In [None]:
ProteinSolver.load_model()

## Parameters

In [None]:
NOTEBOOK_DIR = Path("02_run_proteinsolver").resolve()
NOTEBOOK_DIR.mkdir(exist_ok=True)

NOTEBOOK_DIR

In [None]:
if "DATAPKG_OUTPUT_DIR" in os.environ:
    OUTPUT_DIR = Path(os.getenv("DATAPKG_OUTPUT_DIR")).joinpath("elaspic-v2").resolve()
else:
    OUTPUT_DIR = NOTEBOOK_DIR.parent
OUTPUT_DIR.mkdir(exist_ok=True)

OUTPUT_DIR

In [None]:
if (slurm_tmpdir := os.getenv("SLURM_TMPDIR")) is not None:
    os.environ["TMPDIR"] = slurm_tmpdir
    
print(tempfile.gettempdir())

In [None]:
if "scinet" in socket.gethostname():
    CPU_COUNT = 40
else:
    CPU_COUNT = max(1, len(os.sched_getaffinity(0)))

CPU_COUNT = CPU_COUNT // 2

CPU_COUNT

In [None]:
DATASET_NAME = os.getenv("DATASET_NAME")
TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")
TASK_COUNT = os.getenv("ORIGINAL_ARRAY_TASK_COUNT") or os.getenv("SLURM_ARRAY_TASK_COUNT")

TASK_ID = int(TASK_ID) if TASK_ID is not None else None
TASK_COUNT = int(TASK_COUNT) if TASK_COUNT is not None else None

DATASET_NAME, TASK_ID, TASK_COUNT

In [None]:
DEBUG = TASK_ID is None

if DEBUG:
    DATASET_NAME = "elaspic-training-set-interface"
    TASK_ID = 1
    TASK_COUNT = 26
else:
    assert DATASET_NAME is not None
    assert TASK_ID is not None
    assert TASK_COUNT is not None

DATASET_NAME, TASK_ID, TASK_COUNT

## Workspace

### Load data

In [None]:
input_file = OUTPUT_DIR.joinpath("01_load_data", f"{DATASET_NAME}.parquet")

input_file

In [None]:
pfile = pq.ParquetFile(input_file)

pfile.num_row_groups

In [None]:
assert TASK_COUNT == pfile.num_row_groups

In [None]:
INPUT_DF = pfile.read_row_group(TASK_ID - 1).to_pandas(integer_object_nulls=True)

In [None]:
display(INPUT_DF.head(2))
print(len(INPUT_DF))

### Create tasks

In [None]:
output_dir = OUTPUT_DIR.joinpath(NOTEBOOK_DIR.name)
output_dir.mkdir(exist_ok=True)

output_dir

In [None]:
tasks = []
for row in tqdm(INPUT_DF.itertuples(), total=len(INPUT_DF)):

    with tempfile.NamedTemporaryFile(suffix=".pdb") as tmp_file:
        with open(tmp_file.name, "wt") as fout:
            fout.write(row.protein_structure)
        try:
            data_core = ProteinSolver.build(tmp_file.name, row.protein_sequence, None)
            if row.ligand_sequence is not None:
                data_interface = ProteinSolver.build(
                    tmp_file.name, row.protein_sequence, row.ligand_sequence
                )
            else:
                data_interface = None
        except ProteinSolverBuildError as e:
            print(e)
            continue

    _seen = set()
    for idx in range(len(row.mutation)):
        mutation = row.mutation[idx]
        if mutation in _seen:
            print(
                f"Already added mutation '{mutation}' for protein ({row.unique_id}, {row.dataset}, {row.name})."
            )
            continue
        _seen.add(mutation)

        aa = "GVALICMFWPDESTYQNKRH"
        if re.search(f"^[{aa}][1-9]+[0-9]*[{aa}]$", mutation) is None:
            print(f"Skipping mutation {mutation} because it appears to be malformed.")
            continue

        data_mut = {
            "unique_id": row.unique_id,
            "mutation": row.mutation[idx],
            "effect": row.effect[idx],
        }

        tasks.append((data_core, data_interface, data_mut))

len(tasks)

In [None]:
tasks[0]

### Evaluate mutations

In [None]:
def worker(input):
    data_core, data_interface, data_mut = input
    mutation = data_mut["mutation"]
    try:
        results_core = ProteinSolver.analyze_mutation(f"A_{mutation}", data_core)
        if data_interface is not None:
            results_interface = ProteinSolver.analyze_mutation(f"A_{mutation}", data_interface)
        else:
            results_interface = {}
    except ProteinSolverAnalyzeError as e:
        print(e)
        return None

    results = {
        **data_mut,
        **{f"proteinsolver_core_{key}": value for key, value in results_core.items()},
        **{f"proteinsolver_interface_{key}": value for key, value in results_interface.items()},
    }
    return results

In [None]:
worker(tasks[0])

In [None]:
with concurrent.futures.ProcessPoolExecutor(CPU_COUNT) as pool:
    results = list(tqdm(pool.map(worker, tasks), total=len(tasks)))

results_df = pd.DataFrame([l for l in results if l is not None])

In [None]:
output_file = output_dir.joinpath(f"{DATASET_NAME}-{TASK_ID}-{TASK_COUNT}.parquet")

output_file

In [None]:
pq.write_table(pa.Table.from_pandas(results_df, preserve_index=False), output_file)