# Imports

In [None]:
%load_ext autoreload
%autoreload 2
import glob
import os
import random
import pathlib

import pandas as pd
import torch
from Bio.PDB.Polypeptide import index_to_one
from collections import OrderedDict
from torch.utils.data import DataLoader, Dataset

from cavity_model import (
    CavityModel,
    ResidueEnvironment,
    ResidueEnvironmentsDataset,
)
import helpers
from typing import Optional
from visualization import scatter_pred_vs_true, plot_validation_performance

%load_ext nb_black

# Cavity Model

Download and process Cavity Model data

In [None]:
# # Hack to find the conda activate path since bash scripts
# # don't necessarily work with the conda activate command
# conda_path = !which conda
# conda_path = list(conda_path)[0]
# conda_activate_path = pathlib.Path(conda_path).parent.parent / "bin" / "activate"
# if not conda_activate_path.is_file():
#     raise FileNotFoundError(
#         "Could not find your conda activate path needed for running bash scripts."
#     )

In [None]:
# # Run shell script that takes a .txt file with PDBIDs as input.
# !./get_and_parse_pdbs_for_cavity_model.sh $conda_activate_path data/pdbids_2336.txt

Global variables

In [None]:
# Main parameters
WARM_START: Optional[str] = "cavity_models/model_epoch_02.pt"
DEVICE: str = "cuda"  # "cpu" or "cuda"
TRAIN_VAL_SPLIT: float = 0.9
BATCH_SIZE: int = 100
SHUFFLE_PDBS: bool = True
LEARNING_RATE: float = 2e-4
EPOCHS: int = 6
PATIENCE_CUTOFF: int = 1
EPS: float = 1e-9

# Parameters for simulation stride
STRIDE_FRAGMENTS: int = 2
STRIDE_MD: int = 4

# Parameters specific to downstream model
BATCH_SIZE_DDG: int = 40
SHUFFLE_DDG: bool = True
LEARNING_RATE_DDG: float = 1e-3
EPOCHS_DDG: int = 200

# Mapping between data keys and titles/colors for plots
dataset_name_mapping = {
    "dms": "DMS",
    "guerois": "Guerois",
    "protein_g": "Protein G",
    "symmetric_direct": "Symmetric (Direct)",
    "symmetric_inverse": "Symmetric (Inverse)",
}

dataset_color_mapping = {
    "dms": "steelblue",
    "guerois": "firebrick",
    "protein_g": "forestgreen",
    "symmetric_direct": "olive",
    "symmetric_inverse": "olive",
}

Load Parsed PDBs and perform train/val split

In [None]:
if WARM_START is None:
    parsed_pdb_filenames = sorted(glob.glob("data/pdbs/parsed/*coord*"))
    if SHUFFLE_PDBS:
        random.shuffle(parsed_pdb_filenames)
    (
        dataloader_train,
        dataset_train,
        dataloader_val,
        dataset_val,
    ) = helpers.train_val_split(
        parsed_pdb_filenames, TRAIN_VAL_SPLIT, DEVICE, BATCH_SIZE
    )
else:
    print(f"Warm start: {WARM_START}")

Train the cavity model

In [None]:
if WARM_START is None:
    # Define model
    cavity_model_net = CavityModel(DEVICE).to(DEVICE)
    loss_function = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(cavity_model_net.parameters(), lr=LEARNING_RATE)

    # Create directory for model files
    models_dirpath = "cavity_models/"
    if not os.path.exists(models_dirpath):
        os.mkdir(models_dirpath)

    # Train loop
    best_model_path = helpers.train_loop(
        dataloader_train,
        dataloader_val,
        cavity_model_net,
        loss_function,
        optimizer,
        EPOCHS,
        PATIENCE_CUTOFF,
    )
else:
    print(f"Warm start: {WARM_START}")

# ddG Prediction

Parse PDBs for DMS, Guerois and Protein G data sets

In [None]:
# # Parse PDBs for which we have ddG data
# !./get_and_parse_pdbs_for_dowstream_task.sh $conda_activate_path

Make dict for residue environments for easy look up

