In [1]:
%load_ext memory_profiler

from memory_profiler import profile
import os, random, math, argparse
import torch
import torch.nn as nn
import torch.nn.functional as F

from model_ABDG_v2 import ProteinGCN
# from data import 
from data import ProteinDataset, get_train_val_test_loader,collate_pool
from utils import randomSeed
import config as cfg
import numpy as np
import tqdm
import matplotlib.pyplot as plt
import pickle

In [2]:
def from_torch(x):
    """
    Convert from a torch tensor to numpy array
    """
    return x.detach().cpu().numpy()

In [3]:
N_way = 5 # Number of ways
K_shot = 5
Q_queryperclass = 1 # K_shot is the number of shots

aa_num_neighbors = 1 #How many adjacent amino acids to the target?

ISPROTEINGCN_EMBEDDED = True
# Data loading parameters only
# max_amino=200
# num_writers = 20    # Number of drawings per character
# num_alphabets = 50   # Number of alphabets in Omniglot (explained below)
# num_characters = 1623 # Number of characters in the entire Omniglot

In [4]:
args = {'name': 'abdg_demo',
            'pkl_dir': '/home/ambeck/6.883ProteinDocking/data/firstDB5try/',
            'protein_dir': '/home/ambeck/6.883ProteinDocking/data/firstDB5try/',
            'save_dir': './data/pkl/results/',
            'id_prop': 'protein_id_prop.csv',
            'atom_init': 'protein_atom_init.json',
            'pretrained': './pretrained/pretrained.pth.tar',
            'avg_sample': 500,
            'seed': 1234,
            'epochs': 0,
            'batch_size': 1,
            'train': 0.0,
            'val': 0.0,
            'test': 1.0,
            'testing': False,
            'lr': 0.001, 'h_a': 64, 'h_g': 32,
            'n_conv': 4, 'save_checkpoints': True,
            'print_freq': 10, 
            'workers': 1,
           }
print('Torch Device being used: ', cfg.device)

# create the savepath
savepath = args["save_dir"] + str(args["name"]) + '/'
if not os.path.exists(savepath):
    os.makedirs(savepath)

randomSeed(args["seed"])

# create train/val/test dataset separately
assert os.path.exists(args["protein_dir"]), '{} does not exist!'.format(args["protein_dir"])
dirs_label = [d[:10] for d in os.listdir(args["pkl_dir"]) if not d.startswith('.DS_Store')]
# all_dirs = [d for d in os.listdir(args["protein_dir"]) if not d.startswith('.DS_Store')]
base_dir=set(dirs_label)
dir_r = []
dir_l = []
dir_r.extend(d+'r_u_cleane.pkl' for d in base_dir)
dir_l.extend(d+'l_u_cleane.pkl' for d in base_dir)
all_dirs = []
for r,l in zip(dir_r, dir_l):
    all_dirs.append(r)
    all_dirs.append(l)

dir_len = len(all_dirs)
indices = list(range(dir_len))
random.shuffle(indices)

train_size = math.floor(args["train"] * dir_len)
val_size = math.floor(args["val"] * dir_len)
test_size = math.floor(args["test"] * dir_len)

if val_size == 0:
    print(
        'No protein directory given for validation!! Please recheck the split ratios, ignore if this is intended.')
if test_size == 0:
    print('No protein directory given for testing!! Please recheck the split ratios, ignore if this is intended.')

test_dirs = all_dirs[:test_size]
train_dirs = all_dirs[test_size:test_size + train_size]
val_dirs = all_dirs[test_size + train_size:test_size + train_size + val_size]
print('Testing on {} protein directories:'.format(len(test_dirs)))

@profile
def loadProteinDataSetAndModel():

    dataset = ProteinDataset(args["pkl_dir"], args["id_prop"], args["atom_init"], random_seed=args["seed"])

    print('Dataset length: ', len(dataset))

    # load all model args from pretrained model
    if args["pretrained"] is not None and os.path.isfile(args["pretrained"]):
        print("=> loading model params '{}'".format(args["pretrained"]))
        model_checkpoint = torch.load(args["pretrained"], map_location=lambda storage, loc: storage)
        model_args = argparse.Namespace(**model_checkpoint['args'])
        # override all args value with model_args
        args["h_a"] = model_args.h_a
        args["h_g"] = model_args.h_g
        args["n_conv"] = model_args.n_conv
        args["random_seed"] = model_args.seed
        args["lr"] = model_args.lr

        print("=> loaded model params '{}'".format(args["pretrained"]))
    else:
        print("=> no model params found at '{}'".format(args["pretrained"]))

    args["random_seed"] = args["seed"]
    structures, _, _ = dataset[0]
    
    h_b = structures[1].shape[-1]
    args['h_b'] = h_b  # Dim of the bond embedding initialization

    # Use DataParallel for faster training
    print("Let's use", torch.cuda.device_count(), "GPUs and Data Parallel Model.")
    # print(kwargs)
    model = ProteinGCN(**args)
    return model, dataset


