## Summary

---

## Imports

In [1]:
%env CUDA_VISIBLE_DEVICES=2
%env TF_FORCE_UNIFIED_MEMORY=1
%env XLA_PYTHON_CLIENT_MEM_FRACTION=4.0

env: CUDA_VISIBLE_DEVICES=2
env: TF_FORCE_UNIFIED_MEMORY=1
env: XLA_PYTHON_CLIENT_MEM_FRACTION=4.0


In [2]:
import concurrent.futures
import contextlib
import itertools
import os
import tempfile
import urllib.request
from datetime import datetime
from pathlib import Path

import dotenv
import elaspic2 as el2
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
import seaborn as sns
import torch
from elaspic2.plugins.rosetta_ddg import RosettaDDG
from kmbio import PDB
from kmtools import structure_tools
from tqdm.auto import tqdm



## Parameters

In [3]:
NOTEBOOK_DIR = Path("40_cagi6_cam_submission").resolve()
NOTEBOOK_DIR.mkdir(exist_ok=True)

NOTEBOOK_DIR

PosixPath('/scratch/strokach/workspace/elaspic2-cagi6/notebooks/30_cagi6_cam')

In [4]:
UNIPROT_ID = "P0DP23"

UNIPROT_ID

'P0DP23'

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

device

device(type='cuda')

In [6]:
version = datetime.now().isoformat(timespec="hours")

version

'2021-09-28T12'

## Download data

In [7]:
def download(url, filename):
    urllib.request.urlretrieve(url, filename)

In [8]:
def load_sequence(sequence_file):
    with sequence_file.open("rt") as fin:
        data = fin.read()
    chunks = []
    for line in data.split("\n"):
        if line.startswith(">"):
            continue
        chunks.append(line.strip())
    return "".join(chunks)

In [9]:
if not NOTEBOOK_DIR.joinpath("cam_predictions.tsv").is_file():
    import synapseclient
    import synapseutils

    dotenv.load_dotenv("../.env")
    syn = synapseclient.Synapse()
    syn.login(os.environ["SYNAPSE_USERNAME"], os.environ["SYNAPSE_PASSWORD"])
    _ = synapseutils.syncFromSynapse(syn, "syn25835637", path=NOTEBOOK_DIR)
    _ = subprocess.run(
        ["unzip", "-o", "CAGI6_CaM_challenge_Training_Set.zip"], cwd=NOTEBOOK_DIR, check=True
    )

In [10]:
sequence_file = NOTEBOOK_DIR.joinpath(f"{UNIPROT_ID}.fasta")

if not sequence_file.is_file():
    download(f"https://www.uniprot.org/uniprot/{UNIPROT_ID}.fasta", sequence_file)

sequence = load_sequence(sequence_file)

In [11]:
structure_file = NOTEBOOK_DIR.joinpath(f"{UNIPROT_ID}.pdb")

if not structure_file.is_file():
    download(f"https://alphafold.ebi.ac.uk/files/AF-{UNIPROT_ID}-F1-model_v1.pdb", structure_file)

with structure_file.open("r") as fin:
    structure_blob = fin.read()

In [12]:
alignment_file = NOTEBOOK_DIR.joinpath(f"{UNIPROT_ID}.a3m.gz")

if not alignment_file.is_file():
    from elaspic2.plugins.alphafold import mmseqs2

    dotenv.load_dotenv("../.env")
    with mmseqs2.api_gateway(mmseqs2.MMSEQS2_HOST_URL) as gateway:
        alignment = mmseqs2.run_mmseqs2(sequence, gateway=gateway)
        assert alignment[1] == f"{sequence}\n"
        alignment_df = pd.DataFrame({"alignment": alignment})
        pq.write_table(pa.Table.from_pandas(alignment_df, preserve_index=False), alignment_file)

alignment = pq.read_table(alignment_file).to_pandas()["alignment"].values.tolist()

## Load data

In [13]:
def mutation_matches_sequence(mutation, sequence):
    wt, pos, mut = mutation[0], mutation[1:-1], mutation[-1]
    pos = int(pos)
    return sequence[pos - 1] == wt

In [14]:
def sequence_matches_structure(sequence, structure_blob):
    with tempfile.NamedTemporaryFile(suffix=".pdb") as tmp_file:
        with open(tmp_file.name, "wt") as fout:
            fout.write(structure_blob)
        structure = PDB.load(tmp_file.name)

    chain_sequence = structure_tools.get_chain_sequence(
        structure[0]["A"], if_unknown="replace", unknown_residue_marker=""
    )
    return sequence == chain_sequence