In [None]:
# Create temporary residue environment datasets as dicts to more easily match ddG data
parsed_pdbs_wildcards = {
    "dms": "data/data_dms/pdbs_parsed/*coord*",
    "protein_g": "data/data_protein_g/pdbs_parsed/*coord*",
    "guerois": "data/data_guerois/pdbs_parsed/*coord*",
    "symmetric": "data/data_symmetric/pdbs_parsed/*coord*",
}

resenv_datasets_look_up = {}
for dataset_key, pdbs_wildcard in parsed_pdbs_wildcards.items():
    parsed_pdb_filenames = sorted(glob.glob(pdbs_wildcard))
    dataset = ResidueEnvironmentsDataset(parsed_pdb_filenames, transformer=None)
    dataset_look_up = {}
    for resenv in dataset:
        key = (
            f"{resenv.pdb_id}{resenv.chain_id}_{resenv.pdb_residue_number}"
            f"{index_to_one(resenv.restype_index)}"
        )
        dataset_look_up[key] = resenv
    resenv_datasets_look_up[dataset_key] = dataset_look_up

Load ddG data to dataframe

In [None]:
ddg_data_dict = OrderedDict()
ddg_data_dict = {
    "dms": pd.read_csv("data/data_dms/ddgs_parsed.csv"),
    "protein_g": pd.read_csv("data/data_protein_g/ddgs_parsed.csv"),
    "guerois": pd.read_csv("data/data_guerois/ddgs_parsed.csv"),
    "symmetric_direct": pd.read_csv("data/data_symmetric/ddgs_parsed_direct.csv"),
    "symmetric_inverse": pd.read_csv("data/data_symmetric/ddgs_parsed_inverse.csv"),
}

Populate dataframes with wt ResidueEnvironment objects and wt and mt restype indices

In [None]:
helpers.populate_dfs_with_resenvs(ddg_data_dict, resenv_datasets_look_up)

Populate dataframes with predicted NLLs and isolated WT and MT predicted NLLs as well as NLFs

In [None]:
# Load best performing cavity model from previos training
if WARM_START:
    best_model_path = WARM_START

cavity_model_infer_net = CavityModel(DEVICE).to(DEVICE)
cavity_model_infer_net.load_state_dict(torch.load(best_model_path))
cavity_model_infer_net.eval()

helpers.populate_dfs_with_nlls_and_nlfs(
    ddg_data_dict, cavity_model_infer_net, DEVICE, BATCH_SIZE, EPS
)

## Results without downstream model

### PDB statistics

In [None]:
for data_key in ddg_data_dict:
    fig, ax = scatter_pred_vs_true(
        ddg_data_dict[data_key]["ddg"],
        ddg_data_dict[data_key]["ddg_pred_no_ds"],
        color=dataset_color_mapping[data_key],
        title=dataset_name_mapping[data_key],
    )

### IDP statistics

In [None]:
for data_key in ddg_data_dict:
    fig, ax = scatter_pred_vs_true(
        ddg_data_dict[data_key]["ddg"],
        ddg_data_dict[data_key]["ddg_pred_idp_no_ds"],
        color=dataset_color_mapping[data_key],
        title=dataset_name_mapping[data_key],
    )

### Symmetric, use both structures

In [None]:
helpers.get_predictions_both_structures(ddg_data_dict)

# Plot prediction for ddg direct using both structures
fig, ax = scatter_pred_vs_true(
    ddg_data_dict["symmetric_both"]["ddg_dir"],
    ddg_data_dict["symmetric_both"]["ddg_pred_no_ds_both_dir"],
    color="olive",
    title="Both structure (Direct)",
)

# Plot prediction for ddg inverse using both structures
fig, ax = scatter_pred_vs_true(
    ddg_data_dict["symmetric_both"]["ddg_inv"],
    ddg_data_dict["symmetric_both"]["ddg_pred_no_ds_both_inv"],
    color="olive",
    title="Both structure (Inverse)",
)

### Phaistos statistics

In [None]:
# Output 11 amino acid sequence fragment for MC simulations (Only done once)
helpers.output_sequence_fragments_to_csv(ddg_data_dict)

#### Protein G

