# Imports

In [1]:
import glob
import os
import random
from typing import Callable, List, Union

import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset

from cavity_model import (
    ResidueEnvironment,
    ResidueEnvironmentsDataset,
    ToTensor,
    CavityModel,
)

%load_ext nb_black

<IPython.core.display.Javascript object>

# Download and process Cavity Model data

In [2]:
# Run shell script that takes a .txt file with PDBIDs as input.
!./get_parse_pdbs_cavity_model.sh data/pdbids_010.txt

Successfully downloaded 4X2U.pdb to data/pdbs/raw/4X2U.pdb. 1/10.
Successfully downloaded 2X96.pdb to data/pdbs/raw/2X96.pdb. 2/10.
Successfully downloaded 4MXD.pdb to data/pdbs/raw/4MXD.pdb. 3/10.
Successfully downloaded 3E9L.pdb to data/pdbs/raw/3E9L.pdb. 4/10.
Successfully downloaded 1UWC.pdb to data/pdbs/raw/1UWC.pdb. 5/10.
Successfully downloaded 4BGU.pdb to data/pdbs/raw/4BGU.pdb. 6/10.
Successfully downloaded 2YSW.pdb to data/pdbs/raw/2YSW.pdb. 7/10.
Successfully downloaded 4OW4.pdb to data/pdbs/raw/4OW4.pdb. 8/10.
Successfully downloaded 2V5E.pdb to data/pdbs/raw/2V5E.pdb. 9/10.
Successfully downloaded 1IXH.pdb to data/pdbs/raw/1IXH.pdb. 10/10.
Successfully cleaned data/pdbs/raw/1IXH.pdb and added it to data/pdbs/cleaned/. 1/10.
Successfully cleaned data/pdbs/raw/1UWC.pdb and added it to data/pdbs/cleaned/. 2/10.
Successfully cleaned data/pdbs/raw/2V5E.pdb and added it to data/pdbs/cleaned/. 3/10.
Successfully cleaned data/pdbs/raw/2X96.pdb and added it to data/pdbs/cleaned/. 4

<IPython.core.display.Javascript object>

# Global variables for Cavity Model Training

In [3]:
DEVICE = "cuda"  # "cpu" or "cuda"
TRAIN_VAL_SPLIT = 0.8
BATCH_SIZE = 100
LEARNING_RATE = 3e-4
EPOCHS = 10
PATIENCE_CUTOFF = 2

<IPython.core.display.Javascript object>

# Parse PDBs and train/val split for Cavity Model 

In [4]:
parsed_pdb_filenames = sorted(glob.glob("data/pdbs/parsed/*coord*"))
random.shuffle(parsed_pdb_filenames)

n_train_pdbs = int(len(parsed_pdb_filenames) * TRAIN_VAL_SPLIT)
filenames_train = parsed_pdb_filenames[:n_train_pdbs]
filenames_val = parsed_pdb_filenames[n_train_pdbs:]

to_tensor_transformer = ToTensor(DEVICE)

dataset_train = ResidueEnvironmentsDataset(
    filenames_train, transformer=to_tensor_transformer
)
dataset_val = ResidueEnvironmentsDataset(
    filenames_val, transformer=to_tensor_transformer
)


dataloader_train = DataLoader(
    dataset_train,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=to_tensor_transformer.collate_cat,
    drop_last=True,
)
dataloader_val = DataLoader(
    dataset_val,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=to_tensor_transformer.collate_cat,
    drop_last=True,
)

print(
    f"Training data set includes {len(filenames_train)} pdbs with "
    f"{len(dataset_train)} environments."
)
print(
    f"Validation data set includes {len(filenames_val)} pdbs with "
    f"{len(dataset_val)} environments."
)

Training data set includes 8 pdbs with 4405 environments.
Validation data set includes 2 pdbs with 853 environments.


<IPython.core.display.Javascript object>

# Train Cavity Model

In [5]:
def _train_step(
    cavity_model_net: CavityModel,
    optimizer: torch.optim.Adam,
    loss_function: torch.nn.CrossEntropyLoss,
) -> (torch.Tensor, float):
    """
    Helper function to take a training step
    """
    cavity_model_net.train()
    optimizer.zero_grad()
    batch_y_pred = cavity_model_net(batch_x)
    loss_batch = loss_function(batch_y_pred, torch.argmax(batch_y, dim=-1))
    loss_batch.backward()
    optimizer.step()
    return (batch_y_pred, loss_batch.detach().cpu().item())