In [15]:
result_template_df = pd.read_csv(NOTEBOOK_DIR.joinpath("cam_predictions.tsv"), sep="\t")

result_template_df.head(2)

Unnamed: 0,CaM-variant,Tm,Standard-Deviation-Tm,%Unfold,Standard-Deviation-%Unfold,Stabilizing-vs-Destabilizing,Comments
0,p.N54I,*,*,*,*,*,*
1,p.F90L,*,*,*,*,*,*


In [16]:
result_df = result_template_df.copy()
result_df["mut"] = result_df["CaM-variant"].str[2:]

In [17]:
assert all([mutation_matches_sequence(mut, sequence) for mut in result_df["mut"]])

In [18]:
assert sequence_matches_structure(sequence, structure_blob)

In [19]:
assert alignment[1] == f"{sequence}\n"

## Run Rosetta

In [20]:
rosetta_ddg_data = RosettaDDG.build(
    structure_file,
    protocol="cartesian_ddg",
    energy_function="beta_nov16_cart",
    interface=0,
)

In [21]:
def rosetta_ddg_worker(mut, data):
    results = RosettaDDG.analyze_mutation(f"A_{mut}", data)
    results = {"mut": mut} | {f"rosetta_{key}": value for key, value in results.items()}
    return results

In [22]:
rosetta_results_file = NOTEBOOK_DIR.joinpath("rosetta-results.parquet")

if rosetta_results_file.is_file():
    rosetta_results_df = pq.read_table(rosetta_results_file).to_pandas()
else:
    with concurrent.futures.ThreadPoolExecutor(len(result_df)) as pool:
        rosetta_results = list(
            tqdm(
                pool.map(
                    rosetta_ddg_worker,
                    result_df["mut"].values.tolist(),
                    itertools.repeat(rosetta_ddg_data),
                ),
                total=len(result_df),
            )
        )
    rosetta_results_df = pd.DataFrame(rosetta_results)
    pq.write_table(
        pa.Table.from_pandas(rosetta_results_df, preserve_index=False), rosetta_results_file
    )

## Run AlphaFold

In [23]:
def worker(mutation, data):
    try:
        results = AlphaFold.analyze_mutation(f"A_{mutation}", data)
    except Exception as error:
        print(f"{error!r}")
        return None
    else:
        results = {f"alphafold_core_{key}": value for key, value in results.items()}
        return results

In [24]:
alphafold_results_file = NOTEBOOK_DIR.joinpath("alphafold-results.parquet")

if alphafold_results_file.is_file():
    alphafold_results_df = pq.read_table(alphafold_results_file).to_pandas()
else:
    from elaspic2.plugins.alphafold import AlphaFold, AlphaFoldAnalyzeError, AlphaFoldBuildError
    from jax.lib import xla_bridge

    jax_device = xla_bridge.get_backend().platform
    print(f"Device: {jax_device}")

    AlphaFold.load_model(device=jax_device)
    data = AlphaFold.build(sequence, ligand_sequence=None, msa=alignment)

    result_list = []
    for mut in tqdm(result_df["mut"], leave=False):
        result = worker(mut, data)
        result_list.append({"mut": mut} | result)
    alphafold_results_df = pd.DataFrame(result_list)

    pq.write_table(
        pa.Table.from_pandas(alphafold_results_df, preserve_index=False), alphafold_results_file
    )

In [25]:
display(alphafold_results_df.head(2))
print(len(alphafold_results_df))

