In [1]:
import os
import glob
import matplotlib.pyplot as plt
import numpy as np
import math, random, time
import torch, torch_geometric
import torch.utils.data as data
import torch.nn as nn
import gvp.data
from gvp import GVP, GVPConvLayer, LayerNorm

import mdtraj as md

In [2]:
def create_structures(files, pdb):
    topology = md.load(pdb).topology
    start = time.time()
    #structures = {}
    structures = []
    for file in files:
        traj = md.load_xtc(file, top=pdb)
        traj.center_coordinates()
        backbone = traj.atom_slice(topology.select("protein and backbone")) 
        coords = backbone.xyz # has size (n_frames x n_atoms x 3)
        
        traj_coords = coords.reshape(-1, backbone.n_atoms*3) # has size (n_frames x n_atoms*3)
        for i in range(len(traj)):
            structures.append({'coords':traj_coords,
                               'name': f'{file[5:-4]}_frame{i}',
                               'seq': ['Q']*traj_coords.shape[0]})
            #structures[f'{file[5:-4]}_frame{i}'] = {'coords': traj_coords,
            #                                        'name': f'{file[5:-4]}_frame{i}',
            #                                        'seq': ['Q']*traj_coords.shape[0]} # seq doesn't matter for now
    end = time.time()
    print("create structures took", round((end-start)/60,4), "minutes")
    return structures
        

In [3]:
def train_test_split_files(structures):
    start = time.time()
    random.shuffle(structures)
    split = int(0.8*len(structures))
    train_struct = structures[0:split]
    test_struct = structures[split:]
    end = time.time()
    print("train test split took", round((end-start)/60,4), "minutes")
    return train_struct, test_struct

In [4]:
def train_test_split_filesSAVED(structures_path, model_dir):
    start = time.time()
    input_structures = torch.load(structures_path)
    random.shuffle(input_structures)
    split = int(0.8*len(input_structures))
    train_files = input_structures[0:split]
    test_files = input_structures[split:]
    torch.save(train_files, os.path.join(model_dir,'train_structures.pt'))
    torch.save(test_files, os.path.join(model_dir,'test_structures.pt'))
    end = time.time()
    print("Importing structures and splitting took", round((end-start)/60,4), "minutes")
    return

In [17]:
class LigandDataset(gvp.data.ProteinGraphDataset):
    def __init_(self, data_list,
                num_positional_embeddings=16,
                top_k=30, num_rbf=16, device="cpu"):
        super(LigandDataset, self).__init__(data_list,
                                            num_positional_embeddings,
                                            top_k, num_rbf, device)

In [6]:
"""
class LigandDataset(gvp.data.ProteinGraphDataset):
    def __init__(self, data_list, 
                 num_positional_embeddings=16,
                 top_k=30, num_rbf=16, device="cpu", node_counts=66,
                 tica_dict={}, structures_dict={}):
        data.Dataset.__init__(self)
        
        self.data_list = data_list 
        self.top_k = top_k
        self.num_rbf = num_rbf
        self.num_positional_embeddings = num_positional_embeddings
        self.device = device
        self.node_counts = node_counts
        
        self.letter_to_num = {'C': 4, 'D': 3, 'S': 15, 'Q': 5, 'K': 11, 'I': 9,
                       'P': 14, 'T': 16, 'F': 13, 'A': 0, 'G': 7, 'H': 8,
                       'E': 6, 'L': 10, 'R': 1, 'W': 17, 'V': 19, 
                       'N': 2, 'Y': 18, 'M': 12}
        self.num_to_letter = {v:k for k, v in self.letter_to_num.items()}

    def __getitem__(self, i): 
        return self._featurize_as_graph(self.data_list[i])

    def _featurize_as_graph(self, protein):
        # overwrite using from_dict and to_dict (on data)
        # dict.update to add the state information (as a tensor)
        data = super(LigandDataset, self)._featurize_as_graph(protein)
        with torch.no_grad():
            tica = torch.as_tensor(protein['tica'], device=self.device, dtype=torch.float)
        data_dict = data.to_dict()
        data_dict.update({'tica':tica})
        new_data = torch_geometric.data.Data.from_dict(data_dict)
        return new_data
"""