In [None]:
# # DROP SOME ROWS FOR QUICKER TESTING
# ddg_data_dict["protein_g"] = ddg_data_dict["protein_g"].iloc[0:10]

In [None]:
data_set = "protein_g"

helpers.infer_probabilities_for_center_residues(
    ddg_data_dict,
    data_set,
    cavity_model_infer_net,
    DEVICE,
    EPS,
    is_wt=True,
    stride=STRIDE_FRAGMENTS,
)
helpers.infer_probabilities_for_center_residues(
    ddg_data_dict,
    data_set,
    cavity_model_infer_net,
    DEVICE,
    EPS,
    is_wt=False,
    stride=STRIDE_FRAGMENTS,
)

helpers.add_ddg_preds_with_unfolded_state(ddg_data_dict, data_set)

In [None]:
fig, ax = scatter_pred_vs_true(
    ddg_data_dict[data_set]["ddg"],
    ddg_data_dict[data_set]["ddg_pred_wt_phaistos_no_ds"],
    color=dataset_color_mapping[data_set],
    title=f"{dataset_name_mapping[data_set]}, WT unfolded",
)
fig, ax = scatter_pred_vs_true(
    ddg_data_dict[data_set]["ddg"],
    ddg_data_dict[data_set]["ddg_pred_mt_phaistos_no_ds"],
    color=dataset_color_mapping[data_set],
    title=f"{dataset_name_mapping[data_set]}, MT unfolded",
)

fig, ax = scatter_pred_vs_true(
    ddg_data_dict[data_set]["ddg"],
    ddg_data_dict[data_set]["ddg_pred_wt_and_mt_phaistos_no_ds"],
    color=dataset_color_mapping[data_set],
    title=f"{dataset_name_mapping[data_set]}, WT & MT",
)

#### Guerois (Phaistos statistics)

In [None]:
# # DROP SOME ROWS FOR QUICKER TESTING
# ddg_data_dict["guerois"] = ddg_data_dict["guerois"].iloc[0:10]

In [None]:
data_set = "guerois"

helpers.infer_probabilities_for_center_residues(
    ddg_data_dict, data_set, cavity_model_infer_net, DEVICE, EPS, is_wt=True
)
helpers.infer_probabilities_for_center_residues(
    ddg_data_dict, data_set, cavity_model_infer_net, DEVICE, EPS, is_wt=False
)

helpers.add_ddg_preds_with_unfolded_state(ddg_data_dict, data_set)

In [None]:
fig, ax = scatter_pred_vs_true(
    ddg_data_dict[data_set]["ddg"],
    ddg_data_dict[data_set]["ddg_pred_wt_phaistos_no_ds"],
    color=dataset_color_mapping[data_set],
    title=f"{dataset_name_mapping[data_set]}, WT unfolded",
)
fig, ax = scatter_pred_vs_true(
    ddg_data_dict[data_set]["ddg"],
    ddg_data_dict[data_set]["ddg_pred_mt_phaistos_no_ds"],
    color=dataset_color_mapping[data_set],
    title=f"{dataset_name_mapping[data_set]}, MT unfolded",
)

fig, ax = scatter_pred_vs_true(
    ddg_data_dict[data_set]["ddg"],
    ddg_data_dict[data_set]["ddg_pred_wt_and_mt_phaistos_no_ds"],
    color=dataset_color_mapping[data_set],
    title=f"{dataset_name_mapping[data_set]}, WT & MT",
)

#### DMS (Phaistos statistics)

In [None]:
# # DROP SOME ROWS FOR QUICKER TESTING
# ddg_data_dict["dms"] = ddg_data_dict["dms"].iloc[0:10]

In [None]:
data_set = "dms"

helpers.infer_probabilities_for_center_residues(
    ddg_data_dict, data_set, cavity_model_infer_net, DEVICE, EPS, is_wt=True
)
helpers.infer_probabilities_for_center_residues(
    ddg_data_dict, data_set, cavity_model_infer_net, DEVICE, EPS, is_wt=False
)
helpers.add_ddg_preds_with_unfolded_state(ddg_data_dict, data_set)

