# Imports

In [1]:
from torch.utils.data import Dataset, DataLoader
from typing import List, Callable
import numpy as np
import torch
import glob
import random

%load_ext nb_black

<IPython.core.display.Javascript object>

# Download and process data

In [2]:
# Run shell script that takes a .txt file with PDBIDs as input.
!./download_and_process_data.sh pdbids_250.txt

Successfully downloaded 4X2U.pdb to data/raw/4X2U.pdb. 1/250.
Successfully downloaded 2X96.pdb to data/raw/2X96.pdb. 2/250.
Successfully downloaded 4MXD.pdb to data/raw/4MXD.pdb. 3/250.
Successfully downloaded 3E9L.pdb to data/raw/3E9L.pdb. 4/250.
Successfully downloaded 1UWC.pdb to data/raw/1UWC.pdb. 5/250.
Successfully downloaded 4BGU.pdb to data/raw/4BGU.pdb. 6/250.
Successfully downloaded 2YSW.pdb to data/raw/2YSW.pdb. 7/250.
Successfully downloaded 4OW4.pdb to data/raw/4OW4.pdb. 8/250.
Successfully downloaded 2V5E.pdb to data/raw/2V5E.pdb. 9/250.
Successfully downloaded 1IXH.pdb to data/raw/1IXH.pdb. 10/250.
Successfully downloaded 3ZR9.pdb to data/raw/3ZR9.pdb. 11/250.
Successfully downloaded 4O7Q.pdb to data/raw/4O7Q.pdb. 12/250.
Successfully downloaded 3OBL.pdb to data/raw/3OBL.pdb. 13/250.
Successfully downloaded 2YVP.pdb to data/raw/2YVP.pdb. 14/250.
Successfully downloaded 1UNK.pdb to data/raw/1UNK.pdb. 15/250.
Successfully downloaded 5B2H.pdb to data/raw/5B2H.pdb. 16/250.
S

Successfully downloaded 4UEJ.pdb to data/raw/4UEJ.pdb. 131/250.
Successfully downloaded 5HI8.pdb to data/raw/5HI8.pdb. 132/250.
Successfully downloaded 1HG0.pdb to data/raw/1HG0.pdb. 133/250.
Successfully downloaded 1EX7.pdb to data/raw/1EX7.pdb. 134/250.
Successfully downloaded 1V7C.pdb to data/raw/1V7C.pdb. 135/250.
Successfully downloaded 1QJ8.pdb to data/raw/1QJ8.pdb. 136/250.
Successfully downloaded 2JHF.pdb to data/raw/2JHF.pdb. 137/250.
Successfully downloaded 2Z3V.pdb to data/raw/2Z3V.pdb. 138/250.
Successfully downloaded 3BL2.pdb to data/raw/3BL2.pdb. 139/250.
Successfully downloaded 2OKT.pdb to data/raw/2OKT.pdb. 140/250.
Successfully downloaded 1QB7.pdb to data/raw/1QB7.pdb. 141/250.
Successfully downloaded 3OUE.pdb to data/raw/3OUE.pdb. 142/250.
Successfully downloaded 2RBK.pdb to data/raw/2RBK.pdb. 143/250.
Successfully downloaded 4FD9.pdb to data/raw/4FD9.pdb. 144/250.
Successfully downloaded 1BMQ.pdb to data/raw/1BMQ.pdb. 145/250.
Successfully downloaded 1DW0.pdb to data

<IPython.core.display.Javascript object>

# Global variables

In [3]:
DEVICE = "cuda"  # "cpu" or "cuda"
BATCH_SIZE = 100
LEARNING_RATE = 3e-4
EPOCHS = 5
TRAIN_VAL_SPLIT = 0.8

<IPython.core.display.Javascript object>

# Data classes

In [4]:
class ResidueEnvironment:
    """
    Residue environment class used to hold necessarry information about the
    atoms of the environment such as atomic coordinates, atom types and the
    class of the missing central amino acid.

    Parameters
    ----------
    xyz_coords: np.ndarray
        Numpy array with shape (n_atoms, 3) containing the x, y, z coordinates.
    atom_types: np.ndarray
        1D numpy array containing the atom types. Integer values in range(6).
    restypes_onehot: np.ndarray
        Numpy array with shape (n_atoms, 21) containing the amino acid
        class of the missing amino acid
    """

    def __init__(
        self,
        xyz_coords: np.ndarray,
        atom_types: np.ndarray,
        restypes_onehot: np.ndarray,
    ):
        self._xyz_coords = xyz_coords
        self._atom_types = atom_types
        self._restypes_onehot = restypes_onehot

    @property
    def xyz_coords(self):
        return self._xyz_coords

    @property
    def atom_types(self):
        return self._atom_types

    @property
    def restypes_onehot(self):
        return self._restypes_onehot

    def __repr__(self):
        return (
            f"<ResidueEnvironment objects with {self.xyz_coords.shape[0]} "
            f"atoms and residue class {np.argmax(self.restypes_onehot)}>"
        )


