In [1]:
from ABDB import database as db
import numpy as np
import json
import sys
import time
import datetime
from rich.progress import track
import copy
from sklearn.model_selection import train_test_split
import pickle
import torch
from einops import rearrange



### 1. Data preparation
**1.1 Get data from SAbDab**

Extract CDR sequencs and coordinate of backbone atoms from antibodies in SAbDab

In [2]:
# dictionaries to convert one letter, three letter and numerical amino acid codes
aa1 = "ACDEFGHIKLMNPQRSTVWY"
aa3 = ["ALA", "CYS", "ASP", "GLU", "PHE", "GLY", "HIS", "ILE", "LYS", "LEU", "MET", "ASN", "PRO", "GLN", "ARG", "SER",
       "THR", "VAL", "TRP", "TYR", ]

short2long = {}
long2short = {}
short2num = {}

for ind in range(0, 20):
    long2short[aa3[ind]] = aa1[ind]
    short2long[aa1[ind]] = aa3[ind]
    short2num[aa1[ind]] = ind

In [3]:
# functions to filter entries in SAbDAb

def filter_abs(pdb_list):
    '''
    Filter a list of PDB ids obtained from SAbDab and removes FABS where one of the chains is missing or where
    heavy and light chains have the same name.
    '''
    filtered_list = []
    i = 0
    
    for pdb in track(pdb_list, description='Filter FABs'):
        i += 1
        fab = db.fetch(pdb).fabs[0]

        if fab.VH == fab.VL:
            continue
        elif fab.VH == 'NA' or fab.VL == 'NA':
            continue
        else:
            filtered_list.append(pdb)

    return filtered_list

In [4]:
# functions to extract and format relevant data from a SAbDab FAB. Given a FAB two dictionaries are returned for CDR and anchor 
# sequences and thier backbone coordinates

def split_structure_in_regions(fab):
    '''
    Split FAB into regions.

    Takes FAB as input an returns a dictionary with keys: regions, values: residues in region
    regions = ['fwh1', 'cdrh1', 'fwh2', 'cdrh2', 'fwh3', 'cdrh3', 'fwh4', 'fwl1', 'cdrl1', 'fwl2', 'cdrl2', 'fwl3', 'cdrl3', 'fwl4']
    '''
    ab_regions = dict()
    struc = fab.get_structure()

    for chain in [fab.VH, fab.VL]:

        # Chian.get_residues() is a generator that loops through residue
        for residue in struc[chain].get_residues():

            # residue.region indicates in which cdr or framework region the residue is
            if residue.region in ab_regions:
                ab_regions[residue.region].append(residue)
            else:
                ab_regions[residue.region] = [residue]

    return ab_regions

def get_slice(ab_regions, CDR):
    '''
    Returns a slice of residues containing a CDR plus two anchor residues on each side,
    given a FAB split in to regions and a spefied CDR.
    '''
    chain = CDR[0].lower()
    loop = CDR[1]

    slice = ab_regions['fw' + chain + loop][-2:]
    slice += ab_regions['cdr' + chain + loop]
    slice += ab_regions['fw' + chain + str(int(loop) + 1 )][:2]

    return slice

def cdr_anchor_seq(ab_regions, CDR):
    '''
    Retruns sequence of a CDR plus two anchors on each side,
    given a FAB split in to regions and a spefied CDR.
    '''
    slice = get_slice(ab_regions, CDR)
    CDR_seq = []

    for res in slice:
        CDR_seq.append(long2short[res.resname])

    return CDR_seq

def cdr_anchor_BB_coord(ab_regions, CDR):
    '''
    Returns coordinates of backbone atoms of a CDR plus two anchors on each side,
    given a FAB split in to regions and a spefied CDR.
    '''
    slice = get_slice(ab_regions, CDR)
    CDR_BB_coord = np.zeros((len(slice), 4, 3))
    BB_atoms = ["CA", "C", "N", "CB"]

    for i in range(len(slice)):
        res = slice[i]
        for j in range(len(BB_atoms)):
            atom = BB_atoms[j]

            # if residue is glycine use CA coordinates for CB
            if res.resname == 'GLY' and atom == 'CB':
                atom = "CA"
                
            coord = res[atom].coord

            CDR_BB_coord[i, j, :] = coord
    
    return CDR_BB_coord

