# Imports

In [1]:
import random
import glob

import torch
import numpy as np
from typing import List

from torch.utils.data import Dataset, DataLoader

# Download and process data

In [2]:
# Run shell script that takes a .txt file with PDBIDs as input.
!./download_and_process_data.sh pdbids_010.txt

Successfully downloaded 4X2U.pdb to data/raw/4X2U.pdb. 1/10.
Successfully downloaded 2X96.pdb to data/raw/2X96.pdb. 2/10.
Successfully downloaded 4MXD.pdb to data/raw/4MXD.pdb. 3/10.
Successfully downloaded 3E9L.pdb to data/raw/3E9L.pdb. 4/10.
Successfully downloaded 1UWC.pdb to data/raw/1UWC.pdb. 5/10.
Successfully downloaded 4BGU.pdb to data/raw/4BGU.pdb. 6/10.
Successfully downloaded 2YSW.pdb to data/raw/2YSW.pdb. 7/10.
Successfully downloaded 4OW4.pdb to data/raw/4OW4.pdb. 8/10.
Successfully downloaded 2V5E.pdb to data/raw/2V5E.pdb. 9/10.
Successfully downloaded 1IXH.pdb to data/raw/1IXH.pdb. 10/10.
Successfully cleaned data/raw/1IXH.pdb and added it to data/cleaned/. 1/10.
Successfully cleaned data/raw/1UWC.pdb and added it to data/cleaned/. 2/10.
Successfully cleaned data/raw/2V5E.pdb and added it to data/cleaned/. 3/10.
Successfully cleaned data/raw/2X96.pdb and added it to data/cleaned/. 4/10.
Successfully cleaned data/raw/2YSW.pdb and added it to data/cleaned/. 5/10.
Successfu

# Settings

In [3]:
DEVICE = "cuda" # "cpu" or "cuda"
BATCH_SIZE = 100
LEARNING_RATE = 0.0003
EPOCHS = 10
TRAIN_VAL_SPLIT = 0.8

# Data set and data loader

In [4]:
class ResidueEnvironment:
    """
    Residue environment class
    """
    def __init__(self, coords_2d_arr: np.ndarray, atom_types: np.ndarray, aa_onehot: np.ndarray):
        self._coords_2d_arr = coords_2d_arr
        self._atom_types = atom_types
        self._aa_onehot = aa_onehot
        
    @property
    def coords_2d_arr(self):
        return self._coords_2d_arr
    
    @property
    def atom_types(self):
        return self._atom_types
    
    @property
    def aa_onehot(self):
        return self._aa_onehot
        
    def __repr__(self):
        return (f"<ResidueEnvironment objects with {self.coords_2d_arr.shape[0]} "
                f"atoms and residue class {np.argmax(self.aa_onehot)}>")

        
class ResidueEnvironmentsDataset(Dataset):
    def __init__(self, npz_filenames: List[str], transform=None):
        self._res_env_objects = self._parse_envs(npz_filenames)
        self._transform = transform
        
    @property
    def res_env_objects(self):
        return self._res_env_objects
    
    @property
    def transform(self):
        return self._transform

    def __len__(self):
        return len(self.res_env_objects)

    def __getitem__(self, idx):
        sample = self.res_env_objects[idx]        
        if self.transform:
            sample = self.transform(sample)
        return sample
    
    def _parse_envs(self, npz_filenames):
        res_env_objects = []
        for i in range(len(npz_filenames)):
            coordinate_features = np.load(npz_filenames[i])
            atom_coords_prot_seq = coordinate_features["positions"]
            restype_onehots_prot_seq = coordinate_features["aa_onehot"]
            selector_prot_seq = coordinate_features["selector"]
            atom_types_flattened = coordinate_features["atom_types_numeric"]
            N_residues = selector_prot_seq.shape[0]
            for resi_i in range(N_residues):
                selector = selector_prot_seq[resi_i]
                selector_masked = selector[selector>-1] # Remove Filler
                coords_mask = atom_coords_prot_seq[resi_i, :, 0] != -99.0 # To remove filler
                coords = atom_coords_prot_seq[resi_i][coords_mask]            
                atom_types = atom_types_flattened[selector_masked]
                restype_onehot = restype_onehots_prot_seq[resi_i]
                res_env_objects.append(ResidueEnvironment(coords, atom_types, restype_onehot))
        return res_env_objects


