# Download and process data

In [1]:
# Run shell script that takes a .txt file with PDBIDs as input.
!./download_and_process_data.sh pdbids_010.txt

Successfully downloaded 4X2U.pdb to data/raw/4X2U.pdb. 1/10.
Successfully downloaded 2X96.pdb to data/raw/2X96.pdb. 2/10.
Successfully downloaded 4MXD.pdb to data/raw/4MXD.pdb. 3/10.
Successfully downloaded 3E9L.pdb to data/raw/3E9L.pdb. 4/10.
Successfully downloaded 1UWC.pdb to data/raw/1UWC.pdb. 5/10.
Successfully downloaded 4BGU.pdb to data/raw/4BGU.pdb. 6/10.
Successfully downloaded 2YSW.pdb to data/raw/2YSW.pdb. 7/10.
Successfully downloaded 4OW4.pdb to data/raw/4OW4.pdb. 8/10.
Successfully downloaded 2V5E.pdb to data/raw/2V5E.pdb. 9/10.
Successfully downloaded 1IXH.pdb to data/raw/1IXH.pdb. 10/10.
Successfully cleaned data/raw/1IXH.pdb to data/cleaned/data/raw/1IXH.pdb.pdb. 1/10.
Successfully cleaned data/raw/1UWC.pdb to data/cleaned/data/raw/1UWC.pdb.pdb. 2/10.
Successfully cleaned data/raw/2V5E.pdb to data/cleaned/data/raw/2V5E.pdb.pdb. 3/10.
Successfully cleaned data/raw/2X96.pdb to data/cleaned/data/raw/2X96.pdb.pdb. 4/10.
Successfully cleaned data/raw/2YSW.pdb to data/cleane

# Imports

In [30]:
import random
import glob

import torch
import numpy as np

import matplotlib.pyplot as plt

from torch.utils.data import Dataset, DataLoader
from typing import List

# Settings

In [3]:
DEVICE = "cpu" # "cpu" or "cuda"
BATCH_SIZE = 100
LEARNING_RATE = 0.0003
EPOCHS = 1

In [32]:
class ResidueEnvironment():
    """
    Residue environment class
    """
    def __init__(self, coords_2d_arr: np.ndarray, 
                 atom_types: np.ndarray, aa_onehot: np.ndarray):
        self.coords_2d_arr = coords_2d_arr
        self.atom_types = atom_types
        self.aa_onehot = aa_onehot

class ResidueEnvironmentsDataset(Dataset):
    def __init__(npz_filenames: List[str], transform=None):
        """
        Load parsed pdb data in .npz format
        """
        self.res_env_objects = []
        for i in range(len(npz_filenames)):
            coordinate_features = np.load(coordinate_features_filenames[i])
            atom_coords_prot_seq = coordinate_features["positions"]
            restype_onehots_prot_seq = coordinate_features["aa_onehot"]
            selector_prot_seq = coordinate_features["selector"]
            atom_types_flattened = coordinate_features["atom_types_numeric"]
            N_residues = selector_prot_seq.shape[0]
            for resi_i in range(N_residues):
                selector = selector_prot_seq[resi_i]
                selector_masked = selector[selector>-1] # Remove Filler
                coords_mask = atom_coords_prot_seq[resi_i, :, 0] != -99.0 # To remove filler
                coords = atom_coords_prot_seq[resi_i][coords_mask]            
                atom_types = atom_types_flattened[selector_masked]
                restype_onehot = restype_onehots_prot_seq[resi_i]
                res_env_objects.append(ResidueEnvironment(coords, atom_types, restype_onehot))
            
        def __len__(self):
            return len(self.res_env_objects)

        def __getitem__(self, idx):
            if torch.is_tensor(idx):
                idx = idx.tolist()

            img_name = os.path.join(self.root_dir,
                                    self.landmarks_frame.iloc[idx, 0])
            image = io.imread(img_name)
            landmarks = self.landmarks_frame.iloc[idx, 1:]
            landmarks = np.array([landmarks])
            landmarks = landmarks.astype('float').reshape(-1, 2)
            sample = {'image': image, 'landmarks': landmarks}

            if self.transform:
                sample = self.transform(sample)

            return sample

        
        