def get_cdr_anchor_seqs(ab_regions, CDRs = ["H1", "H2", "H3", "L1", "L2", "L3"]):
    '''
    Get sequences of all CDRs in a FAB.

    Returns a dictionary with keys: CDRs, values: CDR + anchor sequence
    '''
    CDR_seqs = dict()

    for CDR in CDRs:
        CDR_seqs[CDR] = cdr_anchor_seq(ab_regions, CDR)
   
    return CDR_seqs

def get_cdr_anchor_BB_coords(cdr_residues, CDRs = ["H1", "H2", "H3", "L1", "L2", "L3"]):
    '''
    Get backbone coordinates of all CDRs in a FAB.
    
    Returns a dictionary with keys: CDRs, values: CDR + anchor backbone coordinates
    '''
    CDR_BB_coords = dict()

    for CDR in CDRs:
        CDR_BB_coords[CDR] = cdr_anchor_BB_coord(cdr_residues, CDR)
   
    return CDR_BB_coords

In [5]:
# function to retrieve all PDB ids from SAbDab and runs functions in above cells for each individual FAB. 
def get_sabdab_fabs():
    '''
    Get all fabs from sabdab, extracts CDR sequences and coordinates and formats the data for the next steps.

    returns CDR_seqs: list of dictionaries,
                      each dictionary contains data of one FAB, keys: CDR, value: CDR sequence
    returns CDR_BB_coords: list of dictionaries,
                           each dictionary contains data of one FAB, keys: CDR, value: CDR backbone coordinates
    '''
    # use imgt numbering
    db.set_numbering_scheme("imgt")
    db.set_region_definition("imgt")

    # list of all pdb ids in SAbDab
    all_pdbs_in_sabdab = list(db.db_summary.keys())
    all_pdbs_in_sabdab = filter_abs(all_pdbs_in_sabdab)

    CDR_seqs = list()
    CDR_BB_coords = list()

    for pdb_id in track(all_pdbs_in_sabdab, description='Load data from SAbDab'):
        pdb = db.fetch(pdb_id)
        for fab in pdb.fabs:
            try: # some fab have errors and throw exceptions, ignore these
                ab_regions = split_structure_in_regions(fab)
                cdr_seqs = get_cdr_anchor_seqs(ab_regions)
                cdr_BB_coords = get_cdr_anchor_BB_coords(ab_regions)

                CDR_seqs.append(cdr_seqs)
                CDR_BB_coords.append(cdr_BB_coords)
            except Exception:
                pass

    return CDR_seqs, CDR_BB_coords

Code that runs the above functions to download all FABs from SAbDab and extract data about CDR and anchor sequences and backbone coordinates. Data is saved in a json as the code interacting with SAbDAb is slow.

In [6]:
CDR_seqs, CDR_BB_coords = get_sabdab_fabs()

In [31]:
with open('train_data/CDR_BB_coords.npy', 'wb') as outfile:
    np.save(outfile, CDR_BB_coords)

with open('train_data/CDR_seqs.npy', 'wb') as outfile:
    np.save(outfile, CDR_seqs)

In [6]:
with open('train_data/CDR_BB_coords.npy', 'rb') as infile:
    CDR_BB_coords = np.load(infile, allow_pickle=True)

with open('train_data/CDR_seqs.npy', 'rb') as infile:
    CDR_seqs = np.load(infile, allow_pickle=True)

**1.2 Format data to model inputs**

Data loaded from SAbDab is reformated to model inputs and the training outputs. 

Each backbone atom corresponds to one node in the graph. The atoms are encoded into a vector with 41 elements (one-hot encoding of amino acid residue, one-hot encoding of atom type, one-hot encoding of CDR loop, positional encoding within loop).

The input coordinates of each backbone atoms are processed as follows. Anchor residues keep their original position, the CDR residues are spaced equally on a straigt line between the two anchors.

