In [1]:
import torch
import torch
import numpy as np
import sys
import argparse
import os.path

sys.path.append("./training")

In [2]:
from utils import worker_init_fn, get_pdbs, loader_pdb, build_training_clusters, PDB_dataset, StructureDataset, StructureLoader

data_path = "./pdb_2021aug02_sample"
params = {
    "LIST"    : f"{data_path}/list.csv", 
    "VAL"     : f"{data_path}/valid_clusters.txt",
    "TEST"    : f"{data_path}/test_clusters.txt",
    "DIR"     : f"{data_path}",
    "DATCUT"  : "2030-Jan-01",
    "RESCUT"  : 3.5, #resolution cutoff for PDBs
    "HOMO"    : 0.70 #min seq.id. to detect homo chains
}

LOAD_PARAM = {'batch_size': 1,
              'shuffle': False,
              'pin_memory':False,
              'num_workers': 4}

train, valid, test = build_training_clusters(params, False)

train_set = PDB_dataset(list(train.keys()), loader_pdb, train, params)
train_loader = torch.utils.data.DataLoader(train_set, worker_init_fn=worker_init_fn, **LOAD_PARAM)

In [3]:
pdb_dict_train = get_pdbs(train_loader)
dataset_train = StructureDataset(pdb_dict_train, truncate=None, max_length=100000)
loader_train = StructureLoader(dataset_train, batch_size=10000)

### Mess around with architecture here:

In [113]:
from model_utils import ProteinFeatures, PositionWiseFeedForward, gather_nodes, cat_neighbors_nodes
import torch.nn as nn
import torch.nn.functional as F


class EncLayer(nn.Module):
    def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30):
        super(EncLayer, self).__init__()
        self.num_hidden = num_hidden
        self.num_in = num_in
        self.scale = scale
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.dropout3 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(num_hidden)
        self.norm2 = nn.LayerNorm(num_hidden)
        self.norm3 = nn.LayerNorm(num_hidden)

        self.W1 = nn.Linear(num_hidden + num_in, num_hidden, bias=True)
        self.W2 = nn.Linear(num_hidden, num_hidden, bias=True)
        self.W3 = nn.Linear(num_hidden, num_hidden, bias=True)
        self.W11 = nn.Linear(num_hidden + num_in, num_hidden, bias=True)
        self.W12 = nn.Linear(num_hidden, num_hidden, bias=True)
        self.W13 = nn.Linear(num_hidden, num_hidden, bias=True)
        self.act = torch.nn.GELU()
        self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)

    def forward(self, h_V, h_E, E_idx, mask_V=None, mask_attend=None):
        """ Parallel computation of full transformer layer """
        
        # Compute node update.
        h_EV = cat_neighbors_nodes(h_V, h_E, E_idx)
        ## Duplicates the node embeddings up to K (nearest neighbors) for concatenation with edge embeddings.
        h_V_expand = h_V.unsqueeze(-2).expand(-1,-1,h_EV.size(-2),-1)
        h_EV = torch.cat([h_V_expand, h_EV], -1)
        ## Compute message passing through linear layers in self.num_hidden embedding space.
        h_message = self.W3(self.act(self.W2(self.act(self.W1(h_EV)))))
        
        ## mask_V and mask_attend will not be None during training at least.
        ## Zeros out the updates to the indices being ignored, before the update is applied to the node embeddings.
        if mask_attend is not None:
            h_message = mask_attend.unsqueeze(-1) * h_message
            
        ## Compute the message by summing over the neighbor indices and dividing by a scale factor (should probably be the degree)
        dh = torch.sum(h_message, -2) / self.scale
        h_V = self.norm1(h_V + self.dropout1(dh))
        dh = self.dense(h_V)
        h_V = self.norm2(h_V + self.dropout2(dh))
        
        ## mask_V and mask_attend will not be None during training at least.
        ## Zeros out the nodes being ignored, after the update is applied to the node embeddings (removes any layer biases added to the zeroed nodes).
        if mask_V is not None:
            mask_V = mask_V.unsqueeze(-1)
            h_V = mask_V * h_V

        # Compute edge update.
        h_EV = cat_neighbors_nodes(h_V, h_E, E_idx)
        h_V_expand = h_V.unsqueeze(-2).expand(-1,-1,h_EV.size(-2),-1)
        h_EV = torch.cat([h_V_expand, h_EV], -1)
        h_message = self.W13(self.act(self.W12(self.act(self.W11(h_EV)))))
        h_E = self.norm3(h_E + self.dropout3(h_message))
        ## Looks like encoder block does not perform a 'dense layer' update for edge embeddings.
        return h_V, h_E