Torch Device being used:  cpu
No protein directory given for validation!! Please recheck the split ratios, ignore if this is intended.
Testing on 464 protein directories:


In [5]:
model, dataset = loadProteinDataSetAndModel()

ERROR: Could not find file <ipython-input-4-ffed259b5ef1>
NOTE: %mprun can only be used on functions defined in physical files, and not in the IPython environment.
Dataset length:  920
=> loading model params './pretrained/pretrained.pth.tar'
=> loaded model params './pretrained/pretrained.pth.tar'
Let's use 0 GPUs and Data Parallel Model.


In [6]:
test_loader = get_train_val_test_loader(dataset, train_dirs, val_dirs, test_dirs,
                                                                      collate_fn    = collate_pool,
                                                                      num_workers   = args["workers"],
                                                                        batch_size    = args["batch_size"],
                                                                      pin_memory    = False,
                                                                      predict=True)

In [7]:
def getInputHACK(inputs):
    return [inputs[0], inputs[1], inputs[2], inputs[4], inputs[5]]

# Relation Network Code

### Set Up Batching

In [8]:
def make_dictionary(names, folder):
    '''names = list of protein_names, folder = 'emb_ligand', etc '''
    dict_out={}
    for n in names:
        filename = os.path.join(fullpath,folder,n +'.pkl')
        if os.path.exists(filename):
            if "emb" in filename:
                amino_emb,_ = torch.load(filename)
            else:
                amino_emb = torch.load(filename)[2].squeeze()
            dict_out[n] = amino_emb
    return dict_out


In [9]:
def loadAdjacencyMatrix(pdb,pdb_dir,chooseBound=True):
    #DB5 specific
    boundchar = 'b'
    if not chooseBound:
        boundchar = 'u'
    #[(500, 64, 'A'), (1116, 141), 'A'],
    adjacencies_full = np.load(os.path.join(pdb_dir, pdb, f'{pdb}_{boundchar}_adjacencies.npy'),allow_pickle=True)

    adjacencies_short = [[a[0][1], a[1][1]] for a in adjacencies_full]
    #Correct for the 1-indexing of the PDB to the 0-indexing of the ProteinGCN
    adjacencies_short = np.array(adjacencies_short)-1
    return adjacencies_short

In [10]:
#loading amino embeddings from files and putting into dictionaries
fullpath='/mnt/disks/amanda200/bounddb5_processed/'
protein_names=torch.load(fullpath + 'names.pkl')
#NOTE: there was an extra addition 'prot' that must be removed here
protein_names.remove('prot')

proteins = [n for n in protein_names]
if ISPROTEINGCN_EMBEDDED:
    ligands = make_dictionary(protein_names, 'emb_ligand')
    receptors = make_dictionary(protein_names, 'emb_receptor')        
else:
    ligands = make_dictionary(protein_names, 'ligand')
    receptors = make_dictionary(protein_names, 'receptor')        

In [11]:
f = torch.load('/mnt/disks/amanda200/bounddb5_processed/ligand/3S9D.pkl')
for i in np.arange(len(f)):
    print(f"{i}:  {f[i].size()}")


0:  torch.Size([1, 1032])
1:  torch.Size([1, 1032, 50, 43])
2:  torch.Size([1, 1032, 50])
3:  torch.Size([1, 1032])
4:  torch.Size([1, 1032])


In [12]:
ligand_sizes = [ligands[n].size() for n in ligands.keys()]
max_amino_l = np.max(ligand_sizes,0)
receptor_sizes = [receptors[n].size() for n in receptors.keys()]
max_amino_r = np.max(receptor_sizes,0)

print(max_amino_l)
print(max_amino_r)

[1500   64]
[2130   64]


In [14]:
labels={}
for protein in protein_names:
    # Load the adjacencey matrix that is a set of pairs of amino acids between ligand and receptors
    adj = torch.load(os.path.join(fullpath,'adjacencies', protein +'.pkl'))
    #Get the indices
    ligand_no = [adj[n][0] for n in np.arange(len(adj))]
    receptor_no = [adj[n][1] for n in np.arange(len(adj))]
    #Make a dense adjacency matrix 
    num_aminos_l = ligands[protein].shape[0]
    num_aminos_r = receptors[protein].shape[0]
    label = torch.zeros((num_aminos_l,num_aminos_r))
    label[ligand_no,receptor_no]=1
    labels[protein]=label

