# Imports

In [1]:
import glob
import os
import random
from typing import Callable, List, Union

import numpy as np
import torch
from torch.utils.data import DataLoader, Dataset

from cavity_model import (
    ResidueEnvironment,
    ResidueEnvironmentsDataset,
    ToTensor,
    CavityModel,
)

%load_ext nb_black

<IPython.core.display.Javascript object>

# Download and process data

In [2]:
# Run shell script that takes a .txt file with PDBIDs as input.
!./download_and_process_data.sh data/pdbids_002.txt

Successfully downloaded 4X2U.pdb to data/pdbs/raw/4X2U.pdb. 1/2.
Successfully downloaded 2X96.pdb to data/pdbs/raw/2X96.pdb. 2/2.
Successfully cleaned data/pdbs/raw/2X96.pdb and added it to data/pdbs/cleaned/. 1/2.
Successfully cleaned data/pdbs/raw/4X2U.pdb and added it to data/pdbs/cleaned/. 2/2.
Successfully parsed 2X96_clean.pdb and moved parsed file to data/pdbs/parsed. Finished 1/2.
Successfully parsed 4X2U_clean.pdb and moved parsed file to data/pdbs/parsed. Finished 2/2.


<IPython.core.display.Javascript object>

# Global variables

In [3]:
DEVICE = "cuda"  # "cpu" or "cuda"
BATCH_SIZE = 100
LEARNING_RATE = 3e-4
EPOCHS = 5
TRAIN_VAL_SPLIT = 0.8

<IPython.core.display.Javascript object>

# Parse and train/val split

In [4]:
parsed_pdb_filenames = sorted(glob.glob("data/pdbs/parsed/*coord*"))
random.shuffle(parsed_pdb_filenames)

n_train_pdbs = int(len(parsed_pdb_filenames) * TRAIN_VAL_SPLIT)
filenames_train = parsed_pdb_filenames[:n_train_pdbs]
filenames_val = parsed_pdb_filenames[n_train_pdbs:]

to_tensor_transformer = ToTensor(DEVICE)

dataset_train = ResidueEnvironmentsDataset(
    filenames_train, transformer=to_tensor_transformer
)
dataset_val = ResidueEnvironmentsDataset(
    filenames_val, transformer=to_tensor_transformer
)


dataloader_train = DataLoader(
    dataset_train,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=to_tensor_transformer.collate_cat,
    drop_last=True,
)
dataloader_val = DataLoader(
    dataset_val,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=to_tensor_transformer.collate_cat,
    drop_last=True,
)

print(
    f"Training data set includes {len(filenames_train)} pdbs with "
    f"{len(dataset_train)} environments."
)
print(
    f"Validation data set includes {len(filenames_val)} pdbs with "
    f"{len(dataset_val)} environments."
)

Training data set includes 1 pdbs with 598 environments.
Validation data set includes 1 pdbs with 889 environments.


<IPython.core.display.Javascript object>

# Train

In [5]:
def _train_step(
    cavity_model: CavityModel,
    optimizer: torch.optim.Adam,
    loss_function: torch.nn.CrossEntropyLoss,
) -> (torch.Tensor, float):
    cavity_model.train()
    optimizer.zero_grad()
    batch_y_pred = cavity_model(batch_x)
    loss_batch = loss_function(batch_y_pred, torch.argmax(batch_y, dim=-1))
    loss_batch.backward()
    optimizer.step()
    return (batch_y_pred, loss_batch.detach().cpu().item())


# Define model
cavity_model = CavityModel(DEVICE).to(DEVICE)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cavity_model.parameters(), lr=LEARNING_RATE)

# Train loop
for epoch in range(EPOCHS):
    loss_running_mean = 0.0
    labels_true = []
    labels_pred = []
    for batch_x, batch_y in dataloader_train:
        # Take train step
        batch_y_pred, loss_batch = _train_step(cavity_model, optimizer, loss_function)

        # Exponential running mean for the loss
        loss_running_mean = loss_running_mean * 0.9 + loss_batch * 0.1

        labels_true.append(torch.argmax(batch_y, dim=-1).detach().cpu().numpy())
        labels_pred.append(torch.argmax(batch_y_pred, dim=-1).detach().cpu().numpy())
    acc_train = np.mean((np.reshape(labels_true, -1) == np.reshape(labels_pred, -1)))

    # Eval loop. Due to memory, we don't pass the whole data set to the model
    labels_true_val = []
    labels_pred_val = []
    for batch_x_val, batch_y_val in dataloader_val:
        cavity_model.eval()
        batch_y_pred_val = cavity_model(batch_x_val)
        labels_true_val.append(torch.argmax(batch_y_val, dim=-1).detach().cpu().numpy())
        labels_pred_val.append(
            torch.argmax(batch_y_pred_val, dim=-1).detach().cpu().numpy()
        )
    acc_val = np.mean(
        (np.reshape(labels_true_val, -1) == np.reshape(labels_pred_val, -1))
    )

    print(
        f"Epoch {epoch+1:2d}. Train loss: {loss_running_mean:5.3f}. "
        f"Train Acc: {acc_train:4.2f}. Val Acc: {acc_val:4.2f}"
    )

Epoch  1. Train loss: 1.292. Train Acc: 0.04. Val Acc: 0.06
Epoch  2. Train loss: 0.787. Train Acc: 0.61. Val Acc: 0.06
Epoch  3. Train loss: 0.552. Train Acc: 0.85. Val Acc: 0.07
Epoch  4. Train loss: 0.406. Train Acc: 0.95. Val Acc: 0.06
Epoch  5. Train loss: 0.302. Train Acc: 0.99. Val Acc: 0.07


<IPython.core.display.Javascript object>