class ResidueEnvironmentsDataset(Dataset):
    """
    Residue environment dataset class

    Parameters
    ----------
    npz_filenames: List[str]
        List of parsed pdb filenames in .npz format
    transform: Callable
        A to-tensor transformer class
    """

    def __init__(self, npz_filenames: List[str], transform: Callable = None):
        self._res_env_objects = self._parse_envs(npz_filenames)
        self._transform = transform

    @property
    def res_env_objects(self):
        return self._res_env_objects

    @property
    def transform(self):
        return self._transform

    def __len__(self):
        return len(self.res_env_objects)

    def __getitem__(self, idx):
        sample = self.res_env_objects[idx]
        if self.transform:
            sample = self.transform(sample)
        return sample

    def _parse_envs(self, npz_filenames: List[str]) -> List[ResidueEnvironment]:
        res_env_objects = []
        for i in range(len(npz_filenames)):
            coordinate_features = np.load(npz_filenames[i])
            atom_coords_prot_seq = coordinate_features["positions"]
            restypes_onehots_prot_seq = coordinate_features["aa_onehot"]
            selector_prot_seq = coordinate_features["selector"]
            atom_types_flattened = coordinate_features["atom_types_numeric"]
            N_residues = selector_prot_seq.shape[0]
            for resi_i in range(N_residues):
                selector = selector_prot_seq[resi_i]
                selector_masked = selector[selector > -1]  # Remove Filler
                coords_mask = (
                    atom_coords_prot_seq[resi_i, :, 0] != -99.0
                )  # Remove filler
                coords = atom_coords_prot_seq[resi_i][coords_mask]
                atom_types = atom_types_flattened[selector_masked]
                restypes_onehot = restypes_onehots_prot_seq[resi_i]
                res_env_objects.append(
                    ResidueEnvironment(coords, atom_types, restypes_onehot)
                )
        return res_env_objects


class ToTensor:
    """ To-tensor transformer class"""

    def __call__(self, sample: ResidueEnvironment):
        """Converts single ResidueEnvironment object into x_ and y_"""

        sample_env = np.hstack(
            [np.reshape(sample.atom_types, [-1, 1]), sample.xyz_coords]
        )

        return {
            "x_": torch.tensor(sample_env, dtype=torch.float32).to(DEVICE),
            "y_": torch.tensor(
                np.array(sample.restypes_onehot), dtype=torch.float32
            ).to(DEVICE),
        }

    @staticmethod
    def collate_cat(batch: List[ResidueEnvironment]):
        """
        Collate method used by the dataloader to collate a
        batch of ResidueEnvironment objects.
        """
        target = torch.cat([torch.unsqueeze(b["y_"], 0) for b in batch], dim=0)

        # To collate the input, we need to add a column which
        # specifies the environtment each atom belongs to
        env_id_batch = []
        for i, b in enumerate(batch):
            n_atoms = b["x_"].shape[0]
            env_id_arr = torch.zeros(
                n_atoms, dtype=torch.float32).to(DEVICE) + i
            env_id_batch.append(
                torch.cat([torch.unsqueeze(env_id_arr, 1), b["x_"]], dim=1)
            )
        data = torch.cat(env_id_batch, dim=0)

        return data, target

<IPython.core.display.Javascript object>

# Model class