The training output coordinates correspond to the backbone coordinates from the crystal structure formated identically to the input coordinates

In [7]:
# functions that convert data extracted from SAbDab to model input

def encode(x, classes):
    '''
    One hot encodes a scalar x into a vector of length classes.
    This is the function used for Sequence encoding.
    '''
    one_hot = np.zeros(classes)
    one_hot[x] = 1

    return one_hot

def one_hot(num_list, classes=20):
    '''
    One hot encodes a 1D vector x.
    This is the function used for Sequence encoding.
    '''
    end_shape = (len(num_list), classes)
    finish = np.zeros(end_shape)
    for i in range(end_shape[0]):
        finish[i] = encode(num_list[i], classes)

    return finish

def which_loop(loop_seq, cdr):
    '''
    Adds a one-hot encoded vector to each node describing which CDR it belongs to.
    '''
    CDRs = ["H1", "H2", "H3", "L1", "L2", "L3", "Anchor"]
    loop = np.zeros((len(loop_seq), len(CDRs)))
    loop[:, -1] = 1
    loop[2:-2] = np.array([1.0 if cdr == x else 0.0 for x in (CDRs)])[None].repeat(len(loop_seq) - 4, axis=0)

    return loop

def positional_encoding(sequence, n=5):
    '''
    Gives the network information on how close each resdiue is to the anchors
    '''
    encs = []
    L = len(sequence)
    for i in range(n):
        encs.append(np.cos((2 ** i) * np.pi * np.arange(L) / L))
        encs.append(np.sin((2 ** i) * np.pi * np.arange(L) / L))

    return np.array(encs).transpose()

def res_to_atom(res_encoding, n_atoms=4):
    '''
    Adds a one-hot encoded vector to each node describing what atom type it is.
    '''
    out_shape = (res_encoding.shape[0], n_atoms, 41)
    atom_encoding = np.zeros(out_shape)

    for i in range(len(res_encoding)):
        for j in range(n_atoms):
            atom_encoding[i, j, 0:37] = res_encoding[i]
            # add one-hot encoding for atom type
            atom_encoding[i, j, 37:] = one_hot([j], classes=n_atoms) 

    return atom_encoding

def prepare_input_loop(CDR_coord, CDR_seq, CDR):
    '''
    Generates input features to be fed into the network for a single CDR
    '''
    CDR_input_coord = copy.deepcopy(CDR_coord)
    # put CDR residues equally spaced on straight line between anchor residues 
    CDR_input_coord[1:-1] = np.linspace(CDR_coord[1], CDR_coord[-2], len(CDR_coord) - 2)
    # CDR_input_coord = rearrange(torch.tensor(CDR_input_coords), "i a d -> () (i a) d").float()

    one_hot_encoding = one_hot(np.array([short2num[amino] for amino in CDR_seq]))
    loop = which_loop(CDR_seq, CDR)
    positional = positional_encoding(CDR_seq)
    res_encoding = np.concatenate([one_hot_encoding, positional, loop], axis=1)
    atom_encoding = res_to_atom(res_encoding)

    # encoding = res_to_atom(torch.tensor(np.concatenate([one_hot_encoding, positional, loop], axis=1)).float())
    # encoding = rearrange(encoding, "i a d -> () (i a) d")

    return CDR_input_coord, atom_encoding

def prepare_model_input(CDR_seq, CDR_BB_coord):
    '''
    Prepares model inputs for a single FAB
    '''
    encodings = []
    geomins = []
    
    for CDR in CDR_BB_coord:
        geom, encode = prepare_input_loop(CDR_BB_coord[CDR], CDR_seq[CDR], CDR)
        encodings.append(encode)
        geomins.append(geom)

    # concatenate encodings and geoms into single array
    encodings = np.concatenate(encodings, axis=0)
    geomins = np.concatenate(geomins, axis=0)
    # format to tensor
    encodings = torch.from_numpy(encodings)
    geomins = torch.from_numpy(geomins)
    # rearrange tensors that atoms in one residue are nolonger grouped
    encodings = rearrange(encodings, "i a d -> (i a) d")
    geomins = rearrange(geomins, "i a d -> (i a) d")

    return geomins, encodings