In [None]:
fig, ax = scatter_pred_vs_true(
    ddg_data_dict[data_set]["ddg"],
    ddg_data_dict[data_set]["ddg_pred_wt_phaistos_no_ds"],
    color=dataset_color_mapping[data_set],
    title=f"{dataset_name_mapping[data_set]}, WT unfolded",
)
fig, ax = scatter_pred_vs_true(
    ddg_data_dict[data_set]["ddg"],
    ddg_data_dict[data_set]["ddg_pred_mt_phaistos_no_ds"],
    color=dataset_color_mapping[data_set],
    title=f"{dataset_name_mapping[data_set]}, MT unfolded",
)

fig, ax = scatter_pred_vs_true(
    ddg_data_dict[data_set]["ddg"],
    ddg_data_dict[data_set]["ddg_pred_wt_and_mt_phaistos_no_ds"],
    color=dataset_color_mapping[data_set],
    title=f"{dataset_name_mapping[data_set]}, WT & MT",
)

##### Molecular dynamics

In [None]:
# Infer NLLs for WT and MT  in columns wt_nll_md and mt_nll_md
helpers.infer_molecular_dynamics_nlls(
    ddg_data_dict, "protein_g", DEVICE, EPS, cavity_model_infer_net, stride=STRIDE_MD
)
helpers.infer_molecular_dynamics_nlls(
    ddg_data_dict, "guerois", DEVICE, EPS, cavity_model_infer_net, stride=STRIDE_MD
)
helpers.infer_molecular_dynamics_nlls(
    ddg_data_dict, "dms", DEVICE, EPS, cavity_model_infer_net, stride=STRIDE_MD
)

In [None]:
# Calculate DDgs
helpers.add_ddg_preds_with_md_simulations(ddg_data_dict, "protein_g")
helpers.add_ddg_preds_with_md_simulations(ddg_data_dict, "guerois")
helpers.add_ddg_preds_with_md_simulations(ddg_data_dict, "dms")

In [None]:
data_set = "protein_g"
fig, ax = scatter_pred_vs_true(
    ddg_data_dict[data_set]["ddg"],
    ddg_data_dict[data_set]["ddg_pred_md_pdb_statistics_no_ds"],
    color=dataset_color_mapping[data_set],
    title=f"{dataset_name_mapping[data_set]}, MD & PDB statistics",
)
fig, ax = scatter_pred_vs_true(
    ddg_data_dict[data_set]["ddg"],
    ddg_data_dict[data_set]["ddg_pred_md_phaistos_mt_and_wt_statistics_no_ds"],
    color=dataset_color_mapping[data_set],
    title=f"{dataset_name_mapping[data_set]}, MD & Phaistos statistics",
)

data_set = "guerois"
fig, ax = scatter_pred_vs_true(
    ddg_data_dict[data_set]["ddg"],
    ddg_data_dict[data_set]["ddg_pred_md_pdb_statistics_no_ds"],
    color=dataset_color_mapping[data_set],
    title=f"{dataset_name_mapping[data_set]}, MD & PDB statistics",
)
fig, ax = scatter_pred_vs_true(
    ddg_data_dict[data_set]["ddg"],
    ddg_data_dict[data_set]["ddg_pred_md_phaistos_mt_and_wt_statistics_no_ds"],
    color=dataset_color_mapping[data_set],
    title=f"{dataset_name_mapping[data_set]}, MD & Phaistos statistics",
)

data_set = "dms"
fig, ax = scatter_pred_vs_true(
    ddg_data_dict[data_set]["ddg"],
    ddg_data_dict[data_set]["ddg_pred_md_pdb_statistics_no_ds"],
    color=dataset_color_mapping[data_set],
    title=f"{dataset_name_mapping[data_set]}, MD & PDB statistics",
)
fig, ax = scatter_pred_vs_true(
    ddg_data_dict[data_set]["ddg"],
    ddg_data_dict[data_set]["ddg_pred_md_phaistos_mt_and_wt_statistics_no_ds"],
    color=dataset_color_mapping[data_set],
    title=f"{dataset_name_mapping[data_set]}, MD & Phaistos statistics",
)

# Downstream model

