## Summary

---

## Imports

In [1]:
%env TF_FORCE_UNIFIED_MEMORY=1
%env XLA_PYTHON_CLIENT_MEM_FRACTION=4.0

env: TF_FORCE_UNIFIED_MEMORY=1
env: XLA_PYTHON_CLIENT_MEM_FRACTION=4.0


In [2]:
import os
import tempfile
import urllib.request
from datetime import datetime
from pathlib import Path

import dotenv
import elaspic2 as el2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import seaborn as sns
import torch
from kmbio import PDB
from kmtools import structure_tools
from tqdm.auto import tqdm



## Parameters

In [3]:
NOTEBOOK_DIR = Path("35_cagi6_hmbs_alphafold").resolve()
NOTEBOOK_DIR.mkdir(exist_ok=True)

NOTEBOOK_DIR

PosixPath('/scratch/strokach/workspace/elaspic2-cagi6/notebooks/35_cagi6_hmbs_alphafold')

In [4]:
UNIPROT_ID = "P08397"

UNIPROT_ID

'P08397'

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device

device(type='cuda')

In [6]:
version = datetime.now().isoformat(timespec="hours")

version

'2021-09-28T19'

## Load data

In [7]:
def load_sequence(sequence_file):
    with sequence_file.open("rt") as fin:
        data = fin.read()
    chunks = []
    for line in data.split("\n"):
        if line.startswith(">"):
            continue
        chunks.append(line.strip())
    return "".join(chunks)

In [8]:
sequence_file = NOTEBOOK_DIR.parent.joinpath("30_cagi6_hmbs", f"{UNIPROT_ID}.fasta")

assert sequence_file.is_file()

In [9]:
sequence = load_sequence(sequence_file)

sequence[:5]

'MSGNG'

In [10]:
structure_file = NOTEBOOK_DIR.parent.joinpath("30_cagi6_hmbs", f"{UNIPROT_ID}.pdb")

assert structure_file.is_file()

In [11]:
alignment_file = NOTEBOOK_DIR.parent.joinpath("30_cagi6_hmbs", f"{UNIPROT_ID}.a3m.gz")

alignment = pq.read_table(alignment_file).to_pandas()["alignment"].values.tolist()

In [12]:
results_to_fill_file = NOTEBOOK_DIR.parent.joinpath("30_cagi6_hmbs", "results-to-fill.parquet")

results_to_fill_df = pq.read_table(results_to_fill_file).to_pandas()

In [13]:
results_to_fill_df.head(2)

Unnamed: 0,aa_substitution,score,sd,comments,mut
0,p.Ala112Arg,*,*,*,A112R
1,p.Ala112Asn,*,*,*,A112N


## Run AlphaFold WT

In [14]:
def predictions_to_embeddings(predictions):
    return {
        "plddt": predictions["plddt"],
        "max_predicted_aligned_error": predictions["max_predicted_aligned_error"].item(),
        "ptm": predictions["ptm"].item(),
        #
        "experimentally_resolved": predictions["experimentally_resolved"]["logits"]
        .to_py()
        .tolist(),
        "predicted_lddt": predictions["predicted_lddt"]["logits"].to_py().tolist(),
        #
        "msa_first_row": predictions["representations"]["msa_first_row"].to_py().tolist(),
        "single": predictions["representations"]["single"].to_py().tolist(),
        "structure_module": predictions["representations"]["structure_module"].to_py().tolist(),
    }

In [15]:
def get_mutation_embeddings(idx, predictions):
    assert idx >= 0

    def as_residue(x):
        x = np.asarray(x)
        return x[idx].astype(np.float32)

    def as_protein(x):
        x = np.asarray(x)
        return x.mean(axis=0).astype(np.float32)

    embeddings = {
        "experimentally_resolved": predictions["experimentally_resolved"],
        "predicted_lddt": predictions["predicted_lddt"],
        "msa_first_row": predictions["msa_first_row"],
        "single": predictions["single"],
        "structure_module": predictions["structure_module"],
    }

    output = {
        "scores_residue_plddt": predictions["plddt"][idx],
        "scores_protein_plddt": np.mean(predictions["plddt"]),
        "scores_protein_max_predicted_aligned_error": predictions["max_predicted_aligned_error"],
        "scores_proten_ptm": predictions["ptm"],
        **{f"features_residue_{key}": as_residue(value) for key, value in embeddings.items()},
        **{f"features_protein_{key}": as_protein(value) for key, value in embeddings.items()},
    }

    return output

In [16]:
alphafold_wt_results_file = NOTEBOOK_DIR.joinpath("alphafold-wt-results.parquet")