class DecLayer(nn.Module):
    def __init__(self, num_hidden, num_in, dropout=0.1, num_heads=None, scale=30):
        super(DecLayer, self).__init__()
        self.num_hidden = num_hidden
        self.num_in = num_in
        self.scale = scale
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(num_hidden)
        self.norm2 = nn.LayerNorm(num_hidden)

        self.W1 = nn.Linear(num_hidden + num_in, num_hidden, bias=True)
        self.W2 = nn.Linear(num_hidden, num_hidden, bias=True)
        self.W3 = nn.Linear(num_hidden, num_hidden, bias=True)
        self.act = torch.nn.GELU()
        self.dense = PositionWiseFeedForward(num_hidden, num_hidden * 4)

    def forward(self, h_V, h_E, mask_V=None, mask_attend=None):
        """ Parallel computation of full transformer layer """
        # The decoding layer works exactly like the encoding layer, but it does not perform an edge update. See line-by-line explanation from above.

        # This block computes the node update.
        h_V_expand = h_V.unsqueeze(-2).expand(-1,-1,h_E.size(-2),-1) 
        h_EV = torch.cat([h_V_expand, h_E], -1)
        h_message = self.W3(self.act(self.W2(self.act(self.W1(h_EV)))))
        if mask_attend is not None:
            h_message = mask_attend.unsqueeze(-1) * h_message
        dh = torch.sum(h_message, -2) / self.scale 
        h_V = self.norm1(h_V + self.dropout1(dh))
        dh = self.dense(h_V)
        h_V = self.norm2(h_V + self.dropout2(dh))
        if mask_V is not None:
            mask_V = mask_V.unsqueeze(-1)
            h_V = mask_V * h_V
            
        # Return updated node embeddings.
        return h_V