In [15]:
class SingleProteinTaskLoader:
    def __init__(self, protein_ligands,protein_receptors,complex_adjacencies, batch_size):
        """
        Do not concatenate the characters from different proteins together. When sampling a task,
        sample a protein first before sampling the shots within the protein

        proteins: a dictionary mapping from protein names to arrays of size (num_aminos, feat_size),
            where num_aminos is the number of amino acids in the protein
        batch_size: number of tasks to generate at once
        """
        self.batch_size = batch_size
        #Making the strong assumption that ligands and receptors have the same protein names 
        self.protein_names = list(protein_ligands.keys())
        self.protein_ligands = protein_ligands
        self.protein_receptors = protein_receptors
        self.complex_adjacencies = complex_adjacencies
        self.nbd = 1+2*aa_num_neighbors #The neigborhood around an amino acid
        self.h_b = 64 #The size of the hidden layer
    def __iter__(self):
        """
        Define this class as an iterable that yields a batch of tasks at each iteration when used in a for-loop.
    
        return: a batch of tasks
        """
        
        while True:
            batch = torch.as_tensor([])
            y_total = [] #all the labels
            for _ in range(self.batch_size):
                # Randomly sample a protein, then sample N_way classes from this proteins positive and negative
                # matching amino acids 
                protein_name = np.random.choice(self.protein_names)
                protein_ligand = self.protein_ligands[protein_name]
                
                protein_receptor = self.protein_receptors[protein_name]
                adjacencies = self.complex_adjacencies[protein_name]
                
                num_class = 2 #bind or doesn't bind

                #Initialize the set of positive and negative labels
                pos_labels_lig,pos_labels_receptor = np.where(self.complex_adjacencies[protein_name]>0)
                neg_labels_lig,neg_labels_receptor = np.where(self.complex_adjacencies[protein_name]==0)
                
                #Get N_way positive samples, N_way negative samples
                #NOTE: This even split of pos/neg is an important meta parameter
                pos_indices = np.random.choice(np.arange(len(pos_labels_lig)), size=N_way, replace=False)
                #NOTE: The negative samples can be taken from non-surface residues, which could be uninformative
                neg_indices = np.random.choice(np.arange(len(neg_labels_lig)), size=N_way, replace=False)
                
                #Here are the indices to positive/negative amino acid pairs and their labels
                indices_ligand = np.concatenate(( pos_labels_lig[pos_indices],neg_labels_lig[neg_indices]),axis=0);
                indices_receptor = np.concatenate((pos_labels_receptor[pos_indices],neg_labels_receptor[neg_indices]),axis=0);

                classes = np.concatenate((np.ones(N_way),np.zeros(N_way)),axis=0);
                #When making the batch, shuffle so we don't memorize positions
                permuted_indices = np.random.permutation(np.arange(len(classes)))
                                                                         
                supports_queries = torch.zeros((N_way + N_way,2,self.nbd, self.h_b)) 
                y = torch.zeros(N_way + N_way) #used to be K_shot + Q_queryperclass
                for i,p in enumerate(permuted_indices):
                    #(2, max_amino*num_receptors, max_amino*num_ligands)
                    y[i] = classes[p]
                    #Grow the local neighborhood with self.nbd
                    for j,n in enumerate(np.arange(-aa_num_neighbors,aa_num_neighbors+1)):
                        if (indices_ligand[p]+n) in range(0,protein_ligand.shape[0]):
                            supports_queries[i,0,j,:] = protein_ligand[indices_ligand[p]+n,:]
                        else:
                            supports_queries[i,0,j,:] = 0
                        if (indices_receptor[p]+n) in range(0,protein_receptor.shape[0]):    
                            supports_queries[i,1,j,:] = protein_receptor[indices_receptor[p]+n,:]
                        else:
                            supports_queries[i,1,j,:] = 0

                batch = torch.cat([batch,torch.as_tensor(supports_queries).unsqueeze(0)],dim=0)
                y_total.append(np.array(y))
           
            yield np.array(y_total),batch