def prepare_model_inputs(CDR_seqs, CDR_BB_coords):
    '''
    Prepares model inputs for a list of FABs
    '''
    encodings = []
    geomins = []

    for i in track(range(len(CDR_seqs)), description='Preparing model inputs'):
        geom, encode = prepare_model_input(CDR_seqs[i], CDR_BB_coords[i])
        encodings.append(encode)
        geomins.append(geom)

    return geomins, encodings

def prepare_model_output(CDR_BB_coords):
    '''
    Prepares model outputs for training, formated identically to inputs
    '''
    geomouts = []
    for CDR_BB_coord in track(CDR_BB_coords, description='Preparing model outputs'):
        geomout = []
        for _, coords in CDR_BB_coord.items():
            geomout.append(coords)

        # concatenate geoms into single array
        geomout = np.concatenate(geomout, axis=0)
        # format to tensor
        geomout = torch.from_numpy(geomout)
        # rearrange tensor
        geomout = rearrange(geomout, "i a d -> (i a) d")

        geomouts.append(geomout)
    return geomouts

def concatenate_data(encodings, geomins, geomouts):
    '''
    Puts encodings, geomins and geomouts into a single array.
    '''
    data = []
    for i in range(len(encodings)):
        # potentially change list to dict
        data.append({'encodings': encodings[i],
                     'geomins': geomins[i],
                     'geomouts': geomouts[i]})

    return data

In [8]:
geomins, node_encodings = prepare_model_inputs(CDR_seqs, CDR_BB_coords)
geomouts = prepare_model_output(CDR_BB_coords)

In [9]:
data = concatenate_data(node_encodings, geomins, geomouts)
len(data)

8117

**1.3 Prepare data for training**

In [10]:
# data needs to be a single array or list containing encodings, geomins and geomouts
train, test = train_test_split(data, test_size=0.2, random_state=42)
train, validation = train_test_split(train, test_size=0.25, random_state=42)

len(train), len(validation), len(test)

(4869, 1624, 1624)

In [11]:
train_dataloader = torch.utils.data.DataLoader(train, 
                                               batch_size=1,    # Batch size
                                               num_workers=1,   # Number of cpu's allocated to load the data (recommended is 4/GPU)
                                               shuffle=True,    # Whether to randomly shuffle data
                                               pin_memory=True, # Enables faster data transfer to CUDA-enabled GPUs (page-locked memory)
                                               )

test_dataloader = torch.utils.data.DataLoader(test, 
                                              batch_size=1,
                                              num_workers=1,
                                              shuffle=True,
                                              pin_memory=True,
                                              )

val_dataloader = torch.utils.data.DataLoader(validation, 
                                             batch_size=1,
                                             num_workers=1,
                                             shuffle=True,
                                             pin_memory=True,
                                             )

### 2. Implement the EGNN

**2.1 EGNN layer that allows inputs of different lenghts**