def _eval_loop(
    cavity_model_net: CavityModel,
    data_loader_val,
    loss_function: torch.nn.CrossEntropyLoss,
) -> (float, float):
    """
    Helper function to perform an eval loop
    """
    # Eval loop. Due to memory, we don't pass the whole eval set to the model
    labels_true_val = []
    labels_pred_val = []
    loss_batch_list_val = []
    for batch_x_val, batch_y_val in dataloader_val:
        cavity_model_net.eval()
        batch_y_pred_val = cavity_model_net(batch_x_val)

        loss_batch_val = loss_function(
            batch_y_pred_val, torch.argmax(batch_y_val, dim=-1)
        )
        loss_batch_list_val.append(loss_batch_val.detach().cpu().item())

        labels_true_val.append(torch.argmax(batch_y_val, dim=-1).detach().cpu().numpy())
        labels_pred_val.append(
            torch.argmax(batch_y_pred_val, dim=-1).detach().cpu().numpy()
        )
    acc_val = np.mean(
        (np.reshape(labels_true_val, -1) == np.reshape(labels_pred_val, -1))
    )
    loss_val = np.mean(loss_batch_list_val)
    return acc_val, loss_val

<IPython.core.display.Javascript object>

In [14]:
# Define model
cavity_model_net = CavityModel(DEVICE).to(DEVICE)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cavity_model_net.parameters(), lr=LEARNING_RATE)

# Create directory for model files
models_dirpath = "cavity_models/"
if not os.path.exists(models_dirpath):
    os.mkdir(models_dirpath) 

# Train loop
current_best_epoch_idx = -1
current_best_loss_val = 1e4
patience = 0
epoch_idx_to_model_path = {}
for epoch in range(EPOCHS):
    labels_true = []
    labels_pred = []
    loss_batch_list = []
    for batch_x, batch_y in dataloader_train:
        # Take train step
        batch_y_pred, loss_batch = _train_step(
            cavity_model_net, optimizer, loss_function
        )
        loss_batch_list.append(loss_batch)

        labels_true.append(torch.argmax(batch_y, dim=-1).detach().cpu().numpy())
        labels_pred.append(torch.argmax(batch_y_pred, dim=-1).detach().cpu().numpy())
    
    # Train epoch metrics
    acc_train = np.mean((np.reshape(labels_true, -1) == np.reshape(labels_pred, -1)))
    loss_train = np.mean(loss_batch_list)

    # Validation epoch metrics
    acc_val, loss_val = _eval_loop(cavity_model_net, dataloader_val, loss_function)

    print(
        f"Epoch {epoch:2d}. Train loss: {loss_train:5.3f}. "
        f"Train Acc: {acc_train:4.2f}. Val loss: {loss_val:5.3f}. "
        f"Val Acc {acc_val:4.2f}"
    )
    
    # Save model
    model_path = f"cavity_models/model_epoch_{epoch:02d}.pt"
    epoch_idx_to_model_path[epoch] = model_path
    torch.save(cavity_model_net.state_dict(), model_path)
    
    # Early stopping
    if loss_val < current_best_loss_val:
        current_best_loss_val = loss_val
        current_best_epoch_idx = epoch
    else:
        patience += 1
    if patience > PATIENCE_CUTOFF:
        print(
            f"Early stopping activated. Best epoch idx: {current_best_epoch_idx} "
            f"with validation loss: {current_best_loss_val:5.3f} and \nmodel_path: "
            f"{epoch_idx_to_model_path[current_best_epoch_idx]}"
        )
        break



Epoch  0. Train loss: 2.493. Train Acc: 0.32. Val loss: 3.199. Val Acc 0.07
Epoch  1. Train loss: 1.261. Train Acc: 0.77. Val loss: 2.729. Val Acc 0.18
Epoch  2. Train loss: 0.696. Train Acc: 0.94. Val loss: 2.725. Val Acc 0.21
Epoch  3. Train loss: 0.347. Train Acc: 0.99. Val loss: 2.815. Val Acc 0.21
Epoch  4. Train loss: 0.176. Train Acc: 1.00. Val loss: 2.881. Val Acc 0.21
Epoch  5. Train loss: 0.099. Train Acc: 1.00. Val loss: 2.899. Val Acc 0.20
Early stopping activated. Best epoch idx: 2 with validation loss: 2.725 and 
model_path: cavity_models/model_epoch_02.pt


<IPython.core.display.Javascript object>