In [16]:
class FullProteinTaskLoader:
    def __init__(SingleProteinTaskLoader):
        '''
        Inputs should be LSTM proteins so they are all the same length
        '''
        
  
    def __iter__(self):
        """
        return: a batch of tasks, which should be a tensor shaped
                (batch_size, max_amino*no_receptor, max_amino*no_ligands, feat_no*2 )
                labels for each task, which should be a tensor shaped
                (batch_size, max_amino*no_receptor, max_amino*no_ligands) filled with 1 or 0
            
            
            What are the classes? yes/no or pairs? Do we want to input the receptors and ligands separately? 
        """
        while True:
                protein_name = np.random.choice(self.protein_names)
                protein_ligand = self.protein_ligands[protein_name]
                batch = torch.as_tensor([])
                for _ in range(self.batch_size):
                    protein_type = np.random.choice(np.arange(2))
                    if bool(protein_type):
                        new_task = self.protein_ligands[protein_name]
                    batch = torch.cat([batch,torch.as_tensor])

                protein_receptor = self.protein_receptors[protein_name]
                adjacencies = self.complex_adjacencies[protein_name]
                yield (protein_name, protein_ligand) # Shape (batch_size, N_way, K_shot + Q_queryperclass, H, W)

### Networks

In [19]:
class AutoEncoder(nn.Module):
    def __init__(self, feat_size=64, bottleneck=8):
        super().__init__()
        
        ### Your code here ###
        
        self.encode_hidden1 = nn.Sequential(
            nn.Linear(in_features=feat_size,
            out_features= bottleneck^3),
            nn.ReLU(inplace=True)
        )
        self.encode_hidden2 = nn.Sequential(
            nn.Linear(in_features=bottleneck^3,
            out_features= bottleneck^2),
            nn.ReLU(inplace=True)
        )
        self.encode_output = nn.Sequential(
            nn.Linear(in_features=bottleneck^2,
            out_features=bottleneck),
            nn.ReLU(inplace=True)
        )   
        self.decode_hidden = nn.Sequential(
            nn.Linear(in_features=bottleneck,
            out_features=bottleneck^2),
            nn.ReLU(inplace=True)
        )
        self.decode_output1 = nn.Sequential(
            nn.Linear(in_features=bottleneck^2,
            out_features=bottleneck^3),
            nn.ReLU(inplace=True)
        )
        self.decode_output2 = nn.Sequential(
            nn.Linear(in_features=bottleneck^3,
            out_features=feat_size),
            nn.ReLU(inplace=True)
        )
        
    def forward(self, feats,labels):
        """
        Forward pass of the neural network

        """
        window_size, feature_size = feats.shape

        loss, accuracy = self.predictor(feats.view(-1,feature_size))

        return loss, accuracy
        
    def predictor(self, features):
        
        newstate = self.encode_hidden1(features)
        newstate = self.encode_hidden2(newstate)
        newstate = self.encode_output(newstate)
        newstate = self.decode_hidden(newstate)
        newstate = self.decode_output1(newstate)
        relations = self.decode_output2(newstate)
  
        loss = F.mse_loss(relations.view(-1,1).squeeze(), features.view(-1,1).squeeze())

        #accuracy = (torch.round(relations.squeeze()) == labels_.squeeze()).float().mean()
        accuracy = 0 #accuracy is just a place holder for now.
        return torch.as_tensor(loss), torch.as_tensor(accuracy)
        

In [67]:
class DecodeLayer(nn.Module):
    def __init__(self, window_size=100, bottleneck=8):
        super().__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(in_features=window_size*bottleneck,
            out_features=window_size^2),
            nn.ReLU(inplace=True),
            nn.Sigmoid()
            )
    def forward(self, feats,labels):
        """
        Forward pass of the neural network
        """
        #print(f"Feats size: {feats.size()}")
        window_size, feature_size = feats.shape

        loss, accuracy = self.predictor(feats.view(-1,feature_size),labels)
        #convolve across window?
        return loss, accuracy
        
    def predictor(self, features, labels):
        
        print(features.size())
        relations = self.fc1(features)

        loss = F.mse_loss(relations.view(-1,1).squeeze(), features.view(-1,1).squeeze())
        
        #accuracy = (torch.round(relations.squeeze()) == labels_.squeeze()).float().mean()
        accuracy = (round(relations.view(-1,1)) == labels).float().mean()
        return torch.as_tensor(loss), torch.as_tensor(accuracy)
        