### Performance without augmentation Vanilla

Define training dataloader and eval dataloaders

In [None]:
from cavity_model import DDGToTensor, DDGToTensorPhaistosAndMD

In [None]:
ddg_dataloaders_train_dict = helpers.get_ddg_training_dataloaders(
    ddg_data_dict, BATCH_SIZE_DDG, SHUFFLE_DDG, DDGToTensor
)
ddg_dataloaders_val_dict = helpers.get_ddg_validation_dataloaders(
    ddg_data_dict, DDGToTensor
)

Train and report on the data that is not used during training

In [None]:
pearsons_r_results_dict = helpers.train_downstream_and_evaluate(
    ddg_dataloaders_train_dict,
    ddg_dataloaders_val_dict,
    DEVICE,
    LEARNING_RATE_DDG,
    EPOCHS_DDG,
)

In [None]:
for data_set in pearsons_r_results_dict.keys():
    _ = plot_validation_performance(
        f"Trained on {dataset_name_mapping[data_set]}",
        pearsons_r_results_dict[data_set],
    )

### Performance without augmentation With Phaistos and MD

In [None]:
ddg_dataloaders_train_dict = helpers.get_ddg_training_dataloaders(
    ddg_data_dict, BATCH_SIZE_DDG, SHUFFLE_DDG, DDGToTensorPhaistosAndMD
)
ddg_dataloaders_val_dict = helpers.get_ddg_validation_dataloaders(
    ddg_data_dict, DDGToTensorPhaistosAndMD
)

In [None]:
pearsons_r_results_dict = helpers.train_downstream_and_evaluate(
    ddg_dataloaders_train_dict,
    ddg_dataloaders_val_dict,
    DEVICE,
    LEARNING_RATE_DDG,
    EPOCHS_DDG,
)

In [None]:
for data_set in pearsons_r_results_dict.keys():
    _ = plot_validation_performance(
        f"Trained on {dataset_name_mapping[data_set]}",
        pearsons_r_results_dict[data_set],
    )

### Performance with augmentation Vanilla

Before training we "augment" our dataset simply by adding the reverse mutation with -ddG value

In [None]:
ddg_data_dict_augmented = helpers.augment_with_reverse_mutation(ddg_data_dict)

Define training dataloader (augmented data) and eval dataloaders (original data)

In [None]:
ddg_dataloaders_train_dict = helpers.get_ddg_training_dataloaders(
    ddg_data_dict_augmented, BATCH_SIZE_DDG, SHUFFLE_DDG, DDGToTensor
)
ddg_dataloaders_val_dict = helpers.get_ddg_validation_dataloaders(
    ddg_data_dict, DDGToTensor
)

Train and report on the data that is not used during training

In [None]:
pearsons_r_results_dict = helpers.train_downstream_and_evaluate(
    ddg_dataloaders_train_dict,
    ddg_dataloaders_val_dict,
    DEVICE,
    LEARNING_RATE_DDG,
    EPOCHS_DDG,
)

In [None]:
for data_set in pearsons_r_results_dict.keys():
    _ = plot_validation_performance(
        f"Trained on {dataset_name_mapping[data_set]}",
        pearsons_r_results_dict[data_set],
    )

### Performance with augmentation With Phaistos and MD

In [None]:
ddg_dataloaders_train_dict = helpers.get_ddg_training_dataloaders(
    ddg_data_dict_augmented, BATCH_SIZE_DDG, SHUFFLE_DDG, DDGToTensorPhaistosAndMD
)
ddg_dataloaders_val_dict = helpers.get_ddg_validation_dataloaders(
    ddg_data_dict, DDGToTensorPhaistosAndMD
)

In [None]:
pearsons_r_results_dict = helpers.train_downstream_and_evaluate(
    ddg_dataloaders_train_dict,
    ddg_dataloaders_val_dict,
    DEVICE,
    LEARNING_RATE_DDG,
    EPOCHS_DDG,
)

In [None]:
for data_set in pearsons_r_results_dict.keys():
    _ = plot_validation_performance(
        f"Trained on {dataset_name_mapping[data_set]}",
        pearsons_r_results_dict[data_set],
    )