class ProteinMPNN(nn.Module):
    def __init__(self, num_letters=21, node_features=128, edge_features=128,
        hidden_dim=128, num_encoder_layers=3, num_decoder_layers=3,
        vocab=21, k_neighbors=32, augment_eps=0.1, dropout=0.1):
        super(ProteinMPNN, self).__init__()

        # Hyperparameters
        self.node_features = node_features
        self.edge_features = edge_features
        self.hidden_dim = hidden_dim

        self.features = ProteinFeatures(node_features, edge_features, top_k=k_neighbors, augment_eps=augment_eps)

        self.W_e = nn.Linear(edge_features, hidden_dim, bias=True)
        self.W_s = nn.Embedding(vocab, hidden_dim)

        # Encoder layers
        self.encoder_layers = nn.ModuleList([
            EncLayer(hidden_dim, hidden_dim*2, dropout=dropout)
            for _ in range(num_encoder_layers)
        ])

        # Decoder layers
        self.decoder_layers = nn.ModuleList([
            DecLayer(hidden_dim, hidden_dim*3, dropout=dropout)
            for _ in range(num_decoder_layers)
        ])
        self.W_out = nn.Linear(hidden_dim, num_letters, bias=True)

        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)

    def forward(self, X, S, mask, chain_M, residue_idx, chain_encoding_all):
        """ Graph-conditioned sequence model """
        device=X.device
        
        # Prepare node and edge embeddings
        # Constructs edge features from X.
        # Computes virtual Cb atom position and encodes into RBF, adds additional edge features as needed.
        E, E_idx = self.features(X, mask, residue_idx, chain_encoding_all)
        
        # Node features are intialized to zeros in the same hidden dimensionality as edges (128 by default)
        h_V = torch.zeros((E.shape[0], E.shape[1], E.shape[-1]), device=E.device)
        # Linear transformation over the edges to move them into the hidden dimensionality
        h_E = self.W_e(E)

        # Encoder is unmasked self-attention
        # Duplicate the mask out to K (nearest neighbors) so that it can be used to mask which nearest neighbors will be used.
        mask_attend = gather_nodes(mask.unsqueeze(-1),  E_idx).squeeze(-1)
        mask_attend = mask.unsqueeze(-1) * mask_attend
        
        # Perform the forward pass of the encoder layers with checkpointing.
        # Computes a node and edge update which is used as input to the next layer and builds up the node representation.
        for layer in self.encoder_layers:
            h_V, h_E = torch.utils.checkpoint.checkpoint(layer, h_V, h_E, E_idx, mask, mask_attend)

        # Concatenate sequence embeddings for autoregressive decoder
        # Concatenates linear transformation of sequence (not one-hot encoded for some reason) to edge embeddings in place of node embeddings. 
        # Moves sequence from [B, N] -> [B, N, (node embedding) 128]
        h_S = self.W_s(S)
        h_ES = cat_neighbors_nodes(h_S, h_E, E_idx)
        
        # (Justas comment) Build encoder embeddings [idk what this means]
        ## Concatenates zeros to edges in same shape as nodes would be concatenated to edges.
        h_EX_encoder = cat_neighbors_nodes(torch.zeros_like(h_S), h_E, E_idx)
        ## Concatenates node embeddings from encoder to zeros concatenated to edges.
        h_EXV_encoder = cat_neighbors_nodes(h_V, h_EX_encoder, E_idx)
        
        ## select only visible chains and only the residues that actually exist (not the padded residues added during batching) 
        ## element-wise multiplication of boolean vectors performs logical AND.
        ## Update chain_M to include missing regions (Justas comment, this implies to me that mask also tracks gaps in structure/sequence)
        chain_M = chain_M*mask 
        
        # Apply auto-regressive masking
        ## Creates a biased random shuffle of the indices in range [0,N), 0-indices in chain_M will be at the beginning of the shuffle while 1-indices will be at the end.
        ## (Justas Comment) [numbers will be smaller for places where chain_M = 0.0 and higher for places where chain_M = 1.0]
        decoding_order = torch.argsort((chain_M+0.0001)*(torch.abs(torch.randn(chain_M.shape, device=device)))) 
        
        ## Creates a permutation matrix by constructing one-hot encodings for all of the indices in the biased random decoding order.
        mask_size = E_idx.shape[1]
        permutation_matrix_reverse = torch.nn.functional.one_hot(decoding_order, num_classes=mask_size).float()
        
        ## Creates a random matrix for each protein that encodes the decoding order as a mask. Each row in the mask for each protein complex
        ##   corresponds to which indices are visible for each step of the decoding process. Follows the 
        ##   For example:
        ###    [0, 1, 0, 0, 0]
        ###    [0, 0, 0, 0, 0]
        ###    [1, 1, 0, 0, 0]
        ###    [1, 1, 0, 1, 0]
        ###    [1, 1, 0, 1, 1] for a 5 amino acid protein complex.
        order_mask_backward = torch.einsum('ij, biq, bjp->bqp',(1-torch.triu(torch.ones(mask_size,mask_size, device=device))), permutation_matrix_reverse, permutation_matrix_reverse)
        
        # print(chain_M.shape)
        # print()
        # print(chain_M[0, :])
        # print(E_idx*chain_M.unsqueeze(-1).expand(-1, -1, 5)[0, :])
        
        ## This creates a mask for which neighbors are visible for each index in the random decoding order.
        ###  since each row in the order_mask_backward corresponds to a different index in the protein complex during auto-regressive decoding
        ###  each index has its K nearest neighbors stored in the E_idx matrix. 
        ###  This call of the gather function creates a mask to find visible neighbors for each connection in E_idx.
        mask_attend = torch.gather(order_mask_backward, 2, E_idx).unsqueeze(-1)
        print(mask_attend.shape)
              
        ## Convert the mask from [B, N] -> [B, N, 1, 1] moves data to deepest index in a compatible dimensionality with mask_attend.
        mask_1D = mask.view([mask.size(0), mask.size(1), 1, 1])
        
        ## Construct a mask and an inverse mask. Masks[B, N, KNN k, 1]
        mask_bw = mask_1D * mask_attend
        mask_fw = mask_1D * (1. - mask_attend)

        # Add sequence where appropriate.
        ## Using the forward mask select the indices of the zeros concat to node embeddings concat to edges for fw matrix.
        h_EXV_encoder_fw = mask_fw * h_EXV_encoder
        for layer in self.decoder_layers:
            ## Construct a matrix using the node embeddings updated on each iteration merged concat to sequence embedding concat to edges.
            h_ESV = cat_neighbors_nodes(h_V, h_ES, E_idx)
            ## Use the inverse mask to select only the visible indices of the seq embedding concat to node embedding concat to the edges 
            ##   add forward mask so we have values for every index.
            h_ESV = mask_bw * h_ESV + h_EXV_encoder_fw
            ## Compute node update with masked indices.
            h_V = torch.utils.checkpoint.checkpoint(layer, h_V, h_ESV, mask)

        # Linear map of logits to output dimensionality, softmax to prepare logits for NLL loss and return.
        logits = self.W_out(h_V)
        log_probs = F.log_softmax(logits, dim=-1)
        return log_probs