In [21]:
def train(net, num_steps, print_prob=False):
    """
    Train the input neural network for num_steps
    
    net: an instance of MatchingNetwork, PrototypicalNetwork, or RelationNetwork
    num_steps: number of batches to train for
    
    return: the trained net, the training accuracies per step, and the validation accuracies per 100 steps
    """
    net = net.to(device)
    opt = torch.optim.Adam(net.parameters(), lr=1e-3,weight_decay=0.1)
    train_accuracies = []
    train_losses = []
    val_accuracies = []
    val_losses=[]
    count_global=0
    val_count_global=0

    for step, data in zip(range(num_steps), train_metadataset):

        opt.zero_grad()
        loss, accuracy = net(torch.as_tensor(data, dtype=torch.float32, device=device),0)

        loss.backward(retain_graph=True,create_graph=True)
        opt.step()
        train_loss, train_accuracy = map(from_torch, (loss, accuracy))
        train_accuracies.append(train_accuracy)
        train_losses.append(train_loss)
        if (step + 1) % 100 == 0:
            val_metadataset = val_meta_blocks[val_count_global,:,:,:].squeeze()
            for n in np.arange(val_metadataset.size(0)):
                loss_vec=[]
                temp_loss, val_accuracy = evaluate(net, val_metadataset[n,:,:]) 
                loss_vec.append(temp_loss)
            val_loss = np.mean(loss_vec)
            val_count_global+=1
            val_accuracies.append(val_accuracy)
            val_losses.append(val_loss)
            
            print('step=%s   train(loss=%.5g, accuracy=%.5g)  val(loss=%.5g, accuracy=%.5g)' % (
                step + 1, train_loss, train_accuracy, val_loss, val_accuracy
            ))
            if print_prob:
                print(np.concatenate((val_relations, val_labels.reshape(-1,1)),axis=1))
    return net, train_accuracies, train_losses, val_accuracies, val_losses

def evaluate(net, metadataset):
    """
    Evalate the trained net on either the validation or test metadataset
    
    """
    with torch.no_grad(): # Evaluate without gradients

        loss, accuracy = net(torch.as_tensor(metadataset.squeeze(), dtype=torch.float32, device=device), 0)
        loss, accuracy = map(from_torch, (loss, accuracy))

    return loss, accuracy

device="cpu"

# Running the Autoencoder

In [22]:
num_proteins = len(protein_names)
num_train, num_val = int(num_proteins * 0.6), int(num_proteins * 0.2)

# Train, val, test sets
#Note that we've already shuffled the proteins so we can simply divide the indices in order
train_val_test_splits = np.split(list(protein_names), [num_train, num_train + num_val])


sets = [{n: ligands[n] for n in pdbids} for pdbids in train_val_test_splits]
ligands_train, ligands_test, ligands_val = sets

sets = [{n: receptors[n] for n in pdbids} for pdbids in train_val_test_splits]
receptors_train, receptors_test, receptors_val = sets

sets = [{n: labels[n] for n in pdbids} for pdbids in train_val_test_splits]
labels_train, labels_test, labels_val = sets


train_metadataset = SingleProteinTaskLoader(ligands_train,receptors_train, labels_train,batch_size=64)
val_metadataset = SingleProteinTaskLoader(ligands_val,receptors_val, labels_val,batch_size=100)
test_metadataset = SingleProteinTaskLoader(ligands_test,receptors_test, labels_test,batch_size=100)


In [24]:
class WindowedProteinTaskLoader:
    def __init__(self, protein_ligands, protein_receptors,window_length, window_overlap):
        '''
        Inputs should be LSTM proteins so they are all the same length
        '''

        self.window=window_length #input the atom numbers?
        self.overlap=window_overlap
        self.protein_names = list(protein_ligands.keys())
        self.protein_ligands = protein_ligands
        self.protein_receptors = protein_receptors

  
    def create_batch(self):
        """
        return: a batch of tasks, which should be a tensor shaped
                (batch_size, window_size, feature_size) where batch_size = # windows per protein * protein
                the indexes used, dictionary with keys = protein names
                value of dictionary = torch tensor shaped (2*number of windows in protein or ligand, window_length)
        """
        batch = torch.as_tensor([])
        protein_idx = dict.fromkeys(self.protein_names, [])
        ligand_windows = dict.fromkeys(self.protein_names, [])
        receptor_windows = dict.fromkeys(self.protein_names, [])
        protein_count = 0
        for protein in self.protein_names:
            #print(protein)
            ligand = self.protein_ligands[protein]
            ligand_len = ligand.size(0)
            receptor = self.protein_receptors[protein]
            receptor_len = receptor.size(0)
            receptor_window = 0
            ligand_window = 0
            #window from front and back?
            idx_all = torch.as_tensor([])
            while (ligand_window*self.overlap+self.window) <= ligand_len:
                idx_forward = torch.arange(ligand_window*self.overlap, ligand_window*self.overlap + self.window)
                batch = torch.cat([batch, ligand[idx_forward,:].unsqueeze(dim=0)], dim=0)
                idx_all = torch.cat([idx_all, idx_forward.unsqueeze(dim=0)],dim=0)
                idx_back = torch.arange(ligand_len-(self.overlap*ligand_window)-1,ligand_len-(self.overlap*ligand_window)-self.window-1,-1)
