# Imports

In [1]:
import glob
import os
import random
from typing import Callable, List, Union

import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader, Dataset
from Bio.PDB.Polypeptide import index_to_one

from cavity_model import (
    ResidueEnvironment,
    ResidueEnvironmentsDataset,
    ToTensor,
    CavityModel,
)

%load_ext nb_black

<IPython.core.display.Javascript object>

# Cavity Model

Download and process Cavity Model data

In [2]:
# # Run shell script that takes a .txt file with PDBIDs as input.
# !./get_parse_pdbs_cavity_model.sh data/pdbids_010.txt

<IPython.core.display.Javascript object>

Global variables for Cavity Model Training

In [3]:
DEVICE = "cuda"  # "cpu" or "cuda"
TRAIN_VAL_SPLIT = 0.8
BATCH_SIZE = 100
LEARNING_RATE = 3e-4
EPOCHS = 10
PATIENCE_CUTOFF = 2

<IPython.core.display.Javascript object>

Load Parsed PDBs and perform train/val split

In [4]:
parsed_pdb_filenames = sorted(glob.glob("data/pdbs/parsed/*coord*"))
random.shuffle(parsed_pdb_filenames)

n_train_pdbs = int(len(parsed_pdb_filenames) * TRAIN_VAL_SPLIT)
filenames_train = parsed_pdb_filenames[:n_train_pdbs]
filenames_val = parsed_pdb_filenames[n_train_pdbs:]

to_tensor_transformer = ToTensor(DEVICE)

dataset_train = ResidueEnvironmentsDataset(
    filenames_train, transformer=to_tensor_transformer
)
dataset_val = ResidueEnvironmentsDataset(
    filenames_val, transformer=to_tensor_transformer
)

dataloader_train = DataLoader(
    dataset_train,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=to_tensor_transformer.collate_cat,
    drop_last=True,
)
dataloader_val = DataLoader(
    dataset_val,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=to_tensor_transformer.collate_cat,
    drop_last=True,
)

print(
    f"Training data set includes {len(filenames_train)} pdbs with "
    f"{len(dataset_train)} environments."
)
print(
    f"Validation data set includes {len(filenames_val)} pdbs with "
    f"{len(dataset_val)} environments."
)

Training data set includes 8 pdbs with 4636 environments.
Validation data set includes 2 pdbs with 622 environments.


<IPython.core.display.Javascript object>

Train Cavity Model

In [5]:
def _train_step(
    cavity_model_net: CavityModel,
    optimizer: torch.optim.Adam,
    loss_function: torch.nn.CrossEntropyLoss,
) -> (torch.Tensor, float):
    """
    Helper function to take a training step
    """
    cavity_model_net.train()
    optimizer.zero_grad()
    batch_y_pred = cavity_model_net(batch_x)
    loss_batch = loss_function(batch_y_pred, torch.argmax(batch_y, dim=-1))
    loss_batch.backward()
    optimizer.step()
    return (batch_y_pred, loss_batch.detach().cpu().item())


def _eval_loop(
    cavity_model_net: CavityModel,
    data_loader_val,
    loss_function: torch.nn.CrossEntropyLoss,
) -> (float, float):
    """
    Helper function to perform an eval loop
    """
    # Eval loop. Due to memory, we don't pass the whole eval set to the model
    labels_true_val = []
    labels_pred_val = []
    loss_batch_list_val = []
    for batch_x_val, batch_y_val in dataloader_val:
        cavity_model_net.eval()
        batch_y_pred_val = cavity_model_net(batch_x_val)

        loss_batch_val = loss_function(
            batch_y_pred_val, torch.argmax(batch_y_val, dim=-1)
        )
        loss_batch_list_val.append(loss_batch_val.detach().cpu().item())

        labels_true_val.append(torch.argmax(batch_y_val, dim=-1).detach().cpu().numpy())
        labels_pred_val.append(
            torch.argmax(batch_y_pred_val, dim=-1).detach().cpu().numpy()
        )
    acc_val = np.mean(
        (np.reshape(labels_true_val, -1) == np.reshape(labels_pred_val, -1))
    )
    loss_val = np.mean(loss_batch_list_val)
    return acc_val, loss_val

<IPython.core.display.Javascript object>

In [6]:
# Define model
cavity_model_net = CavityModel(DEVICE).to(DEVICE)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cavity_model_net.parameters(), lr=LEARNING_RATE)

# Create directory for model files
models_dirpath = "cavity_models/"
if not os.path.exists(models_dirpath):
    os.mkdir(models_dirpath)

