# Summary

Run EL2 to calculate affinity for rows in the ELASPIC database.

### Executing

```bash
export NOTEBOOK_PATH="$(realpath 20_el2_affinity.ipynb)"
export DATASET_NAME="elaspic-interface-mutation-local"
export ORIGINAL_ARRAY_TASK_COUNT=9
sbatch --export=DATASET_NAME,NOTEBOOK_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=1-9 --ntasks-per-node=40 --mem=0 ../scripts/run_notebook_cpu.sh

export NOTEBOOK_PATH="$(realpath 20_el2_affinity.ipynb)"
export DATASET_NAME="uniprot-domain-pair-mutation"
export ORIGINAL_ARRAY_TASK_COUNT=1358
sbatch --export=DATASET_NAME,NOTEBOOK_PATH,ORIGINAL_ARRAY_TASK_COUNT --array=1-1358 --ntasks-per-node=48 ../scripts/run_notebook_cpu.sh

# On Cedar
--ntasks-per-node=48
 
# On Niagara,
--ntasks-per-node=40 --mem=0
```

---

## Imports

In [None]:
import os
import socket
import tempfile
from pathlib import Path

import elaspic2 as el2
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from kmbio import PDB
from kmtools import structure_tools
from tqdm.notebook import tqdm

## Parameters

In [None]:
NOTEBOOK_DIR = Path("20_el2_affinity").resolve()
NOTEBOOK_DIR.mkdir(exist_ok=True)

NOTEBOOK_DIR

In [None]:
if "DATAPKG_OUTPUT_DIR" in os.environ:
    OUTPUT_DIR = Path(os.getenv("DATAPKG_OUTPUT_DIR")).joinpath("elaspic2").resolve()
else:
    OUTPUT_DIR = NOTEBOOK_DIR.parent
OUTPUT_DIR.mkdir(exist_ok=True)

OUTPUT_DIR

In [None]:
if (slurm_tmpdir := os.getenv("SLURM_TMPDIR")) is not None:
    os.environ["TMPDIR"] = slurm_tmpdir
    
print(tempfile.gettempdir())

In [None]:
if "scinet" in socket.gethostname():
    CPU_COUNT = 40
else:
    CPU_COUNT = max(1, len(os.sched_getaffinity(0)))