class ToTensor:
    def __call__(self, sample):        
        sample_env = np.hstack([np.reshape(sample.atom_types, [-1, 1]),
                               sample.coords_2d_arr])
        
        return {"x_": torch.tensor(sample_env, dtype=torch.float32).to(DEVICE), 
                "y_": torch.tensor(np.array(sample.aa_onehot), dtype=torch.float32).to(DEVICE)}

    @staticmethod
    def collate_cat(batch):
        target = torch.cat([torch.unsqueeze(b['y_'], 0) for b in batch], dim=0)
            
        # To collate the input, we need to add a column which 
        # specifies the environtment each atom belongs to
        env_id_batch = []
        for i, b in enumerate(batch):
            n_atoms = b['x_'].shape[0]
            env_id_arr = torch.zeros(n_atoms, dtype=torch.float32).to(DEVICE) + i
            env_id_batch.append(torch.cat([torch.unsqueeze(env_id_arr, 1), b['x_']], dim=1))            
        data = torch.cat(env_id_batch, dim=0)
            
        return data, target

# Model

In [5]:
class CavityModel(torch.nn.Module):
    def __init__(self, device: str, sigma: float = 0.6):
        super().__init__()
        self.device = device
        self.n_atom_types = 6
        self.p = 1.0 # Bins pr. Anstrom
        self.n = 18  # Grid dimension
        self.sigma = sigma # width of gaussian
        self.sigma_p = self.sigma*self.p
        self.a = np.linspace(start=-self.n/2*self.p + self.p/2, 
                             stop=self.n/2*self.p - self.p/2, 
                             num=self.n) 
        self.xx, self.yy, self.zz = torch.tensor(np.meshgrid(self.a, self.a, self.a, indexing="ij"),
                                                 dtype = torch.float32).to(self.device)

        self.conv1 = torch.nn.Sequential(torch.nn.Conv3d(6, 16, kernel_size=(3,3,3), stride=2, padding=1),
                                         torch.nn.ReLU(), 
                                         torch.nn.BatchNorm3d(16))
        self.conv2 = torch.nn.Sequential(torch.nn.Conv3d(16, 32, kernel_size=(3,3,3), stride=2, padding=0),
                                         torch.nn.ReLU(), 
                                         torch.nn.BatchNorm3d(32))
        self.conv3 = torch.nn.Sequential(torch.nn.Conv3d(32, 64, kernel_size=(3,3,3), stride=1, padding=1),
                                         torch.nn.ReLU(), 
                                         torch.nn.BatchNorm3d(64),
                                         torch.nn.Flatten())
        self.dense1 = torch.nn.Sequential(torch.nn.Linear(in_features=4096, out_features=128),
                                          torch.nn.ReLU(), 
                                          torch.nn.BatchNorm1d(128))
        self.dense2 = torch.nn.Linear(in_features=128, out_features=21)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self._gaussian_blurring(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.dense1(x)
        x = self.dense2(x)
        return x 

    def _gaussian_blurring(self, x: torch.Tensor) -> torch.Tensor:
        current_batch_size = torch.unique(x[:, 0]).shape[0]
        fields_torch = torch.zeros((current_batch_size, self.n_atom_types, self.n, self.n, self.n)).to(DEVICE)
        for j in range(self.n_atom_types):
            mask_j = x[:,1]==j
            atom_type_j_data = x[mask_j]
            if atom_type_j_data.shape[0] > 0:
                pos = atom_type_j_data[:, 2:]
                density = torch.exp(-((torch.reshape(self.xx, [-1, 1]) - pos[:,0])**2 +\
                                      (torch.reshape(self.yy, [-1, 1]) - pos[:,1])**2 +\
                                      (torch.reshape(self.zz, [-1, 1]) - pos[:,2])**2) / (2 * self.sigma_p**2))
                # Normalize each atom to 1
                density /= torch.sum(density, dim=0)
                # Since column 0 of atom_type_j_data is sorted
                # I can use a trick to detect the boundaries based
                # on the change from one value to another.
                change_mask_j = (atom_type_j_data[:,0][:-1] != atom_type_j_data[:,0][1:]) # detect change in column 0
                # Add begin- and end indices
                ranges_i = torch.cat([torch.tensor([0]),
                                      torch.arange(atom_type_j_data.shape[0]-1)[change_mask_j]+1, 
                                      torch.tensor([atom_type_j_data.shape[0]]) ])
                for i in range(ranges_i.shape[0]):
                    if i < ranges_i.shape[0] - 1:
                        index_0, index_1 = ranges_i[i], ranges_i[i+1]
                        fields = torch.reshape(torch.sum(density[:,index_0:index_1], dim = 1), 
                                               [self.n, self.n, self.n])
                        fields_torch[i,j,:,:,:] = fields
        return fields_torch


# Parse and train/val split

In [6]:
parsed_pdb_filenames = sorted(glob.glob("data/parsed/*coord*"))
random.shuffle(parsed_pdb_filenames)

n_train_pdbs = int(len(parsed_pdb_filenames)*TRAIN_VAL_SPLIT)
filenames_train = parsed_pdb_filenames[:n_train_pdbs]
filenames_val = parsed_pdb_filenames[n_train_pdbs:]

data_set_train = ResidueEnvironmentsDataset(filenames_train, transform=ToTensor())
data_set_val = ResidueEnvironmentsDataset(filenames_val, transform=ToTensor())

dataloader_train = DataLoader(data_set_train, batch_size=BATCH_SIZE, shuffle=True, 
                              collate_fn=ToTensor.collate_cat, drop_last=True)
dataloader_val = DataLoader(data_set_val, batch_size=BATCH_SIZE, shuffle=True, 
                            collate_fn=ToTensor.collate_cat, drop_last=True)

print(f"Training data set includes {len(filenames_train)} pdbs with {len(data_set_train)} environments.")
print(f"Validation data setincludes {len(filenames_val)} pdbs with {len(data_set_val)} environments.")

Training data set includes 8 pdbs with 4414 environments.
Validation data setincludes 2 pdbs with 844 environments.


# Instantiate model, loss and optimizer
# Train

In [7]:
# Define model
cavity_model = CavityModel(DEVICE).to(DEVICE)
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cavity_model.parameters(), lr=LEARNING_RATE)