# Train loop
current_best_epoch_idx = -1
current_best_loss_val = 1e4
patience = 0
epoch_idx_to_model_path = {}
for epoch in range(EPOCHS):
    labels_true = []
    labels_pred = []
    loss_batch_list = []
    for batch_x, batch_y in dataloader_train:
        # Take train step
        batch_y_pred, loss_batch = _train_step(
            cavity_model_net, optimizer, loss_function
        )
        loss_batch_list.append(loss_batch)

        labels_true.append(torch.argmax(batch_y, dim=-1).detach().cpu().numpy())
        labels_pred.append(torch.argmax(batch_y_pred, dim=-1).detach().cpu().numpy())

    # Train epoch metrics
    acc_train = np.mean((np.reshape(labels_true, -1) == np.reshape(labels_pred, -1)))
    loss_train = np.mean(loss_batch_list)

    # Validation epoch metrics
    acc_val, loss_val = _eval_loop(cavity_model_net, dataloader_val, loss_function)

    print(
        f"Epoch {epoch:2d}. Train loss: {loss_train:5.3f}. "
        f"Train Acc: {acc_train:4.2f}. Val loss: {loss_val:5.3f}. "
        f"Val Acc {acc_val:4.2f}"
    )

    # Save model
    model_path = f"cavity_models/model_epoch_{epoch:02d}.pt"
    epoch_idx_to_model_path[epoch] = model_path
    torch.save(cavity_model_net.state_dict(), model_path)

    # Early stopping
    if loss_val < current_best_loss_val:
        current_best_loss_val = loss_val
        current_best_epoch_idx = epoch
        patience = 0
    else:
        patience += 1
    if patience > PATIENCE_CUTOFF:
        print(f"Early stopping activated.")
        break

print(
    f"Best epoch idx: {current_best_epoch_idx} with validation loss: "
    f"{current_best_loss_val:5.3f} and model_path: "
    f"{epoch_idx_to_model_path[current_best_epoch_idx]}"
)

Epoch  0. Train loss: 2.484. Train Acc: 0.32. Val loss: 3.171. Val Acc 0.05
Epoch  1. Train loss: 1.292. Train Acc: 0.76. Val loss: 2.696. Val Acc 0.22
Epoch  2. Train loss: 0.726. Train Acc: 0.92. Val loss: 2.727. Val Acc 0.19
Epoch  3. Train loss: 0.368. Train Acc: 0.99. Val loss: 2.755. Val Acc 0.21
Epoch  4. Train loss: 0.186. Train Acc: 1.00. Val loss: 2.783. Val Acc 0.20
Early stopping activated.
Best epoch idx: 1 with validation loss: 2.696 and model_path: cavity_models/model_epoch_01.pt


<IPython.core.display.Javascript object>

# ddG Prediction

Parse PDBs for DMS, Guerois and Protein G data sets

In [7]:
# # Parse PDBs for which we have ddG data
# !./get_parse_pdbs_dowstream_task.sh

<IPython.core.display.Javascript object>

Create temporary residue environment datasets as dicts to more easily match ddG data

In [8]:
parsed_pdbs_wildcards = {
    "dms": "data/data_dms/pdbs_parsed/*coord*",
    "protein_g": "data/data_protein_g/pdbs_parsed/*coord*",
    "guerois": "data/data_gueros/pdbs_parsed/*coord*",
    "symmetric": "data/data_symmetric/pdbs_parsed/*coord*",
}

datasets_look_up = {}
for dataset_key, pdbs_wildcard in parsed_pdbs_wildcards.items():
    parsed_pdb_filenames = sorted(glob.glob(pdbs_wildcard))
    dataset = ResidueEnvironmentsDataset(parsed_pdb_filenames, transformer=None)
    dataset_look_up = {}
    for resenv in dataset:
        key = (
            f"{resenv.pdb_id}{resenv.chain_id}_{resenv.pdb_residue_number}"
            f"{index_to_one(resenv.restype_index)}"
        )
        dataset_look_up[key] = resenv
    datasets_look_up[dataset_key] = dataset_look_up

<IPython.core.display.Javascript object>

Load ddG data as pandas.DataFrame

In [9]:
ddg_data_dict = {
    "dms": pd.read_csv("data/data_dms/ddgs_parsed.csv"),
    "protein_g": pd.read_csv("data/data_protein_g/ddgs_parsed.csv"),
    "guerois": pd.read_csv("data/data_guerois/ddgs_parsed.csv"),
}

<IPython.core.display.Javascript object>

In [10]:
for df in ddg_data_dict.values():
    display(df.head(10))

Unnamed: 0,pdbid,chainid,variant,ddg
0,1D5R,A,M1V,-0.065143
1,1D5R,A,T2A,0.224462
2,1D5R,A,T2D,-0.190667
3,1D5R,A,T2E,0.333408
4,1D5R,A,T2G,0.00139
5,1D5R,A,T2K,0.05239
6,1D5R,A,T2N,-0.072988
7,1D5R,A,T2P,0.15669
8,1D5R,A,T2R,0.12717
9,1D5R,A,T2S,0.175641


Unnamed: 0,pdbid,chainid,variant,ddg
0,1PGA,A,M1A,0.1407
1,1PGA,A,M1D,0.3795
2,1PGA,A,M1E,0.6414
3,1PGA,A,M1L,0.4573
4,1PGA,A,T2E,0.1299
5,1PGA,A,T2F,0.3008
6,1PGA,A,T2G,-0.668
7,1PGA,A,T2H,-0.1303
8,1PGA,A,T2I,1.004
9,1PGA,A,T2L,0.5417


Unnamed: 0,pdbid,chainid,variant,ddg
0,171L,A,A45E,0.01
1,1A2P,A,Y103F,0.0
2,1A2P,A,T105V,2.24
3,1A2P,A,I109A,2.07
4,1A2P,A,I109V,0.76
5,1A2P,A,V10A,3.39
6,1A2P,A,V10T,2.48
7,1A2P,A,R110A,0.41
8,1A2P,A,D12A,0.31
9,1A2P,A,D12G,1.29


<IPython.core.display.Javascript object>