CPU_COUNT = max(1, CPU_COUNT // 2)

CPU_COUNT

In [None]:
DATASET_NAME = os.getenv("DATASET_NAME")
TASK_ID = os.getenv("SLURM_ARRAY_TASK_ID")
TASK_COUNT = os.getenv("ORIGINAL_ARRAY_TASK_COUNT") or os.getenv("SLURM_ARRAY_TASK_COUNT")
TASK_ID_OFFSET = os.getenv("SLURM_TASK_ID_OFFSET")

TASK_ID = int(TASK_ID) if TASK_ID is not None else None
TASK_COUNT = int(TASK_COUNT) if TASK_COUNT is not None else None
TASK_ID_OFFSET = int(TASK_ID_OFFSET) if TASK_ID_OFFSET is not None else 0

TASK_ID += TASK_ID_OFFSET

DATASET_NAME, TASK_ID, TASK_COUNT, TASK_ID_OFFSET

In [None]:
DEBUG = TASK_ID is None

if DEBUG:
    DATASET_NAME = "uniprot-domain-pair-mutation"
    TASK_ID = 50
    TASK_COUNT = 1358
else:
    assert DATASET_NAME is not None
    assert TASK_ID is not None
    assert TASK_COUNT is not None

DATASET_NAME, TASK_ID, TASK_COUNT

## Workspace

### Load data

In [None]:
input_file = OUTPUT_DIR.joinpath(
    "..", "elaspic-data", "12_el2_to_recalculate", f"{DATASET_NAME}.parquet"
).resolve(strict=True)

input_file

In [None]:
pfile = pq.ParquetFile(input_file)

pfile.num_row_groups

In [None]:
assert TASK_COUNT == pfile.num_row_groups, (TASK_COUNT, pfile.num_row_groups)

In [None]:
INPUT_DF = pfile.read_row_group(TASK_ID - 1).to_pandas(integer_object_nulls=True)

In [None]:
display(INPUT_DF.head(2))
print(len(INPUT_DF))

### Create tasks

In [None]:
model = el2.ELASPIC2()

In [None]:
def map_mutation_to_chain(structure, chain_id, mutation):
    df = structure.to_dataframe()
    chain_df = df[df["chain_id"] == chain_id]

    residue_idx_corrected = chain_df["residue_idx"] - chain_df["residue_idx"].min()
    residue_idx_map = {
        old_residue_idx: new_residue_idx
        for (old_residue_idx, new_residue_idx) in zip(
            chain_df["residue_idx"], residue_idx_corrected
        )
    }

    pos = int(mutation[1:-1])
    pos_new = residue_idx_map[pos - 1] + 1

    wt = mutation[0]
    mut = mutation[-1]
    return f"{wt}{pos_new}{mut}"

In [None]:
def prepare_row(tup):
    if not tup.structure.strip():
        return None

    with tempfile.NamedTemporaryFile(suffix=".pdb") as structure_file_obj:
        with open(structure_file_obj.name, "wt") as fout:
            fout.write(tup.structure)
        structure = PDB.load(structure_file_obj.name)
        protein_sequence = structure_tools.get_chain_sequence(
            structure[0][tup.chain_modeller], if_unknown="replace", unknown_residue_marker=""
        )

        mutation = map_mutation_to_chain(structure, tup.chain_modeller, tup.mutation_modeller)
        wt_aa = mutation[0]
        pos = int(mutation[1:-1])

        if len(protein_sequence) < pos or protein_sequence[pos - 1] != wt_aa:
            print(f"Protein sequence does not match mutation")
            return None

        ligand_sequence = ""
        for chain in structure[0].chains:
            if chain.id == tup.chain_modeller:
                continue
            ligand_sequence = structure_tools.get_chain_sequence(
                structure[0][chain.id], if_unknown="replace", unknown_residue_marker=""
            )
            if ligand_sequence:
                break
        if not ligand_sequence:
            print(f"Skipping row with no ligand sequence: {tup._replace(structure='')}")
            return None

        protein_stability_features = model.build(
            structure_file=structure_file_obj.name,
            protein_sequence=protein_sequence,
            ligand_sequence=None,
            remove_hetatms=True,
        )
        protein_affinity_features = model.build(
            structure_file=structure_file_obj.name,
            protein_sequence=protein_sequence,
            ligand_sequence=ligand_sequence,
            remove_hetatms=True,
        )
    mutation_stability_features = model.analyze_mutation(mutation, protein_stability_features)
    mutation_affinity_features = model.analyze_mutation(mutation, protein_affinity_features)

    # Get final predictions
    row = tup._asdict()
    del row["Index"], row["model_filename_wt"], row["structure"]

    row["protbert_score"] = (
        mutation_affinity_features["protbert_interface_score_wt"]
        - mutation_affinity_features["protbert_interface_score_mut"]
    )
    row["proteinsolver_score"] = mutation_affinity_features["proteinsolver_interface_score_wt"]
    row["el2_score"] = model.predict_mutation_effect(
        [mutation_stability_features], [mutation_affinity_features]
    ).item()

    return row

In [None]:
results = []
for tup in tqdm(INPUT_DF.itertuples(), total=len(INPUT_DF)):
    try:
        row = prepare_row(tup)
    except Exception as e:
        print(f"Encountered an error: {e}")
    else:
        if row is not None:
            results.append(row)

In [None]:
results_df = pd.DataFrame(results)

results_df.head()

In [None]:
output_file = OUTPUT_DIR.joinpath(
    NOTEBOOK_DIR.name, DATASET_NAME, f"{DATASET_NAME}-{TASK_ID:04d}-{TASK_COUNT:04d}.parquet"
)
output_file.parent.mkdir(exist_ok=True, parents=True)

output_file

In [None]:
pq.write_table(pa.Table.from_pandas(results_df, preserve_index=False), output_file)

In [None]:
with output_file.with_suffix(".SUCCESS").open("w") as fout:
    pass