In [None]:
u = [1,2,3,4,5]
print(u[])

# Data Parser

In [14]:
class ResidueEnvironmentDensity():
    """
    Environment class, which contains all the relevant information on the structural environment
    """
    def __init__(self, coords_2d_arr: np.ndarray, atom_types, aa_onehot):
        self.coords_2d_arr = coords_2d_arr
        self.atom_types = atom_types
        self.aa_onehot = aa_onehot
        
    def get_coords_2d_arr(self):
        return self.coords_2d_arr
    
    def get_aa_onehot(self):
        return self.aa_onehot
    
    def get_atom_types(self):
        return self.atom_types

def parse_data(coordinate_features_filenames):
    """
    Function that parses environment files ("coordinate_features.npz") and 
    returns ResidueEnvironmentDensity objects
    """
    env_objects = []
    for i in range(len(coordinate_features_filenames)):
        coordinate_features = np.load(coordinate_features_filenames[i])

        atom_coords_prot_seq = coordinate_features["positions"]
        restype_onehots_prot_seq = coordinate_features["aa_onehot"]
        selector_prot_seq = coordinate_features["selector"]
        atom_types_flattened = coordinate_features["atom_types_numeric"]
        # Loop over protein sequence
        N_residues = selector_prot_seq.shape[0]
        for resi_i in range(N_residues):
            # Loop over surrounding atoms in the environment from the perspective of the residues
            selector = selector_prot_seq[resi_i]
            selector_masked = selector[ selector > -1 ] # Remove Filler
            coords_mask = atom_coords_prot_seq[resi_i, :, 0] != -99.0 # To remove filler

            # Relevant data
            coords = atom_coords_prot_seq[resi_i][ coords_mask ]            
            atom_types = atom_types_flattened[selector_masked]
            restype_onehot = restype_onehots_prot_seq[resi_i]
            env_objects.append(ResidueEnvironmentDensity(coords, atom_types, restype_onehot))            
    return env_objects

def get_batch_x_and_batch_y(res_env_obj_list):
    """
    Function that takes a list of ResidueEnvironmentDensity objects and returns
    batches with dimensions that match the requirements of the CNN for x and y
    """
    batch_aa_onehots = []
    batch_coords = []
    batch_atom_types = []
    for i, res_env_obj in enumerate(res_env_obj_list):
        batch_aa_onehots.append(res_env_obj.get_aa_onehot())
        batch_coords.append(res_env_obj.get_coords_2d_arr())
        batch_atom_types.append(res_env_obj.get_atom_types())
    
    env_data_all = []
    for env_i in range(len(batch_coords)):
        n_atoms = np.array(batch_coords[env_i]).shape[0]
        env_i_vector = np.zeros(n_atoms)+i 
        atom_types = batch_atom_types[env_i]
        coords = batch_coords[env_i]
        env_data = np.hstack( [np.reshape(np.zeros(n_atoms)+env_i, [-1, 1]), np.reshape(atom_types, [-1, 1]), coords] )
        env_data_all.append(env_data)
        
    env_data_all_stacked = np.vstack(env_data_all)
    return env_data_all_stacked, np.array(batch_aa_onehots)

# Model