for epoch in range(EPOCHS):
    # Train loop
    loss_running_mean = 0.0
    labels_true = []
    labels_pred = []    
    for batch_x, batch_y in dataloader_train:
        cavity_model.train()
        optimizer.zero_grad()    
        batch_y_pred = cavity_model(batch_x)
        loss_batch = loss(batch_y_pred, torch.argmax(batch_y, dim=-1))
        loss_batch.backward()
        optimizer.step()

        # Exponential running mean for the loss
        loss_running_mean = loss_running_mean*0.9 + loss_batch*0.1
        
        labels_true.append(torch.argmax(batch_y, dim=-1).detach().cpu().numpy())
        labels_pred.append(torch.argmax(batch_y_pred, dim=-1).detach().cpu().numpy())
    acc_train = np.mean((np.reshape(labels_true, -1) == np.reshape(labels_pred, -1)))
    
    # Eval loop. Due to memory, we don't want to pass the whole eval data set in one go
    labels_true_val = []
    labels_pred_val = []
    for batch_x_val, batch_y_val in dataloader_val:
        cavity_model.eval()
        batch_y_pred_val = cavity_model(batch_x_val)
        labels_true_val.append(torch.argmax(batch_y_val, dim=-1).detach().cpu().numpy())
        labels_pred_val.append(torch.argmax(batch_y_pred_val, dim=-1).detach().cpu().numpy())
    acc_val = np.mean((np.reshape(labels_true_val, -1) == np.reshape(labels_pred_val, -1)))

    print(f"Epoch {epoch+1:2d}. Train loss: {loss_running_mean:5.3f}. "
          f"Train Acc: {acc_train:4.2f}. Val Acc: {acc_val:4.2f}")   

Epoch  1. Train loss: 2.221. Train Acc: 0.30. Val Acc: 0.11
Epoch  2. Train loss: 1.257. Train Acc: 0.75. Val Acc: 0.20
Epoch  3. Train loss: 0.736. Train Acc: 0.92. Val Acc: 0.20
Epoch  4. Train loss: 0.395. Train Acc: 0.98. Val Acc: 0.20
Epoch  5. Train loss: 0.205. Train Acc: 1.00. Val Acc: 0.21
Epoch  6. Train loss: 0.115. Train Acc: 1.00. Val Acc: 0.21
Epoch  7. Train loss: 0.073. Train Acc: 1.00. Val Acc: 0.21
Epoch  8. Train loss: 0.052. Train Acc: 1.00. Val Acc: 0.21
Epoch  9. Train loss: 0.040. Train Acc: 1.00. Val Acc: 0.22
Epoch 10. Train loss: 0.032. Train Acc: 1.00. Val Acc: 0.21