In [5]:
class CavityModel(torch.nn.Module):
    """
    3D convolutional neural network to missing amino acid classification

    Parameters
    ----------
    n_atom_types: int
        Number of atom types. (C, H, N, O, S, P)
    bins_per_angstrom: float
        Number of grid points per Anstrom.
    grid_dim: int
        Grid dimension
    sigma: float
        Standard deviation used for gaussian blurring
    """

    def __init__(
        self,
        n_atom_types: int = 6,
        bins_per_angstrom: float = 1.0,
        grid_dim: int = 18,
        sigma: float = 0.6,):
        
        super().__init__()

        self._n_atom_types = n_atom_types
        self._bins_per_angstrom = bins_per_angstrom
        self._grid_dim = grid_dim
        self._sigma = sigma

        self._model()

    @property
    def n_atom_types(self):
        return self._n_atom_types

    @property
    def bins_per_angstrom(self):
        return self._bins_per_angstrom

    @property
    def grid_dim(self):
        return self._grid_dim

    @property
    def sigma(self):
        return self._sigma

    @property
    def sigma_p(self):
        return self.sigma * self.bins_per_angstrom

    @property
    def lin_spacing(self):
        lin_spacing = np.linspace(
            start=-self.grid_dim / 2 * self.bins_per_angstrom
            + self.bins_per_angstrom / 2,
            stop=self.grid_dim / 2 * self.bins_per_angstrom
            - self.bins_per_angstrom / 2,
            num=self.grid_dim,
        )
        return lin_spacing

    def _model(self):
        self.xx, self.yy, self.zz = torch.tensor(
            np.meshgrid(
                self.lin_spacing, self.lin_spacing, 
                self.lin_spacing, indexing="ij"
            ),
            dtype=torch.float32,
        ).to(DEVICE)

        self.conv1 = torch.nn.Sequential(
            torch.nn.Conv3d(6, 16, kernel_size=(3, 3, 3), 
                            stride=2, padding=1),
            torch.nn.ReLU(),
            torch.nn.BatchNorm3d(16),
        )
        self.conv2 = torch.nn.Sequential(
            torch.nn.Conv3d(16, 32, kernel_size=(3, 3, 3), 
                            stride=2, padding=0),
            torch.nn.ReLU(),
            torch.nn.BatchNorm3d(32),
        )
        self.conv3 = torch.nn.Sequential(
            torch.nn.Conv3d(32, 64, kernel_size=(3, 3, 3), 
                            stride=1, padding=1),
            torch.nn.ReLU(),
            torch.nn.BatchNorm3d(64),
            torch.nn.Flatten(),
        )
        self.dense1 = torch.nn.Sequential(
            torch.nn.Linear(in_features=4096, out_features=128),
            torch.nn.ReLU(),
            torch.nn.BatchNorm1d(128),
        )
        self.dense2 = torch.nn.Linear(in_features=128, out_features=21)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self._gaussian_blurring(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.dense1(x)
        x = self.dense2(x)
        return x

    def _gaussian_blurring(self, x: torch.Tensor) -> torch.Tensor:
        """
        Method that takes 2d torch.Tensor describing the atoms of the batch.

        Parameters
        ----------
        x: torch.Tensor
            Tensor for shape (n_atoms, 5). Each row represents an atom, where:
                column 0 describes the environment of the batch the
                atom belongs to
                column 1 describes the atom type
                column 2,3,4 are the x, y, z coordinates, respectively

        Returns
        -------
        fields_torch: torch.Tensor
            Represents the structural environment with gaussian blurring
            and has shape (-1, self.grid_dim, self.grid_dim, self.grid_dim).
        """
        current_batch_size = torch.unique(x[:, 0]).shape[0]
        fields_torch = torch.zeros(
            (
                current_batch_size,
                self.n_atom_types,
                self.grid_dim,
                self.grid_dim,
                self.grid_dim,
            )
        ).to(DEVICE)
        for j in range(self.n_atom_types):
            mask_j = x[:, 1] == j
            atom_type_j_data = x[mask_j]
            if atom_type_j_data.shape[0] > 0:
                pos = atom_type_j_data[:, 2:]
                density = torch.exp(
                    -(
                        (torch.reshape(self.xx, [-1, 1]) - pos[:, 0]) ** 2
                        + (torch.reshape(self.yy, [-1, 1]) - pos[:, 1]) ** 2
                        + (torch.reshape(self.zz, [-1, 1]) - pos[:, 2]) ** 2
                    )
                    / (2 * self.sigma_p ** 2)
                )

                # Normalize each atom to 1
                density /= torch.sum(density, dim=0)

                # Since column 0 of atom_type_j_data is sorted
                # I can use a trick to detect the boundaries based
                # on the change from one value to another.
                change_mask_j = (
                    atom_type_j_data[:, 0][:-1] != atom_type_j_data[:, 0][1:]
                )

                # Add begin and end indices
                ranges_i = torch.cat(
                    [
                        torch.tensor([0]),
                        torch.arange(
                            atom_type_j_data.shape[0] - 1)[change_mask_j] + 1,
                        torch.tensor([atom_type_j_data.shape[0]]),
                    ]
                )

                # Fill tensor
                for i in range(ranges_i.shape[0]):
                    if i < ranges_i.shape[0] - 1:
                        index_0, index_1 = ranges_i[i], ranges_i[i + 1]
                        fields = torch.reshape(
                            torch.sum(density[:, index_0:index_1], dim=1),
                            [self.grid_dim, self.grid_dim, self.grid_dim],
                        )
                        fields_torch[i, j, :, :, :] = fields
        return fields_torch

<IPython.core.display.Javascript object>

# Parse and train/val split

In [6]:
parsed_pdb_filenames = sorted(glob.glob("data/parsed/*coord*"))
random.shuffle(parsed_pdb_filenames)

n_train_pdbs = int(len(parsed_pdb_filenames) * TRAIN_VAL_SPLIT)
filenames_train = parsed_pdb_filenames[:n_train_pdbs]
filenames_val = parsed_pdb_filenames[n_train_pdbs:]

dataset_train = ResidueEnvironmentsDataset(filenames_train, 
                                           transform=ToTensor())
dataset_val = ResidueEnvironmentsDataset(filenames_val, transform=ToTensor())

dataloader_train = DataLoader(
    dataset_train,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=ToTensor.collate_cat,
    drop_last=True,
)
dataloader_val = DataLoader(
    dataset_val,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=ToTensor.collate_cat,
    drop_last=True,
)

print(
    f"Training data set includes {len(filenames_train)} pdbs with "
    f"{len(dataset_train)} environments."
)
print(
    f"Validation data setincludes {len(filenames_val)} pdbs with "
    f"{len(dataset_val)} environments."
)

Training data set includes 40 pdbs with 18185 environments.
Validation data setincludes 10 pdbs with 3189 environments.


<IPython.core.display.Javascript object>

# Train

In [None]:
def _train_step(
    cavity_model: CavityModel,
    optimizer: torch.optim.Adam,
    loss_function: torch.nn.CrossEntropyLoss,
) -> (torch.Tensor, float):
    cavity_model.train()
    optimizer.zero_grad()
    batch_y_pred = cavity_model(batch_x)
    loss_batch = loss_function(batch_y_pred, torch.argmax(batch_y, dim=-1))
    loss_batch.backward()
    optimizer.step()
    return (batch_y_pred, loss_batch.detach().cpu().item())


# Define model
cavity_model = CavityModel().to(DEVICE)
loss_function = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cavity_model.parameters(), lr=LEARNING_RATE)