Unnamed: 0,mut,alphafold_core_scores_residue_plddt_wt,alphafold_core_scores_protein_plddt_wt,alphafold_core_scores_protein_max_predicted_aligned_error_wt,alphafold_core_scores_proten_ptm_wt,alphafold_core_features_residue_experimentally_resolved_wt,alphafold_core_features_residue_predicted_lddt_wt,alphafold_core_features_residue_msa_first_row_wt,alphafold_core_features_residue_single_wt,alphafold_core_features_residue_structure_module_wt,...,alphafold_core_features_residue_experimentally_resolved_mut,alphafold_core_features_residue_predicted_lddt_mut,alphafold_core_features_residue_msa_first_row_mut,alphafold_core_features_residue_single_mut,alphafold_core_features_residue_structure_module_mut,alphafold_core_features_protein_experimentally_resolved_mut,alphafold_core_features_protein_predicted_lddt_mut,alphafold_core_features_protein_msa_first_row_mut,alphafold_core_features_protein_single_mut,alphafold_core_features_protein_structure_module_mut
0,N54I,72.424267,64.687376,31.75,0.416441,"[3.6272318, 3.6706228, 4.516265, 3.5398533, 4....","[-7.5897775, -8.689957, -7.7136374, -7.0810175...","[1.8796259, 10.791776, 0.033692777, 12.273933,...","[-8.44626, 15.287777, -11.845451, 1.3104451, -...","[-0.004927762, 0.00014060736, -0.0058904495, 0...",...,"[3.2068305, 3.2458231, 3.9859018, 3.062663, 3....","[-7.8205338, -8.494905, -7.730364, -7.1157904,...","[1.0629023, 9.450185, -1.1117198, 10.445597, -...","[1.0913221, 9.465578, -11.463981, -8.050004, -...","[-0.011385553, 0.013329156, -0.006098194, 0.00...","[1.4339311, 1.4835172, 1.6852022, 1.3888673, 1...","[-7.1293554, -8.000177, -7.0753617, -6.3793783...","[-0.6908954, 3.7915826, 2.4891148, 5.101147, -...","[14.626985, 8.639777, -2.4109125, 2.0749578, 0...","[0.0022276677, 0.0090920245, -0.0059451843, 0...."
1,F90L,63.138818,64.687376,31.75,0.416441,"[2.1501245, 2.2867854, 2.5959187, 2.258326, 2....","[-6.581604, -7.862013, -6.984727, -6.309676, -...","[5.532667, -6.649637, 14.638169, 5.3919473, -2...","[39.239716, -30.105417, 6.92773, -16.221613, 4...","[-0.0072678104, 0.026987039, -0.006245898, -0....",...,"[2.0710125, 2.218312, 2.5010533, 2.141931, 2.4...","[-7.2874045, -8.337455, -7.369378, -6.665057, ...","[4.8071737, -5.465674, 14.286465, 4.7100663, -...","[34.874672, -27.71313, 5.065748, -20.697674, 3...","[-0.0033393428, 0.023972645, -0.0062532555, -0...","[1.4434994, 1.4932677, 1.6971351, 1.3984808, 1...","[-7.170294, -8.036215, -7.0772285, -6.3579907,...","[-0.70935285, 3.8102627, 2.4673674, 5.082006, ...","[14.407975, 8.669145, -2.267529, 2.1149678, 0....","[0.0020835502, 0.008595544, -0.005956497, 0.00..."


16


## Run ProteinSolver

In [26]:
def worker(mutation, data):
    try:
        results = ProteinSolver.analyze_mutation(f"A_{mutation}", data)
    except Exception as error:
        print(f"{error!r}")
        return None
    else:
        results = {f"proteinsolver_core_{key}": value for key, value in results.items()}
        return results

In [27]:
proteinsolver_results_file = NOTEBOOK_DIR.joinpath("proteinsolver-results.parquet")

if proteinsolver_results_file.is_file():
    proteinsolver_results_df = pq.read_table(proteinsolver_results_file).to_pandas()
else:
    from elaspic2.plugins.proteinsolver import ProteinSolver

    ProteinSolver.load_model(device=device)
    
    data = ProteinSolver.build(structure_file, sequence, None, remove_hetatms=False)

    result_list = []
    for mut in tqdm(result_df["mut"], leave=False):
        result = worker(mut, data)
        result_list.append({"mut": mut} | result)
    proteinsolver_results_df = pd.DataFrame(result_list)

    pq.write_table(
        pa.Table.from_pandas(proteinsolver_results_df, preserve_index=False),
        proteinsolver_results_file,
    )

In [28]:
display(proteinsolver_results_df.head(2))
print(len(proteinsolver_results_df))