'\nclass LigandDataset(gvp.data.ProteinGraphDataset):\n    def __init__(self, data_list, \n                 num_positional_embeddings=16,\n                 top_k=30, num_rbf=16, device="cpu", node_counts=66,\n                 tica_dict={}, structures_dict={}):\n        data.Dataset.__init__(self)\n        \n        self.data_list = data_list \n        self.top_k = top_k\n        self.num_rbf = num_rbf\n        self.num_positional_embeddings = num_positional_embeddings\n        self.device = device\n        self.node_counts = node_counts\n        \n        self.letter_to_num = {\'C\': 4, \'D\': 3, \'S\': 15, \'Q\': 5, \'K\': 11, \'I\': 9,\n                       \'P\': 14, \'T\': 16, \'F\': 13, \'A\': 0, \'G\': 7, \'H\': 8,\n                       \'E\': 6, \'L\': 10, \'R\': 1, \'W\': 17, \'V\': 19, \n                       \'N\': 2, \'Y\': 18, \'M\': 12}\n        self.num_to_letter = {v:k for k, v in self.letter_to_num.items()}\n\n    def __getitem__(self, i): \n        return self._fe

In [7]:
class GCAE(nn.Module):
    '''
    Graph Convolutional AutoEncoder 
    
    Takes in protein structure graphs of type `torch_geometric.data.Data` 
    or `torch_geometric.data.Batch` and returns a 3D position per node of 
    shape [n_nodes, 3].
    
    Should be used with `gvp.data.ProteinGraphDataset`, or with generators
    of `torch_geometric.data.Batch` objects with the same attributes.
    
    :param node_in_dim: node dimensions in input graph, should be
                        (6, 3) if using features from GVP-GNN
    :param node_h_dim: hidden node dimensions to use in intermediate layers
    :param node_in_dim: edge dimensions in input graph, should be
                        (32, 1) if using original features from GVP-GNN
    :param edge_h_dim: hidden edge dimensions to embed to before use
                       in intermediate layers
    :param num_layers: number of layers in each of the encoder
                       and decoder modules
    :param drop_rate: rate to use in all dropout layers
    '''
    def __init__(self, node_in_dim, node_h_dim, 
                 edge_in_dim, edge_h_dim,
                 num_layers=3, drop_rate=0.1, node_num=66):
    
        super(GCAE, self).__init__()
        
        self.node_num = node_num
        
        self.W_v = nn.Sequential(
            GVP(node_in_dim, node_h_dim, activations=(None, None)),
            LayerNorm(node_h_dim)
        )
        self.W_e = nn.Sequential(
            GVP(edge_in_dim, edge_h_dim, activations=(None, None)),
            LayerNorm(edge_h_dim)
        )
        
        self.encoder_layers = nn.ModuleList(
                GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate) 
            for _ in range(num_layers))
        
        
        self.squeeze_layer = nn.Sequential(
            nn.Linear(self.node_num*(node_h_dim[0]+node_h_dim[1]*3),1024),
            nn.ReLU(inplace=True), 
            nn.Dropout(p=drop_rate),
            nn.Linear(1024, 16)
        )
        
        self.unsqueeze_layer = nn.Sequential(
            nn.Linear(16, 1024), 
            nn.ReLU(inplace=True), 
            nn.Dropout(p=drop_rate),
            nn.Linear(1024, self.node_num*(node_h_dim[0]+node_h_dim[1]*3))
        )
        
        
        self.decoder_layers = nn.ModuleList(
                GVPConvLayer(node_h_dim, edge_h_dim, drop_rate=drop_rate) 
            for _ in range(num_layers))
        
        self.W_out = GVP(node_h_dim, (3, 0), activations=(None, None))


    def forward(self, h_V, edge_index, h_E):
        '''
        Forward pass to be used at train-time or test-time.
        
        :param h_V: tuple (s, V) of node embeddings
        :param edge_index: `torch.Tensor` of shape [2, num_edges]
        :param h_E: tuple (s, V) of edge embeddings
        :param seq: int `torch.Tensor` of shape [num_nodes]
        '''
        
        h_V = (h_V[0].reshape(h_V[0].shape[0]*h_V[0].shape[1],h_V[0].shape[2]),
               h_V[1].reshape(h_V[1].shape[0]*h_V[1].shape[1],h_V[1].shape[2],h_V[1].shape[3]))
        
        h_E = (h_E[0].reshape(h_E[0].shape[0]*h_E[0].shape[1],h_E[0].shape[2]),
               h_E[1].reshape(h_E[1].shape[0]*h_E[1].shape[1],h_E[1].shape[2],h_E[1].shape[3]))
        
        h_V = self.W_v(h_V)
        h_E = self.W_e(h_E)
        
        for layer in self.encoder_layers:
            h_V = layer(h_V, edge_index, h_E)
        
        encoder_embeddings = h_V
        
        flat_s = h_V[0].reshape(h_V[0].shape[0]//self.node_num, -1)
        flat_V = h_V[1].reshape(h_V[1].shape[0]//self.node_num, -1)
        h_V_stack = torch.cat((flat_s, flat_V),dim=1)
        h_V_stack = self.squeeze_layer(h_V_stack)

        h_V_small = torch.clone(h_V_stack)
        
        h_V_stack = self.unsqueeze_layer(h_V_stack)
        
        flat_s = h_V_stack[:,:self.node_num*encoder_embeddings[0].shape[1]]
        flat_V = h_V_stack[:,self.node_num*encoder_embeddings[0].shape[1]:]
        h_V = (flat_s.reshape(encoder_embeddings[0].shape), 
               flat_V.reshape(encoder_embeddings[1].shape))
       
        
        for layer in self.decoder_layers:
            h_V = layer(h_V, edge_index, h_E)
        logits = self.W_out(h_V)
        
        logits = logits.reshape(-1, self.node_num, 3)
        return logits, h_V_small 

In [8]:
def train_gnn(model_dir, model, train_structures, n_epoch=20):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    train_dataset = LigandDataset(train_structures)

    train_dataloader = torch_geometric.loader.DenseDataLoader(train_dataset, 
                                                              batch_size=256, 
                                                              shuffle=True, 
                                                              num_workers=4)

    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    losses = torch.zeros(n_epoch,1)
    #num_trained = len(glob.glob(os.path.join(model_dir,'epoch-*.pt')))
    # if we haven't started training yet, n_epochs will go from 0 -> n_epoch
    #n_epochs = list(range(num_trained, num_trained+n_epoch))
    n_epochs = list(range(0,n_epoch))
    for epoch in n_epochs:
        batch_losses = []
        start = time.time()
        for i, batch in enumerate(train_dataloader):
            optimizer.zero_grad()
            batch = batch.to(device)

            nodes = (batch.node_s, batch.node_v)
            edges = (batch.edge_s, batch.edge_v)
            GT = batch.x
            
            edge_index = batch.edge_index.permute([1,0,2])
            edge_index = edge_index.reshape(2, -1)

            pred,latent = model(nodes, edge_index, edges)

            # make it so if close in latent space it is close in tica space 
            # so minimize norm of (pairwise dist in latent minus pairwise dist in tica) 
            loss = ((GT - pred) ** 2).mean()
            
            
            batch_losses.append(loss.item())

            loss.backward()
            optimizer.step()

            #clear_output(wait=True)
            #print(f'EPOCH {epoch} BATCH {i} TRAIN loss: {loss:.4f}')
        
        path = os.path.join(model_dir, 'epoch-{}.pt'.format(epoch))
        torch.save(model.state_dict(), path)

        
        
        
        
        # save avg loss value over batch for each epoch
        losses[epoch] = np.mean(batch_losses)
        torch.save(losses, os.path.join(model_dir, 'losses.pt'))

        end = time.time()
        t = round((end-start)/60,4)
        print(f'----Epoch {epoch} | TRAIN loss: {losses[epoch, 0]} | Elapsed time: {t} minutes----')
    
    view_output(model_dir, GT, pred, prefix='train_output_')
    plot_loss(model_dir, weight)

    return model

In [9]:
def run_inference(model, model_dir, test_structures):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    #test_structures = torch.load(os.path.join(model_dir,structures))
    test_dataset = LigandDataset(test_structures)
    test_dataloader = torch_geometric.loader.DenseDataLoader(test_dataset, 
                                                             batch_size=256, 
                                                             shuffle=False, 
                                                             num_workers=4)

    model.eval()
    latents, preds, GTs = ([] for i in range(3))
    losses=[]
    for i, batch in enumerate(test_dataloader):
        batch = batch.to(device)
        nodes = (batch.node_s, batch.node_v)
        edges = (batch.edge_s, batch.edge_v)
        GT = batch.x
        edge_index = batch.edge_index.permute([1,0,2])
        edge_index = edge_index.reshape(2, -1)

        pred, latent = model(nodes, edge_index, edges)
        
        loss = ((GT - pred) ** 2).mean()
        losses.append(loss.item())
        

        GTs.extend(GT.cpu().detach().numpy())
        latents.extend(latent.cpu().detach().numpy())
        preds.extend(logits.cpu().detach().numpy())
    print("TEST LOSS: ", np.mean(losses))
    torch.save(GTs, os.path.join(model_dir, 'inference_GTs.pt'))
    torch.save(latents, os.path.join(model_dir, 'inference_latents.pt'))
    torch.save(preds, os.path.join(model_dir, 'inference_preds.pt'))
    view_output(model_dir, GT, pred, prefix='test_output_')
    return latents, preds


In [10]:
def view_output(model_dir, GT, pred, prefix=''):
    gt_coords = np.squeeze(GT.cpu().detach().numpy())
    pred_coords = np.squeeze(pred.cpu().detach().numpy())
    print(pred_coords.shape)
    gt_coords = gt_coords.reshape(-1,3)
    pred_coords = pred_coords.reshape(-1,3)
    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')
    ax.scatter(gt_coords[:,0], gt_coords[:,1], gt_coords[:,2])
    ax.scatter(pred_coords[:,0], pred_coords[:,1], pred_coords[:,2])
    fig.savefig(os.path.join(model_dir,prefix+"gnn_pred_and_GT.png"))

    fig = plt.figure()
    ax = fig.add_subplot(projection='3d')
    ax.scatter(gt_coords[:,0]-pred_coords[:,0], 
            gt_coords[:,1]-pred_coords[:,1],
            gt_coords[:,2]-pred_coords[:,2])
    fig.savefig(os.path.join(model_dir,prefix+"gnn_GT_-_pred.png"))
    return

In [11]:
def plot_loss(model_dir, prefix=''):
    losses = torch.load(os.path.join(model_dir, prefix+'losses.pt'))
    fig, ax = plt.subplots()
    loss = ax.plot(range(0,losses.shape[0]), losses,'-.')
    fig.title('Loss')
    #ax.legend(loss, (f'loss1 + {weight}*loss2', 'loss1', 'loss2'))
    print("COMBINED LOSS: ", losses)
    fig.savefig(os.path.join(model_dir,'loss.png'))

In [12]:
files = [f'data/trajectory-{i}.xtc' for i in range(1,29)]
pdb = 'data/fs-peptide.pdb'
structures = create_structures(files, pdb)

print(len(structures))
#torch.save(structures, 'gnn_model/input_structures.pt')



create structures took 0.7065 minutes
280000


In [18]:
model_dir = 'gnn_model/'
train_struct, test_struct = train_test_split_files(structures)

if torch.cuda.is_available():
    device = torch.device('cuda')
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")

node_h_dim = (100, 16)
edge_h_dim = (32, 1)
node_in_dim = (6, 3) #node dimensions in input graph, should be (6, 3) if using features from GVP-GNN
edge_in_dim = (32, 1) #edge dimensions in input graph should be (32, 1) if using original features from GVP-GNN
model = GCAE(node_in_dim, node_h_dim, edge_in_dim, edge_h_dim).to(device)

model = train_gnn(model_dir, model, train_struct)

latents, preds = run_inference(model, model_dir, test_struct)


train test split took 0.0021 minutes


Traceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/Users/lreeder/miniforge3/envs/pytorch/lib/python3.8/multiprocessing/spawn.py", line 116, in spawn_main
    exitcode = _main(fd, parent_sentinel)
  File "/Users/lreeder/miniforge3/envs/pytorch/lib/python3.8/multiprocessing/spawn.py", line 126, in _main
    self = reduction.pickle.load(from_parent)
AttributeError: Can't get attribute 'LigandDataset' on <module '__main__' (built-in)>


KeyboardInterrupt: 