In [12]:
class MaskedEGNN(torch.nn.Module):
    def __init__(self, node_dim, message_dim=32):
        super().__init__()

        edge_input_dim = (node_dim * 2) + 1

        self.edge_mlp = torch.nn.Sequential(
            torch.nn.Linear(edge_input_dim, 2*edge_input_dim),
            torch.nn.SiLU(),
            torch.nn.Linear(2*edge_input_dim, message_dim),
            torch.nn.SiLU()
        )

        self.node_mlp = torch.nn.Sequential(
            torch.nn.Linear(node_dim + message_dim, 2*node_dim),
            torch.nn.SiLU(),
            torch.nn.Linear(2*node_dim, node_dim),
        )

        self.coors_mlp = torch.nn.Sequential(
            torch.nn.Linear(message_dim, 2*message_dim),
            torch.nn.SiLU(),
            torch.nn.Linear(2*message_dim, 1)
        )
    
    def forward(self, node_features, coordinates):                                                        # We pass in a mask that tells us what nodes to consider and which to ignore.
        rel_coors = rearrange(coordinates, 'b i d -> b i () d') - rearrange(coordinates, 'b j d -> b () j d')  
        rel_dist = (rel_coors ** 2).sum(dim=-1, keepdim=True)                                                  

        feats_j = rearrange(node_features, 'b j d -> b () j d')      
        feats_i = rearrange(node_features, 'b i d -> b i () d')
        feats_i, feats_j = torch.broadcast_tensors(feats_i, feats_j)

        edge_input = torch.cat((feats_i, feats_j, rel_dist), dim=-1)

        m_ij = self.edge_mlp(edge_input)

        coor_weights = self.coors_mlp(m_ij)                                                          # We multiply the predicted weight by the mask (masked residue pairs will have zero weight).
        coor_weights = rearrange(coor_weights, 'b i j () -> b i j')

        rel_coors_normed = rel_coors / rel_dist.clip(min = 1e-8)    

        coors_out = coordinates + torch.einsum('b i j, b i j c -> b i c', coor_weights, rel_coors_normed)  

        m_i = m_ij.sum(dim=-2)                                                                      # To average we divide over the length for each batch (length = sum(mask)).

        node_mlp_input = torch.cat((node_features, m_i), dim=-1)
        node_out = node_features + self.node_mlp(node_mlp_input)                            # We set the update for maked residues to zero. 

        return node_out, coors_out

**2.2 EGNN model with 4 layers**

In [13]:
class EGNNModel(torch.nn.Module):
    def __init__(self, node_dim, layers=4, message_dim=32):
        super().__init__()

        self.layers = torch.nn.ModuleList([MaskedEGNN(node_dim, message_dim = message_dim) for _ in range(layers)])   # Initialise as many EGNN layers as needed

    def forward(self, node_features, coordinates):

        for layer in self.layers:                                                                            
            node_features, coordinates = layer(node_features, coordinates)                                      # Update node features and coordinates for each layer in the model
        
        return node_features, coordinates

**2.3 5 EGNNs in parallel**

In [14]:
class DecoyGen(torch.nn.Module):
    def __init__(self, dims_in=41, decoys=5, **kwargs):
        super().__init__()
        self.blocks = torch.nn.ModuleList([EGNNModel(node_dim=dims_in, **kwargs) for _ in range(decoys)])
        self.decoys = decoys

    def forward(self, node_features, coordinates):
        geoms = torch.zeros((self.decoys, *coordinates.shape[1:]), device=coordinates.device)

        for i, block in enumerate(self.blocks):
            geoms[i] = block(node_features, coordinates)[1] # only save geoms

        return geoms

### 3. Train model

In [36]:
# set loss functions
def rmsd(prediction, truth):
    dists = (prediction - truth).pow(2).sum(-1)
    return torch.sqrt(dists.mean(-1)).mean()

def rmsds(preds, true):
    return  torch.sort((preds - true).pow(2).sum(-1).mean(-1).pow(1/2))[0]

def length_penalty(pred):
    return ((((pred[:,1:]-pred[:,:-1])**2).sum(-1).pow(1/2) - 3.802).pow(2)).mean()

def different_penalty(pred):
    return -(rearrange(pred, "i n d -> i () n d") - rearrange(pred, "j n d -> () j n d")).pow(2).mean()

def dist_check(pred, amino):
    err = 0
    for i in range(6):
        CDR = rearrange(pred[:,amino[0,:,30+i]==1.0], "d (r a) p -> d a r p", a = 4)
        # CA-CA
        err += (((CDR[:,0,1:] - CDR[:,0,:-1]).pow(2).sum(-1).pow(1/2) - 3.82).abs() - 0.12).clamp(0).mean()
        # CA-N
        err += (((CDR[:,0] - CDR[:,1]).pow(2).sum(-1).pow(1/2) - 1.47).abs() - 0.01).clamp(0).mean()
        # CA-C
        err += (((CDR[:,0] - CDR[:,2]).pow(2).sum(-1).pow(1/2) - 1.53).abs() - 0.01).clamp(0).mean()
        # C-N
        err += (((CDR[:,2,:-1] - CDR[:,1,1:]).pow(2).sum(-1).pow(1/2) - 1.34).abs() - 0.01).clamp(0).mean()
        # CA-CB
        CDR2 = rearrange(pred[:,(amino[0,:,30+i]==1.0) & (amino[0,:,5] != 1.0)], "d (r a) p -> d a r p", a = 4)
        err += (((CDR2[:,0] - CDR2[:,-1]).pow(2).sum(-1).pow(1/2) - 1.54).abs() - 0.01).clamp(0).mean()

    return err