Unnamed: 0,mut,proteinsolver_core_score_wt,proteinsolver_core_score_mut,proteinsolver_core_features_residue_wt,proteinsolver_core_features_protein_wt,proteinsolver_core_features_residue_mut,proteinsolver_core_features_protein_mut
0,N54I,0.210572,0.009616,"[1.6711971759796143, 0.8220645189285278, -0.53...","[-1.2827523946762085, 0.29683440923690796, -0....","[-3.301334857940674, 0.29959312081336975, 0.00...","[-1.314837098121643, 0.28501763939857483, -0.0..."
1,F90L,0.553403,0.001049,"[-0.9363932609558105, 0.4986618161201477, -0.4...","[-1.282752275466919, 0.2968343496322632, -0.05...","[-3.6483242511749268, 0.03767576813697815, -0....","[-1.3031216859817505, 0.2846250534057617, -0.0..."


16


## Run ProtBert

In [29]:
def worker(mutation, data):
    try:
        results = ProtBert.analyze_mutation(f"A_{mutation}", data)
    except Exception as error:
        print(f"{error!r}")
        return None
    else:
        results = {f"protbert_core_{key}": value for key, value in results.items()}
        return results

In [30]:
protbert_results_file = NOTEBOOK_DIR.joinpath("protbert-results.parquet")

if protbert_results_file.is_file():
    protbert_results_df = pq.read_table(protbert_results_file).to_pandas()
else:
    from elaspic2.plugins.protbert import ProtBert

    ProtBert.load_model(device=device)
    data = ProtBert.build(sequence, ligand_sequence=None)

    result_list = []
    for mut in tqdm(result_df["mut"], leave=False):
        result = worker(mut, data)
        result_list.append({"mut": mut} | result)
    protbert_results_df = pd.DataFrame(result_list)

    pq.write_table(
        pa.Table.from_pandas(protbert_results_df, preserve_index=False), protbert_results_file
    )

In [31]:
display(protbert_results_df.head(2))
print(len(protbert_results_df))

Unnamed: 0,mut,protbert_core_score_wt,protbert_core_score_mut,protbert_core_features_residue_wt,protbert_core_features_protein_wt,protbert_core_features_residue_mut,protbert_core_features_protein_mut
0,N54I,0.995928,0.000116,"[0.004911681637167931, -0.10091305524110794, 0...","[0.014456912875175476, -0.07412253320217133, 0...","[0.014029847458004951, -0.10067387670278549, 0...","[0.01633981615304947, -0.07155150175094604, 0...."
1,F90L,0.998659,0.000687,"[0.12213721126317978, -0.16561347246170044, 0....","[0.014456912875175476, -0.07412253320217133, 0...","[0.12696442008018494, -0.16875752806663513, 0....","[0.015587416477501392, -0.0756407380104065, 0...."


16


## Run MSA

## Run `ELASPIC2`

### Initialize the `ELASPIC2` model



In [32]:
model = el2.ELASPIC2(device=device)

Some weights of the model checkpoint at /scratch/strokach/workspace/elaspic2/src/elaspic2/plugins/protbert/data/prot_bert_bfd were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### Make predictions

In [33]:
protein_features = model.build(
    structure_file=structure_file,
    protein_sequence=sequence,
    ligand_sequence=None,
    remove_hetatms=True,
)

protein_features

ELASPIC2Data(is_interface=False, protbert_data=ProtBertData(sequence='MADQLTEEQIAEFKEAFSLFDKDGDGTITTKELGTVMRSLGQNPTEAELQDMINEVDADGNGTIDFPEFLTMMARKMKDTDSEEEIREAFRVFDKDGNGYISAAELRHVMTNLGEKLTDEEVDEMIREADIDGDGQVNYEEFVQMMTAK'), proteinsolver_data=Data(edge_attr=[5330, 2], edge_index=[2, 5330], x=[149]))

In [34]:
mutation_features = list(
    tqdm(
        (model.analyze_mutation(mut, protein_features) for mut in result_df["mut"]),
        total=len(result_df),
    )
)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=16.0), HTML(value='')))




In [35]:
# In all cases, higher scores means less stable (same as ΔΔG)
result_df["protbert_score"] = [
    f["protbert_core_score_wt"] - f["protbert_core_score_mut"] for f in mutation_features
]
result_df["proteinsolver_score"] = [
    f["proteinsolver_core_score_wt"] - f["proteinsolver_core_score_mut"] for f in mutation_features
]
result_df["el2_score"] = model.predict_mutation_effect(mutation_features).tolist()

KeyError: 'pca-proteinsolver_core_features_protein_wt-core'

## Combine