In [24]:
class CavityModel(torch.nn.Module):
    def __init__(self, device: str, sigma: float = 0.6):
        super().__init__()
        self.device = device
        self.n_atom_types = 6
        self.p = 1.0 # Bins pr. Anstrom
        self.n = 18  # Grid dimension
        self.sigma = sigma # width of gaussian
        self.sigma_p = self.sigma*self.p
        self.a = np.linspace(start=-self.n/2*self.p + self.p/2, 
                             stop=self.n/2*self.p - self.p/2, 
                             num=self.n) 
        self.xx, self.yy, self.zz = torch.tensor(np.meshgrid(self.a, self.a, self.a, indexing="ij"),
                                                 dtype = torch.float32).to(self.device)

        self.conv1 = torch.nn.Sequential(torch.nn.Conv3d(6, 16, kernel_size=(3,3,3), stride=2, padding=1),
                                         torch.nn.ReLU(), 
                                         torch.nn.BatchNorm3d(16))
        self.conv2 = torch.nn.Sequential(torch.nn.Conv3d(16, 32, kernel_size=(3,3,3), stride=2, padding=0),
                                         torch.nn.ReLU(), 
                                         torch.nn.BatchNorm3d(32))
        self.conv3 = torch.nn.Sequential(torch.nn.Conv3d(32, 64, kernel_size=(3,3,3), stride=1, padding=1),
                                         torch.nn.ReLU(), 
                                         torch.nn.BatchNorm3d(64),
                                         torch.nn.Flatten())
        self.dense1 = torch.nn.Sequential(torch.nn.Linear(in_features=4096, out_features=128),
                                          torch.nn.ReLU(), 
                                          torch.nn.BatchNorm1d(128))
        self.dense2 = torch.nn.Linear(in_features=128, out_features=21)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self._gaussian_blurring(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.dense1(x)
        x = self.dense2(x)
        return x 

    def _gaussian_blurring(self, x: torch.Tensor) -> torch.Tensor:
        current_batch_size = torch.unique(x[:, 0]).shape[0]       
        fields_torch = torch.zeros((current_batch_size, self.n_atom_types, self.n, self.n, self.n)).to(DEVICE)
        for j in range(self.n_atom_types):
            mask_j = x[:,1]==j
            atom_type_j_data = x[mask_j]
            if atom_type_j_data.shape[0] > 0:
                pos = atom_type_j_data[:, 2:]
                density = torch.exp(-((torch.reshape(self.xx, [-1, 1]) - pos[:,0])**2 +\
                                      (torch.reshape(self.yy, [-1, 1]) - pos[:,1])**2 +\
                                      (torch.reshape(self.zz, [-1, 1]) - pos[:,2])**2) / (2 * self.sigma_p**2))
                # Normalize each atom to 1
                density /= torch.sum(density, dim=0)
                # Since column 0 of atom_type_j_data is sorted
                # I can use a trick to detect the boundaries based
                # on the change from one value to another.
                change_mask_j = (atom_type_j_data[:,0][:-1] != atom_type_j_data[:,0][1:]) # detect change in column 0
                # Add begin- and end indices
                ranges_i = torch.cat([torch.tensor([0]),
                                      torch.arange(atom_type_j_data.shape[0]-1)[change_mask_j]+1, 
                                      torch.tensor([atom_type_j_data.shape[0]]) ])
                for i in range(ranges_i.shape[0]):
                    if i < ranges_i.shape[0] - 1:
                        index_0, index_1 = ranges_i[i], ranges_i[i+1]
                        fields = torch.reshape(torch.sum(density[:,index_0:index_1], dim = 1), 
                                               [self.n, self.n, self.n])
                        fields_torch[i,j,:,:,:] = fields
        return fields_torch


# Parse

In [25]:
# Parse environment files (one per protein)
env_filenames_wild_card = "data/parsed/*coord*"
env_filenames = sorted(glob.glob(env_filenames_wild_card))
random.shuffle(env_filenames) # sort for good measure
env_filenames_train = env_filenames[:8] # proteins for training
env_filenames_test = env_filenames[8:] #  proteins for testing

env_objects_train = parse_data(env_filenames_train)
env_objects_test = parse_data(env_filenames_test)
print ("Number of training environments:", len(env_objects_train))
print ("Number of testing environments: ", len(env_objects_test))


Number of training environments: 3389
Number of testing environments:  1869


## Train

In [26]:
# Define model
aa_pred_conv3d = CavityModel(DEVICE, SIGMA).to(DEVICE)
loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(aa_pred_conv3d.parameters(), lr = LEARNING_RATE)
random.shuffle(env_objects_train)

# Define random subsets of training and testing sets for evaluation.
# This makes it simpler to fit in a single forward pass (memory-wise), 
# instead of iterating over all the data
sample_size = 500
train_eval_inds = np.random.choice(np.arange(len(env_objects_train)), size = sample_size, replace = False)
test_eval_inds = np.random.choice(np.arange(len(env_objects_test)), size = sample_size, replace = False)
env_objects_train_eval = [env_objects_train[ind] for ind in train_eval_inds]
env_objects_test_eval = [env_objects_test[ind] for ind in test_eval_inds]
x_train_torch, y_train_torch = [torch.tensor(tens, dtype = torch.float32, requires_grad = False).to(DEVICE) for tens in get_batch_x_and_batch_y(env_objects_train_eval)]
x_test_torch, y_test_torch = [torch.tensor(tens, dtype = torch.float32, requires_grad = False).to(DEVICE) for tens in get_batch_x_and_batch_y(env_objects_test_eval)]

# Train
random.shuffle(env_objects_train)
loss_train_list = []
loss_test_list = []
acc_train_list = []
acc_test_list = []
for epoch_i in range(EPOCHS):
    for i in range(0, len(env_objects_train), BATCH_SIZE):
        if i+BATCH_SIZE > len(env_objects_train): # skip last batch if it is smaller than batch size (small batch might mess up batch norm)
            continue
        # Define training batch
        env_objs_batch_train = env_objects_train[i:i+BATCH_SIZE]
        batch_x_train, batch_y_train = [torch.tensor(tens, dtype = torch.float32).to(DEVICE) for tens in get_batch_x_and_batch_y(env_objs_batch_train)]

        # Set the parameter gradients to zero
        # Forward pass, backward pass, optimize
        aa_pred_conv3d.train()
        optimizer.zero_grad()    
        batch_y_pred = aa_pred_conv3d(batch_x_train)
        labels = torch.argmax(batch_y_train, dim = -1)
        loss_batch_train = loss(batch_y_pred, labels)
        loss_batch_train.backward()
        optimizer.step()
        
        # Evaluate on the big subset of training and testing environments
        if i % 1000 == 0:
            aa_pred_conv3d.eval()
            
            # Training eval
            y_pred_train = aa_pred_conv3d(x_train_torch)
            labels_train_pred = torch.argmax(y_pred_train, dim = -1)
            labels_train = torch.argmax(y_train_torch, dim = -1)
            loss_train = loss(y_pred_train, labels_train).item()
            loss_train_list.append(loss_train)
            accuracy_train = torch.mean((labels_train == labels_train_pred).double()).item()
            acc_train_list.append(accuracy_train)
            
            # Testing eval
            y_pred_test = aa_pred_conv3d(x_test_torch)
            labels_test_pred = torch.argmax(y_pred_test, dim = -1)
            labels_test = torch.argmax(y_test_torch, dim = -1)
            loss_test = loss(y_pred_test, labels_test).item()
            loss_test_list.append(loss_test)
            accuracy_test = torch.mean((labels_test == labels_test_pred).double()).item()
            acc_test_list.append(accuracy_test)
            
            print ("epoch {:1d}/{:1d}, step {:6d}, loss(train): {:5.2f}, loss(test): {:5.2f}, accuracy(train): {:5.2f}, accuracy(test): {:5.2f}".format(epoch_i+1, EPOCHS, i, loss_train, loss_test, accuracy_train, accuracy_test))
            
    # shuffle after each epoch
    random.shuffle(env_objects_train)

epoch 1/1, step      0, loss(train):  3.05, loss(test):  3.05, accuracy(train):  0.05, accuracy(test):  0.08


KeyboardInterrupt: 

## Plot

In [None]:
fig, ax_arr = plt.subplots(2, sharex=True)
ax_arr[0].plot(np.arange(len(loss_train_list)), loss_train_list, label = "train")
ax_arr[0].plot(np.arange(len(loss_test_list)), loss_test_list, label = "test")
ax_arr[1].plot(np.arange(len(acc_train_list)), acc_train_list, label = "train")
ax_arr[1].plot(np.arange(len(acc_test_list)), acc_test_list, label = "test")
ax_arr[0].set_ylabel("Cross Entropy loss")
ax_arr[1].set_ylabel("Accuracy")
plt.legend()
plt.show()

### Discussion
This demo demonstrates how to classify the missing amino acid from a structural environment where the atoms if the query amino acid have been purposefully removed.

I train a simple Conv3D net to learn the shape/contact-pattern that the missing atoms have left behind.

Given the very limited number of PDB structures (250) used in this demo, it is not surprising that we see much better performance on the training data compared to the testing data. Should you want to use the demo, please include more PDB structures. 
For reference, I use around 2k structures that have been homology reduces in sequence with much success.