### Run the messed-with model here.

In [114]:
from model_utils import featurize

model = ProteinMPNN(
    node_features=128, 
    edge_features=128, 
    hidden_dim=128, 
    num_encoder_layers=3, 
    num_decoder_layers=3, 
    k_neighbors=5, # Set this to something small for simplicity during analysis.
    # k_neighbors=48, 
    dropout=0.1, 
    augment_eps=0.2
)

model.train();

In [115]:
for idx, batch in enumerate(loader_train):
    print("Protein Lengths:")
    print([len(p['seq']) for p in batch], '\n')
    
    # chain_M is the mask corresponding to 
    X, S, mask, lengths, chain_M, residue_idx, mask_self, chain_encoding_all = featurize(batch, torch.device('cpu'))
    
    # Multiplies the masks element-wise for a logical AND of the two masks, this is also performed in the forward pass.
    # This has the effect of selecting only visible chains and only the residues that actually exist (not the padded residues added during batching).
    mask_for_loss = mask * chain_M
    
    # Forward pass (is sequence aware!!), this is what we want to replicate for our model.
    log_probs = model(X, S, mask, chain_M, residue_idx, chain_encoding_all)
    
    break

Protein Lengths:
[2954, 3267] 

torch.Size([2, 3267, 5, 1])
torch.Size([2, 3267, 5, 1])
tensor([[[[0.],
          [1.],
          [0.],
          [1.],
          [1.]],

         [[0.],
          [0.],
          [0.],
          [1.],
          [0.]],

         [[0.],
          [1.],
          [1.],
          [1.],
          [0.]],

         ...,

         [[0.],
          [0.],
          [0.],
          [0.],
          [0.]],

         [[0.],
          [0.],
          [0.],
          [0.],
          [0.]],

         [[0.],
          [0.],
          [0.],
          [0.],
          [0.]]],


        [[[0.],
          [0.],
          [1.],
          [1.],
          [1.]],

         [[0.],
          [1.],
          [1.],
          [0.],
          [1.]],

         [[0.],
          [0.],
          [0.],
          [1.],
          [0.]],

         ...,

         [[0.],
          [1.],
          [1.],
          [1.],
          [1.]],

         [[0.],
          [0.],
          [1.],
          [1