In [None]:
rosetta_results_df["rosetta_dg_score"] = rosetta_results_df["rosetta_dg_change"]

In [None]:
try:
    del result_df["rosetta_dg_score"]
except KeyError:
    pass

result_df = result_df.merge(rosetta_results_df[["mut", "rosetta_dg_score"]], on="mut", how="left")

## Visualize

In [None]:
sns.set(style="ticks", color_codes=True)
g = sns.pairplot(
    result_df[["protbert_score", "proteinsolver_score", "el2_score", "rosetta_dg_score"]]
)

## Write results

In [None]:
output_dir = NOTEBOOK_DIR.joinpath("submission")
output_dir.mkdir(exist_ok=True)

output_dir

In [None]:
%%file {output_dir}/strokach_desc.md
# Submission for CAGI6—CAM challenge

## Overview

- `strokach_modelnumber_1.tsv` → Predictions made using ELASPIC2 [1].
- `strokach_modelnumber_2.tsv` → Predictions made using ProteinSolver [2].
- `strokach_modelnumber_3.tsv` → Predictions made using ProtBert [3].
- `strokach_modelnumber_4.tsv` → Predictions made using Rosetta's cartesian_ddg protocol.

## References

- [1] Strokach, A., Lu, T.Y., Kim, P.M., 2021. ELASPIC2 (EL2): Combining Contextualized Language Models and Graph Neural Networks to Predict Effects of Mutations. Journal of Molecular Biology, Computation Resources for Molecular Biology 433, 166810. https://doi.org/10.1016/j.jmb.2021.166810
- [2] Strokach, A., Becerra, D., Corbi-Verge, C., Perez-Riba, A., Kim, P.M., 2020. Fast and Flexible Protein Design Using Deep Graph Neural Networks. Cell Systems. https://doi.org/10.1016/j.cels.2020.08.016
- [3] Elnaggar, A., Heinzinger, M., Dallago, C., Rehawi, G., Wang, Y., Jones, L., Gibbs, T., Feher, T., Angerer, C., Steinegger, M., Bhowmik, D., Rost, B., 2020. ProtTrans: Towards Cracking the Language of Life’s Code Through Self-Supervised Deep Learning and High Performance Computing. bioRxiv 2020.07.12.199554. https://doi.org/10.1101/2020.07.12.199554

In [None]:
result_dfs = {}
for i, metric in enumerate(["el2", "proteinsolver", "protbert", "rosetta_dg"]):
    result_dfs[metric] = result_df.copy()

    starting_tm = 59.2
    starting_unfold = 71.8

    max_ddg = result_dfs[metric][f"{metric}_score"].max()
    min_ddg = max(0, result_dfs[metric][f"{metric}_score"].min())
    #     assert min_ddg > 0

    tm_scale = (starting_tm - 20) / (max_ddg - min_ddg)
    unfold_scale = (98 - starting_unfold) / (max_ddg - min_ddg)

    result_dfs[metric]["Tm"] = starting_tm - tm_scale * (
        result_dfs[metric][f"{metric}_score"].values - min_ddg
    )
    result_dfs[metric]["Standard-Deviation-Tm"] = 10.0
    result_dfs[metric]["%Unfold"] = starting_unfold + unfold_scale * (
        result_dfs[metric][f"{metric}_score"].values - min_ddg
    )
    result_dfs[metric]["Standard-Deviation-%Unfold"] = 10.0
    result_dfs[metric]["Stabilizing-vs-Destabilizing"] = [
        ("Destabilizing" if ddg > 0 else "Stabilizing")
        for ddg in result_dfs[metric][f"{metric}_score"].values
    ]
    result_dfs[metric]["Comment"] = "*"

    print(metric, min_ddg)
    display(result_dfs[metric].head())

    output_file = output_dir.joinpath(f"strokach_modelnumber_{i + 1}.tsv")
    result_dfs[metric][result_template_df.columns].to_csv(output_file, sep="\t", index=False)
#     !python {NOTEBOOK_DIR}/validation.py {output_file}

In [None]:
for metric, df in result_dfs.items():
    print(metric)

    fig, axs = plt.subplots(1, 2, figsize=(10, 3))

    ax = axs[0]
    ax.hist(df["Tm"])
    ax.set_xlabel("Tm")

    ax = axs[1]
    ax.hist(df["%Unfold"])
    ax.set_xlabel("%unfold")