if alphafold_wt_results_file.is_file():
    alphafold_wt_results_df = pq.read_table(alphafold_wt_results_file).to_pandas()
else:
    from elaspic2.plugins.alphafold import AlphaFold, AlphaFoldAnalyzeError, AlphaFoldBuildError
    from jax.lib import xla_bridge

    jax_device = xla_bridge.get_backend().platform
    print(f"Device: {jax_device}")

    AlphaFold.load_model(device=jax_device)
    data = AlphaFold.build(sequence, ligand_sequence=None, msa=alignment)
    predictions = predictions_to_embeddings(data.predictions)
    features_list = [
        (
            {"mut": mut}
            | {
                f"alphafold_core_{key}_wt": value
                for key, value in get_mutation_embeddings(int(mut[1:-1]) - 1, predictions).items()
            }
        )
        for mut in results_to_fill_df["mut"]
    ]
    alphafold_wt_results_df = pd.DataFrame(features_list)
    pq.write_table(
        pa.Table.from_pandas(alphafold_wt_results_df, preserve_index=False),
        alphafold_wt_results_file,
    )

In [17]:
display(alphafold_wt_results_df.head(2))
print(len(alphafold_wt_results_df))

Unnamed: 0,mut,alphafold_core_scores_residue_plddt_wt,alphafold_core_scores_protein_plddt_wt,alphafold_core_scores_protein_max_predicted_aligned_error_wt,alphafold_core_scores_proten_ptm_wt,alphafold_core_features_residue_experimentally_resolved_wt,alphafold_core_features_residue_predicted_lddt_wt,alphafold_core_features_residue_msa_first_row_wt,alphafold_core_features_residue_single_wt,alphafold_core_features_residue_structure_module_wt,alphafold_core_features_protein_experimentally_resolved_wt,alphafold_core_features_protein_predicted_lddt_wt,alphafold_core_features_protein_msa_first_row_wt,alphafold_core_features_protein_single_wt,alphafold_core_features_protein_structure_module_wt
0,A112R,96.101048,87.854437,31.75,0.869775,"[2.765605, 2.885155, 3.4066327, 2.7455544, 3.4...","[-11.647073, -12.94594, -11.152253, -9.916887,...","[0.5386157, 8.812037, -3.058997, 4.732732, -4....","[29.820736, 49.839714, 53.33713, -28.858473, 2...","[-0.00066452473, 0.011935189, -0.005865535, 0....","[2.5443485, 2.652794, 3.0835412, 2.5492835, 3....","[-8.669228, -9.790952, -8.569939, -7.6671114, ...","[4.0051045, 6.9503894, 0.70320404, 0.7040772, ...","[17.930984, 18.186771, 6.970354, 3.3894393, 8....","[0.00042588258, 0.010626375, -0.0059233517, 0...."
1,A112N,96.101048,87.854437,31.75,0.869775,"[2.765605, 2.885155, 3.4066327, 2.7455544, 3.4...","[-11.647073, -12.94594, -11.152253, -9.916887,...","[0.5386157, 8.812037, -3.058997, 4.732732, -4....","[29.820736, 49.839714, 53.33713, -28.858473, 2...","[-0.00066452473, 0.011935189, -0.005865535, 0....","[2.5443485, 2.652794, 3.0835412, 2.5492835, 3....","[-8.669228, -9.790952, -8.569939, -7.6671114, ...","[4.0051045, 6.9503894, 0.70320404, 0.7040772, ...","[17.930984, 18.186771, 6.970354, 3.3894393, 8....","[0.00042588258, 0.010626375, -0.0059233517, 0...."


6239


## Run AlphaFold

In [18]:
# def worker(mutation, data):
#     try:
#         results = AlphaFold.analyze_mutation(f"A_{mutation}", data)
#     except Exception as error:
#         print(f"{error!r}")
#         return None
#     else:
#         results = {f"alphafold_core_{key}": value for key, value in results.items()}
#         return results

In [19]:
# alphafold_results_file = NOTEBOOK_DIR.joinpath("alphafold-results.parquet")

# if alphafold_results_file.is_file():
#     alphafold_results_df = pq.read_table(alphafold_results_file).to_pandas()
# else:
#     from elaspic2.plugins.alphafold import AlphaFold, AlphaFoldAnalyzeError, AlphaFoldBuildError
#     from jax.lib import xla_bridge

#     jax_device = xla_bridge.get_backend().platform
#     print(f"Device: {jax_device}")

#     AlphaFold.load_model(device=jax_device)
#     data = AlphaFold.build(sequence, ligand_sequence=None, msa=alignment)