# Train loop
for epoch in range(EPOCHS):
    loss_running_mean = 0.0
    labels_true = []
    labels_pred = []
    for batch_x, batch_y in dataloader_train:
        # Take train step
        batch_y_pred, loss_batch = _train_step(cavity_model, optimizer, 
                                               loss_function)

        # Exponential running mean for the loss
        loss_running_mean = loss_running_mean * 0.9 + loss_batch * 0.1

        labels_true.append(
            torch.argmax(batch_y, dim=-1).detach().cpu().numpy()
        )
        labels_pred.append(
            torch.argmax(batch_y_pred, dim=-1).detach().cpu().numpy()
        )
    acc_train = np.mean(
        (np.reshape(labels_true, -1) == np.reshape(labels_pred, -1))
    )

    # Eval loop. Due to memory, we don't pass the whole data set to the model
    labels_true_val = []
    labels_pred_val = []
    for batch_x_val, batch_y_val in dataloader_val:
        cavity_model.eval()
        batch_y_pred_val = cavity_model(batch_x_val)
        labels_true_val.append(
            torch.argmax(batch_y_val, dim=-1).detach().cpu().numpy()
        )
        labels_pred_val.append(
            torch.argmax(batch_y_pred_val, dim=-1).detach().cpu().numpy()
        )
    acc_val = np.mean(
        (np.reshape(labels_true_val, -1) == np.reshape(labels_pred_val, -1))
    )

    print(
        f"Epoch {epoch+1:2d}. Train loss: {loss_running_mean:5.3f}. "
        f"Train Acc: {acc_train:4.2f}. Val Acc: {acc_val:4.2f}"
    )