def atom_dist(geom):
    return ((geom[:,None] - geom[:,:,None]).mean(0).pow(2).sum(-1) + 1e-8).pow(1/2) 

def atom_dist_penal(geom, pred):
    true_ds = atom_dist(geom)
    pred_ds = atom_dist(pred)
    mask = true_ds < 4.0
    return (true_ds-pred_ds)[mask].pow(2).mean()

In [37]:
# train for one epoch
def run_epoch(model, optim, train_dataloader, test_dataloader, grad_clip=10.0):
    epoch_train_losses = []
    model.train()                                                      # Set the model to train mode (Should't matter here as we don't have dropout, but good practice to keep in)

    for i,data in enumerate(train_dataloader):                         # For each batch of data in the dataset
        coordinates, geomouts, node_features = data['geomins'].float(), data['geomouts'].float(), data['encodings'].float()

        pred = model(node_features, coordinates)
        optim.zero_grad()                                              # Delete old gradients

        loss = rmsd(geomouts, pred)+ 5*(20*atom_dist_penal(geomouts, pred) + dist_check(pred.mean(0, keepdim = True), node_features))
        epoch_train_losses.append(loss.item())                         # Store value of loss function for training set

        loss.backward()                                                # Calculate loss gradients (pytorch handles this in the background)
        torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)  # Optional: Clip the norm of the gradient (It stops the optimiser from doing very large updates at once)
        optim.step()                                                   # Update model weights

    with torch.no_grad():                                              # Calculate loss funtion for validation set
        model.eval()                                                   # Set the model to eval mode
        epoch_test_loss = np.mean([rmsd(geomouts, pred) + 5*(20*atom_dist_penal(geomouts, pred)
                                   + dist_check(pred.mean(0, keepdim = True), data['encodings'])).item() for data in test_dataloader])
    
    return np.mean(epoch_train_losses), epoch_test_loss

In [38]:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_dtype(torch.float)

# initialise model
model = DecoyGen().to(device = device).float()
train_losses = []
test_losses = []

# set optimiser
optimiser = torch.optim.Adam(model.parameters(), lr=1e-4)

Train model for 5000 epoch, but stop training if the loss hasn't improved for 150 epochs

In [39]:
print(" Train |  Val ")
patience = 150
for epoch in range(5000):
    train_loss, test_loss = run_epoch(model, optimiser, train_dataloader, test_dataloader)  # Run one epoch and get train and validation loss
    
    train_losses.append(train_loss)                                                    # Store train and validation loss
    test_losses.append(test_loss)
    
    if np.min(test_losses) == test_loss:                                                 # If it is the best model on the validation set, save it
        torch.save(model.state_dict(), "best_model")                                   # This is how you save models in pytorch
        epochs_without_improvement = 0

    elif epochs_without_improvement < patience:                                        # If the model hasn't improved this epoch store that
        epochs_without_improvement += 1
    else:                                                                              # If the model hasn't improved in 'patience' epochs stop the training.
        break

    if train_loss > 1.5*np.min(train_losses):                                          # EGNNs are quite unstable, this reverts the model to a previous state if an epoch blows up
        model.load_state_dict(torch.load("previous_weights", map_location=torch.device(device)))
        optimiser.load_state_dict(torch.load("previous_optim", map_location=torch.device(device)))
    if train_loss == np.min(train_losses):
        torch.save(model.state_dict(), "previous_weights")        
        torch.save(optimiser.state_dict(), "previous_optim")  


    print("{:6.2f} | {:6.2f}".format(train_loss, train_loss))

 Train |  Val 


KeyboardInterrupt: 