#     result_list = []
#     for mut in tqdm(results_to_fill_df["mut"], leave=False):
#         result = worker(mut, data)
#         result_list.append({"mut": mut} | result)
#     alphafold_results_df = pd.DataFrame(result_list)

#     pq.write_table(
#         pa.Table.from_pandas(alphafold_results_df, preserve_index=False), alphafold_results_file
#     )

In [20]:
# display(alphafold_results_df.head(2))
# print(len(alphafold_results_df))

## Run ProteinSolver

In [21]:
def worker(mutation, data):
    try:
        results = ProteinSolver.analyze_mutation(f"A_{mutation}", data)
    except Exception as error:
        print(f"{error!r}")
        return None
    else:
        results = {f"proteinsolver_core_{key}": value for key, value in results.items()}
        return results

In [22]:
proteinsolver_results_file = NOTEBOOK_DIR.joinpath("proteinsolver-results.parquet")

if proteinsolver_results_file.is_file():
    proteinsolver_results_df = pq.read_table(proteinsolver_results_file).to_pandas()
else:
    from elaspic2.plugins.proteinsolver import ProteinSolver

    ProteinSolver.load_model(device=device)
    
    data = ProteinSolver.build(structure_file, sequence, None, remove_hetatms=False)

    result_list = []
    for mut in tqdm(results_to_fill_df["mut"], leave=False):
        result = worker(mut, data)
        result_list.append({"mut": mut} | result)
    proteinsolver_results_df = pd.DataFrame(result_list)

    pq.write_table(
        pa.Table.from_pandas(proteinsolver_results_df, preserve_index=False),
        proteinsolver_results_file,
    )

In [23]:
display(proteinsolver_results_df.head(2))
print(len(proteinsolver_results_df))

Unnamed: 0,mut,proteinsolver_core_score_wt,proteinsolver_core_score_mut,proteinsolver_core_features_residue_wt,proteinsolver_core_features_protein_wt,proteinsolver_core_features_residue_mut,proteinsolver_core_features_protein_mut
0,A112R,0.810171,0.000203,"[-4.356575965881348, 0.01211337000131607, -0.5...","[-1.539688229560852, 0.31123825907707214, 0.07...","[-0.23516525328159332, -0.1763198971748352, 0....","[-1.528090000152588, 0.3134782314300537, 0.076..."
1,A112N,0.810172,0.000319,"[-4.356573581695557, 0.01211315393447876, -0.5...","[-1.539688229560852, 0.31123828887939453, 0.07...","[2.0055434703826904, 0.43138548731803894, -0.7...","[-1.5259792804718018, 0.3278365731239319, 0.07..."


6239


## Run ProtBert

In [24]:
def worker(mutation, data):
    try:
        results = ProtBert.analyze_mutation(f"A_{mutation}", data)
    except Exception as error:
        print(f"{error!r}")
        return None
    else:
        results = {f"protbert_core_{key}": value for key, value in results.items()}
        return results

In [25]:
protbert_results_file = NOTEBOOK_DIR.joinpath("protbert-results.parquet")

if protbert_results_file.is_file():
    protbert_results_df = pq.read_table(protbert_results_file).to_pandas()
else:
    from elaspic2.plugins.protbert import ProtBert

    ProtBert.load_model(device=device)
    data = ProtBert.build(sequence, ligand_sequence=None)

    result_list = []
    for mut in tqdm(results_to_fill_df["mut"], leave=False):
        result = worker(mut, data)
        result_list.append({"mut": mut} | result)
    protbert_results_df = pd.DataFrame(result_list)

    pq.write_table(
        pa.Table.from_pandas(protbert_results_df, preserve_index=False), protbert_results_file
    )

In [26]:
display(protbert_results_df.head(2))
print(len(protbert_results_df))

Unnamed: 0,mut,protbert_core_score_wt,protbert_core_score_mut,protbert_core_features_residue_wt,protbert_core_features_protein_wt,protbert_core_features_residue_mut,protbert_core_features_protein_mut
0,A112R,0.936817,6e-06,"[0.0668850839138031, -0.07388728111982346, -0....","[-0.00546208256855607, -0.013093064539134502, ...","[0.08266019821166992, -0.07282465696334839, -0...","[-0.0036239889450371265, -0.013208648189902306..."
1,A112N,0.936817,2.6e-05,"[0.0668850839138031, -0.07388728111982346, -0....","[-0.00546208256855607, -0.013093064539134502, ...","[0.06758468598127365, -0.06582294404506683, -0...","[-0.003480394370853901, -0.013589110225439072,..."


6239


## Run MSA