#                     print(max(idx_back))
#                     print(min(idx_back))
#                     print(len(idx_back))
                batch = torch.cat([batch, ligand[idx_back,:].unsqueeze(dim=0)], dim=0)
                idx_all = torch.cat([idx_all, idx_back.unsqueeze(dim=0)],dim=0)
                ligand_window +=1
                #print(window_num)

            while (receptor_window*self.overlap+self.window) <= receptor_len:
                idx_forward = torch.arange(receptor_window*self.overlap, receptor_window*self.overlap + self.window)
                batch = torch.cat([batch, receptor[idx_forward,:].unsqueeze(dim=0)], dim=0)
                idx_all = torch.cat([idx_all, idx_forward.unsqueeze(dim=0)],dim=0)
                idx_back = torch.arange(receptor_len-(self.overlap*receptor_window)-1,receptor_len-(self.overlap*receptor_window)-self.window-1,-1)
                batch = torch.cat([batch, receptor[idx_back,:].unsqueeze(dim=0)], dim=0)
                idx_all = torch.cat([idx_all, idx_back.unsqueeze(dim=0)],dim=0)
                receptor_window +=1
                #print(window_num)
            protein_idx[protein] = idx_all
            receptor_windows[protein] = receptor_window
            ligand_windows[protein] = ligand_window
            protein_count+=1
            print(protein_count)


        return batch, protein_idx, ligand_windows, receptor_windows # Shape (batch_size, N_way, K_shot + Q_queryperclass, H, W)

In [26]:
train_windows = WindowedProteinTaskLoader(ligands_train,receptors_train, window_length=100, window_overlap=50)
val_windows = WindowedProteinTaskLoader(ligands_val,receptors_val, window_length=100, window_overlap=50)
#test_windows = WindowedProteinTaskLoader(ligands_test,receptors_test, window_length=100, window_overlap=50)



In [32]:
autonet=AutoEncoder()
train_steps = 100
dump, train_accs, train_loss, val_accs, val_loss = train(autonet, train_steps)
print(len(val_accs))

step=100   train(loss=117.43, accuracy=0)  val(loss=79.923, accuracy=0)
1


In [None]:
fig, (ax1,ax2) = plt.subplots(2,1,sharex=True)
ax1.plot((1 + np.arange(len(train_accs))), train_accs, label="Training Accuracy")
ax1.plot((1 + np.arange(len(val_accs)))*100, val_accs, label="Validation Accuracy")
ax2.set_xlabel('training step')
ax1.set_ylabel('accuracy')
ax1.legend()

ax2.plot((1 + np.arange(len(train_loss))), train_loss, label="Training Loss")
ax2.plot((1 + np.arange(len(val_loss))) * 100, val_loss, label="Validation Loss")
ax2.set_ylabel('validation loss')
ax2.legend()

## Code for Extended Pipeline/Future Directions. (Under Construction)

In [20]:
class CatAutoEncoder(AutoEncoder):
    def __init__(self): #depth is protein_no*protein_no
        """
        Define Linear layers for Autoencoder
        bottleneck = number of nodes at bottleneck
        """
        #Questions:
        #1. How do we deal with variable length proteins? zeropad?
        #2. How do we go from autoencoder output to protein structure?
                #can we do this with proteinGCN? Do we learn an adjacency matrix?
        #3. Convolve or LSTM along protein to take advantage of chain structure?
        
        # The backbone for Relation Network doesn't max pool for the last two convolution blocks
        super().__init__(bottleneck=8)
        ### Your code here ###

        self.decode_cat_hidden = nn.Sequential(
            nn.Linear(in_features=2*bottleneck,
            out_features=(2*bottleneck)^2),
            nn.ReLU(inplace=True)
        )
        self.decode_cat_output = nn.Sequential(
            nn.Linear(in_features=(2*bottleneck)^2,
            out_features=(2*bottleneck)^3),
            nn.ReLU(inplace=True)
        )
 

    def predictor(self, features, labels_):
        """
        receptor_emb = tensor of size (batch_no, task_no, r/l,protein_length?, feature size)
        return: a tuple (loss, accuracy) of two torch.float32 scalars representing the mean loss and
            mean accuracy of the batch
        """
        ### Your code here ###
        rec_state = self.encode_hidden(features[:,:,0,:,:].view(-1, nbd, feat_size))
        rec_state = self.encode_output(rec_state)
        lig_state = self.encode_hidden(features[:,:,1,:,:].view(-1, nbd, feat_size))
        lig_state = self.encode_output(lig_state)
        
        hidden_cat = torch.cat([rec_state,lig_state],dim=2)
        newstate = self.decode_hidden(hidden_cat)
        relations = self.decode_output(newstate)
        
        #fully connected layer on decoded output for labels or adj matrix?
        #do we want to learn embedded bound proteins from embedded unbound proteins?
        #train model on bound to bound then use as initialization for bound to unbound?
        
        loss = F.mse_loss(relations.squeeze(), labels_.squeeze())
        accuracy = (torch.round(relations.squeeze()) == labels_.squeeze()).float().mean()
        
        
        return relations, loss, accuracy

In [69]:
def train_decoder(net, num_steps, print_prob=False):
    """
    Train the input neural network for num_steps
    
    net: an instance of MatchingNetwork, PrototypicalNetwork, or RelationNetwork
    num_steps: number of batches to train for
    
    return: the trained net, the training accuracies per step, and the validation accuracies per 100 steps
    """
    net = net.to(device)
    opt = torch.optim.Adam(net.parameters(), lr=1e-3,weight_decay=0.1)
    train_accuracies = []
    train_losses = []
    val_accuracies = []
    val_losses=[]
    count_global=0
    val_count_global=0
    #print(train_metadataset)
    for step, data, label in zip(range(num_steps), train_metadataset, train_latent_labels):
        #print(step)
        #print(data.size())
        #labels_ = torch.as_tensor(y) #Adding the _ because labels is a globally defined variable
        #batch_size = data.size(0)#Will this print the correct size?
        #print(f"batch size of task is: {data.size()}")
        

        opt.zero_grad()
        loss, accuracy = net(torch.as_tensor(data, dtype=torch.float32, device=device),train_latent_labels)
        #loss, accuracy = net(torch.as_tensor(data),0)
#         if step == 0:
#             loss.backward(retain_graph=True)
#         else:
#             loss.backward()
        #print(loss.size())
        loss.backward(retain_graph=True,create_graph=True)
        opt.step()
        train_loss, train_accuracy = map(from_torch, (loss, accuracy))
        train_accuracies.append(train_accuracy)
        train_losses.append(train_loss)
        if (step + 1) % 100 == 0:
            val_metadataset = val_decoder_blocks[val_count_global,:,:,:].squeeze()
            val_labelset = val_label_blocks[val_count_global,:,:].squeeze()
            for n in np.arange(val_metadataset.size(0)):
                loss_vec=[]
                temp_loss, val_accuracy = evaluate_decoder(net, val_metadataset[n,:,:], val_labelset) #whetre are these defined?
                loss_vec.append(temp_loss)
            val_loss = np.mean(loss_vec)
            val_count_global+=1
            val_accuracies.append(val_accuracy)
            val_losses.append(val_loss)
            
            print('step=%s   train(loss=%.5g, accuracy=%.5g)  val(loss=%.5g, accuracy=%.5g)' % (
                step + 1, train_loss, train_accuracy, val_loss, val_accuracy
            ))
            if print_prob:
                print(np.concatenate((val_relations, val_labels.reshape(-1,1)),axis=1))
    return net, train_accuracies, train_losses, val_accuracies, val_losses

def evaluate_decoder(net, metadataset, metalabels):
    """
    Evalate the trained net on either the validation or test metadataset
    
    net: an instance of MatchingNetwork, PrototypicalNetwork, or RelationNetwork
    metadataset: validation or test metadataset
    
    return: a tuple (loss, accuracy) of Python scalars
    """
    with torch.no_grad(): # Evaluate without gradients

        loss, accuracy = net(torch.as_tensor(metadataset.squeeze(), dtype=torch.float32, device=device), metalabels)
        loss, accuracy = map(from_torch, (loss, accuracy))

    return loss, accuracy

device="cpu"

In [38]:
def MixLatentState(latents, indexes, names, train_lw, train_rw, adj):
    ''' Want to build concatenated ligand and receptor sets. concatenate by windows (all combinations)
    adjacency matrix is receptor x ligand
    '''
    total_windows=0
    protein_idx=0
    batch=torch.as_tensor([])
    labels=torch.as_tensor([])
    for protein in names:
        temp_index=torch.as_tensor(indexes[protein],dtype=torch.long)
        temp_label=torch.as_tensor(adj[protein],dtype=torch.long)
        assert(temp_index.size(0) == (train_lw[protein]+train_rw[protein])*2)
        max_lw=2*train_lw[protein]
        max_rw=2*train_rw[protein]
        labeltemp=[]
        for ii, m in enumerate(np.arange(max_rw)+total_windows+max_lw):
            for jj, n in enumerate(np.arange(max_lw)+total_windows):
                temp=temp_label[:,temp_index[ii+max_lw]]
                temp=temp[temp_index[jj],:]
                labels = torch.cat([labels, temp.view(-1,1)],dim=1)
                latentcat = torch.cat([latents[n,:,:].unsqueeze(dim=0),latents[m,:,:].unsqueeze(dim=0)],dim=1)
                batch = torch.cat([batch, latentcat],dim=0)
                
        total_windows=total_windows+max_lw+max_rw
        print(total_windows)
    return batch, labels
            

### Unpack values

In [None]:
train_data, train_indexes, train_lw, train_rw = train_windows.create_batch()
train_proteins = ligands_train.keys()
val_proteins = ligands_val.keys()
#pickle.dump((train_data, train_indexes, train_proteins, train_lw, train_rw), open('tr_window50_overlap25.pkl','wb'))
val_data, val_indexes, val_lw, val_rw = val_windows.create_batch()
#pickle.dump((val_data, val_indexes, val_proteins, val_lw, val_rw), open('val_window50_overlap25.pkl','wb'))

In [30]:
#val_metadataset.size()
val_temp = val_data[:600,:,:]
val_meta_blocks = val_temp.unsqueeze(dim=0).view(20,-1,100,64)


### Find All Encoded/Latent States for DB5 and batch

In [None]:
#Batching Encoded Training Data
latent_states=torch.as_tensor([])
for n in torch.arange(train_data.size(0)):
    newstate = autonet.encode_hidden1(train_data[n,:,:])
    newstate = autonet.encode_hidden2(newstate)
    latenttemp = autonet.encode_output(newstate)
    
latent_states=torch.cat([latent_states, torch.as_tensor(latenttemp).unsqueeze(dim=0)],dim=0)
batched_latents, latent_labels = MixLatentState(latent_states,train_indexes,train_proteins,train_lw,train_rw,label_out)

In [None]:
#Batching Encoded Validation Data
val_states=torch.as_tensor([])
for n in torch.arange(val_data.size(0)):
    newstate = autonet.encode_hidden1(val_data[n,:,:])
    newstate = autonet.encode_hidden2(newstate)
    latenttemp = autonet.encode_output(newstate)
    
    latent_states=torch.cat([latent_states, torch.as_tensor(latenttemp).unsqueeze(dim=0)],dim=0)
batched_val, val_latent_labels = MixLatentState(latent_states,val_indexes,val_proteins,val_lw,val_rw,label_out)

In [61]:
#Adjust sizes
tempval=batched_val[:1600,:,:]
val_decoder_blocks = tempval.unsqueeze(dim=0).view(100,-1,200,8)
templabels=val_latent_labels.permute(1,0)
print(val_decoder_blocks.size())
templabels=templabels[:1600,:]
val_label_blocks=templabels.unsqueeze(dim=0).view(100,-1,10000)
print(val_label_blocks.size())

torch.Size([100, 16, 200, 8])
torch.Size([100, 16, 10000])


### Make Matching Adjacency Matrices as Labels 

In [47]:
label_out={}
for protein in val_proteins:
    # Load the adjacencey matrix that is a set of pairs of amino acids between ligand and receptors
    adj = torch.load(os.path.join(fullpath,'adjacencies', protein +'.pkl'))
    #Get the indices
    ligand_no = [adj[n][0] for n in np.arange(len(adj))]
    receptor_no = [adj[n][1] for n in np.arange(len(adj))]
    #Make a dense adjacency matrix 
    num_aminos_l = ligands[protein].shape[0]
    num_aminos_r = receptors[protein].shape[0]
    label = torch.zeros((num_aminos_l,num_aminos_r))
    label[ligand_no,receptor_no]=1
    label_out[protein]=torch.as_tensor(label,dtype=torch.long)

### Run AE Pipeline from Encoded proteins to Adjacency Matrix 

In [None]:
train_metadataset= train_data
val_decoder_blocks=val_metadataset

decodenet=DecodeLayer()
train_steps = 100
train_metadatset
dump, train_accs, train_loss, val_accs, val_loss = train_decoder(decodenet, train_steps)
print(len(val_accs))