In [0]:
training_file = "drive/My Drive/assign4data/training_30"
validation_file = "drive/My Drive/assign4data/validation"
testing_file = "drive/My Drive/assign4data/testing"
ng_val = 1

INVALID_ANGLE = 10
import glob
import os.path
import os
import platform
import numpy as np
import h5py
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from torch.nn.functional import softmax 
import pickle as pkl
import random
import torch.optim as optim
import subprocess
import torch.utils.data


In [3]:
#skip this section if not running on colab
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [0]:



import numpy as np
from Bio.PDB.vectors import Vector, calc_angle, calc_dihedral


def process_tertiary(tertiary):
    '''compute the bond lengths, bond angles, and dihedral angles'''
    phi = []
    psi = []
    omega = []
    bond_angle_CNCa = []
    bond_angle_NCaC = []
    bond_angle_CaCN = []
    bond_len_NCa = []
    bond_len_CaC = []
    bond_len_CN = []
    # convert tertiary coords into Vectors
    pV = [vec for vec in map(lambda v: Vector(v[0], v[1], v[2]),
                             zip(tertiary[0], tertiary[1], tertiary[2]))]

    for i in range(0, len(pV), 3):
        # check for zero coords
        norm_im1 = False
        norm_i = False
        norm_i1 = False
        norm_i2 = False
        norm_i3 = False
        norm_i4 = False
        if i > 0 and pV[i-1].norm() > 0:
            norm_im1 = True
        if pV[i].norm() > 0:
            norm_i = True
        if pV[i+1].norm() > 0:
            norm_i1 = True
        if pV[i+2].norm() > 0:
            norm_i2 = True
        if i + 3 < len(pV) and pV[i+3].norm() > 0:
            norm_i3 = True
        if i + 3 < len(pV) and pV[i+4].norm() > 0:
            norm_i4 = True

        # compute bond lengths
        if norm_im1 and norm_i:
            blen_CN = (pV[i-1]-pV[i]).norm()
            bond_len_CN.append(blen_CN)

        if norm_i and norm_i1:
            blen_NCa = (pV[i]-pV[i+1]).norm()
            bond_len_NCa.append(blen_NCa)

        if norm_i1 and norm_i2:
            blen_CaC = (pV[i+1]-pV[i+2]).norm()
            bond_len_CaC.append(blen_CaC)

        # compute bond angles
        if norm_im1 and norm_i and norm_i1:
            theta_CNCa = calc_angle(pV[i-1], pV[i], pV[i+1])  # C-N-Ca
            bond_angle_CNCa.append(theta_CNCa)

        if norm_i and norm_i1 and norm_i2:
            theta_NCaC = calc_angle(pV[i], pV[i+1], pV[i+2])  # N-Ca-C
            bond_angle_NCaC.append(theta_NCaC)

        if norm_i1 and norm_i2 and norm_i3:
            theta_CaCN = calc_angle(pV[i+1], pV[i+2], pV[i+3])  # Ca-C-N
            bond_angle_CaCN.append(theta_CaCN)

        # compute dihedral angles
        if norm_im1 and norm_i and norm_i1 and norm_i2:
            phi_i = calc_dihedral(
                pV[i-1], pV[i], pV[i+1], pV[i+2])  # N-Ca-C-N
        else:
            phi_i = INVALID_ANGLE
        phi.append(phi_i)

        if norm_i and norm_i1 and norm_i2 and norm_i3:
            psi_i = calc_dihedral(
                pV[i], pV[i+1], pV[i+2], pV[i+3])  # C-N-Ca-C
        else:
            psi_i = INVALID_ANGLE
        psi.append(psi_i)

        if norm_i1 and norm_i2 and norm_i3 and norm_i4:
            omega_i = calc_dihedral(
                pV[i+1], pV[i+2], pV[i+3], pV[i+4])  # Ca-C-N-Ca
        else:
            omega_i = INVALID_ANGLE
        omega.append(omega_i)

    return (phi, psi, omega, bond_angle_NCaC, bond_angle_CaCN,
            bond_angle_CNCa, bond_len_CN, bond_len_NCa, bond_len_CaC)




In [0]:
#functions from util
AA_ID_DICT = {'A': 1, 'C': 2, 'D': 3, 'E': 4, 'F': 5, 'G': 6, 'H': 7, 'I': 8, 'K': 9,
              'L': 10, 'M': 11, 'N': 12, 'P': 13, 'Q': 14, 'R': 15, 'S': 16, 'T': 17,
              'V': 18, 'W': 19, 'Y': 20}
MAX_SEQUENCE_LENGTH = 100
def encode_primary_string(primary):
    return list([AA_ID_DICT[aa] for aa in primary])

def calc_pairwise_distances(chain_a, chain_b, use_gpu):
    distance_matrix = torch.Tensor(chain_a.size()[0], chain_b.size()[0]).type(torch.float)
    # add small epsilon to avoid boundary issues
    epsilon = 10 ** (-4) * torch.ones(chain_a.size(0), chain_b.size(0))
    if use_gpu:
        distance_matrix = distance_matrix.cuda()
        epsilon = epsilon.cuda()

    for idx, row in enumerate(chain_a.split(1)):
        distance_matrix[idx] = torch.sum((row.expand_as(chain_b) - chain_b) ** 2, 1).view(1, -1)

    return torch.sqrt(distance_matrix + epsilon)

def read_protein_from_file(file_pointer):
    """The algorithm Defining Secondary Structure of Proteins (DSSP) uses information on e.g. the
    position of atoms and the hydrogen bonds of the molecule to determine the secondary structure
    (helices, sheets...).
    """
    dict_ = {}
    _dssp_dict = {'L': 0, 'H': 1, 'B': 2, 'E': 3, 'G': 4, 'I': 5, 'T': 6, 'S': 7}
    _mask_dict = {'-': 0, '+': 1}

    while True:
        next_line = file_pointer.readline()
        if next_line == '[ID]\n':
            id_ = file_pointer.readline()[:-1]
            dict_.update({'id': id_})
        elif next_line == '[PRIMARY]\n':
            primary = encode_primary_string(file_pointer.readline()[:-1])
            dict_.update({'primary': primary})
        elif next_line == '[EVOLUTIONARY]\n':
            evolutionary = []
            for _residue in range(21):
                evolutionary.append([float(step) for step in file_pointer.readline().split()])
            dict_.update({'evolutionary': evolutionary})
        elif next_line == '[SECONDARY]\n':
            secondary = list([_dssp_dict[dssp] for dssp in file_pointer.readline()[:-1]])
            dict_.update({'secondary': secondary})
        elif next_line == "[TERTIARY]\n":
            tertiary = []
            # 3 dimension
            for _axis in range(3):
                next_line = file_pointer.readline()
                tertiary.append(
                    [float(coord)/100 for coord in next_line.split()])
            phi, psi, omega,\
                    bond_angle_NCaC, bond_angle_CaCN, bond_angle_CNCa,\
                    bond_len_CN, bond_len_NCa, bond_len_CaC = process_tertiary(
                        tertiary)
            dict_.update({"tertiary": tertiary})
            dict_.update({"phi": phi})
            dict_.update({"psi": psi})

            
        elif next_line == '[MASK]\n':
            mask = list([_mask_dict[aa] for aa in file_pointer.readline()[:-1]])
            dict_.update({'mask': mask})
        elif next_line == '\n':
            return dict_
        elif next_line == '':
            return None


def process_file(input_file, use_gpu):

    input_file_pointer = open(input_file, "r")
    all_proteins = []
    while True:
        # while there's more proteins to process
        next_protein = read_protein_from_file(input_file_pointer)
        if next_protein is None:
            break

        sequence_length = len(next_protein['primary'])

        if sequence_length > MAX_SEQUENCE_LENGTH:
            #print("Dropping protein as length too long:", sequence_length)
            continue

        phi = next_protein['phi']
        psi = next_protein['psi']

        primary_padded = np.zeros(MAX_SEQUENCE_LENGTH)
        tertiary_padded = np.zeros((3, MAX_SEQUENCE_LENGTH))
        evo_padded = np.zeros((21,MAX_SEQUENCE_LENGTH))
        phi_padded = np.zeros(MAX_SEQUENCE_LENGTH)
        psi_padded = np.zeros(MAX_SEQUENCE_LENGTH)
        
        # masking and padding here happens so that the stored dataset is of the same size.
        # when the data is loaded in this padding is removed again.
        primary_padded[:sequence_length] = next_protein['primary']
        phi_padded[:sequence_length] = phi
        psi_padded[:sequence_length] = psi
        phi_mask = torch.Tensor([1 if x!=INVALID_ANGLE else 0 for x in phi_padded])
        psi_mask = torch.Tensor([1 if x!=INVALID_ANGLE else 0 for x in psi_padded])
        mask_padded = np.zeros(MAX_SEQUENCE_LENGTH)
        mask_padded[:sequence_length] = next_protein['mask']
        mask_padded = mask_padded*phi_mask.numpy()*psi_mask.numpy()

        next_protein['tertiary'] = [next_protein['tertiary'][i] for i in range(1, len(next_protein['tertiary']), 3)]
        t_transposed = np.ravel(np.array(next_protein['tertiary']).T)
        t_reshaped = np.reshape(t_transposed, (sequence_length, 3)).T

        e_transposed = np.ravel(np.array(next_protein['evolutionary']).T)
        e_reshaped = np.reshape(e_transposed, (sequence_length,21)).T

        tertiary_padded[:, :sequence_length] = t_reshaped
        evo_padded[:,:sequence_length] = e_reshaped


        mask = torch.Tensor(mask_padded).type(dtype=torch.bool)

        prim = torch.masked_select(torch.Tensor(primary_padded)
                                   .type(dtype=torch.long), mask)
        phi_t = torch.masked_select(torch.Tensor(phi_padded)
                                   .type(dtype=torch.float), mask)
        psi_t = torch.masked_select(torch.Tensor(psi_padded)
                                   .type(dtype=torch.float), mask)

        
        pos = torch.masked_select(torch.Tensor(tertiary_padded), mask)\
                  .view(3, -1).transpose(0, 1).unsqueeze(1)



        
        tertiary = pos.squeeze(1)

        evol = torch.masked_select(torch.Tensor(evo_padded), mask)\
                  .view(21, -1).transpose(0, 1).unsqueeze(1) / 100

        
        evolutionary = evol.squeeze(1)

        primary_padded = np.zeros(MAX_SEQUENCE_LENGTH)
        tertiary_padded = np.zeros((MAX_SEQUENCE_LENGTH, 3))
        evo_padded = np.zeros((MAX_SEQUENCE_LENGTH,21))
        phi_padded = np.zeros(MAX_SEQUENCE_LENGTH)
        psi_padded = np.zeros(MAX_SEQUENCE_LENGTH)

        length_after_mask_removed = len(prim)

        primary_padded[:length_after_mask_removed] = prim.data.cpu().numpy()
        tertiary_padded[:length_after_mask_removed, :] = tertiary.data.cpu().numpy()
        evo_padded[:length_after_mask_removed, :] = evolutionary.data.cpu().numpy()
        phi_padded[:length_after_mask_removed] = phi_t.data.cpu().numpy()
        psi_padded[:length_after_mask_removed] = psi_t.data.cpu().numpy()

        mask_padded = np.zeros(MAX_SEQUENCE_LENGTH)
        mask_padded[:length_after_mask_removed] = np.ones(length_after_mask_removed)
        
        dict_ = {}
        dict_['primary'] = primary_padded
        dict_['tertiary'] = tertiary_padded
        dict_['mask_padded'] = mask_padded
        dict_['evolutionary'] = evo_padded
        dict_['seq_len'] = sequence_length
        dict_['phi'] = phi_padded
        dict_['psi'] = psi_padded
        all_proteins.append(dict_)


    return all_proteins

def sequence_onehot(seq):
    """Maps the given sequence into a one-hot encoded matrix."""
    one_hot = np.zeros((len(seq), 20), dtype=np.int32)

    for aa_index, aa_id in enumerate(seq):
        one_hot[int(aa_index), int(aa_id) - 1] = 1

    return torch.Tensor(one_hot).float()

def seq2block(prim_msa_i, prim_msa_j):
    # takes each prot fi and returns n*n*f block were fij = |fi cat fj|
    # todo: add options to have dim=2 be different options such as |fi cat fj|fi * fj|
    # prim_msa shape should be (prot_len x 41d)
    seq_len = prim_msa_i.shape[0]
    fij = torch.zeros((seq_len, seq_len, prim_msa_i.shape[1] * 2))
    for idx, i in enumerate(prim_msa_i):
        for jdx, j in enumerate(prim_msa_j):
            fij[idx][jdx] = torch.cat((i, j))
    return fij



In [0]:
import torch
import torch.nn as nn
import torch.utils.model_zoo as model_zoo
from torch.nn.functional import softmax 
import pickle as pkl
import numpy as np
import random
import torch.optim as optim
import subprocess
import torch.utils.data
#helper functions
#GLOBAL VARIABLE
INPUT_DIM = 82
OUTPUT_BINS = 65 #number of bins in output
ANGLE_BINS = 36
RESNET_DIM = 128 #number of layers inside of resnet


class BasicBlock(nn.Module):

    def __init__(self, dilation = 1):
        super(BasicBlock, self).__init__()
        norm_layer = nn.BatchNorm2d
        self.project_down = conv1x1(128, 64, stride=1)
        self.project_up   = conv1x1(64, 128, stride=1)
        self.bn64_1 = norm_layer(64)
        self.bn64_2 = norm_layer(64)
        self.bn128 = norm_layer(128)

        #dilations deal now with 64 incoming and 64 outcoming layers
        self.dilation = conv3x3(64, 64, stride=1, dilation = dilation) #when the block is initialized, the only thing that changes is the dilation filter used!
        self.elu = nn.ELU(inplace=True)

    def forward(self, x):
        
        identity = x
    
        #the deepmind basic block goes:
        
        #batchnorm
        out = self.bn128(x)
        
        #elu
        out = self.elu(out)
    
        #project down to 64
        out = self.project_down(out)
        
        #batchnorm
        out = self.bn64_1(out)

        #elu
        out = self.elu(out)   
        
        #cycle through 4 dilations
        out = self.dilation(out)  
        
        #batchnorm
        out = self.bn64_2(out)

        #elu
        out = self.elu(out)
        
        #project up to 128
        out = self.project_up(out)
        
        #identitiy addition 
        out = out + identity

        return out


class ResNet(nn.Module):

    def __init__(self,n_groups):
        super(ResNet, self).__init__()
        self.inplanes = RESNET_DIM
        
        self.conv1 = conv1x1(INPUT_DIM, RESNET_DIM, stride=1)
        self.conv2 = conv1x1(RESNET_DIM, OUTPUT_BINS )
        
        self.phi_i_conv  =  conv1x64(RESNET_DIM,ANGLE_BINS)
        self.phi_j_conv  =  conv1x64(RESNET_DIM,ANGLE_BINS)
        self.psi_i_conv  =  conv1x64(RESNET_DIM,ANGLE_BINS)
        self.psi_j_conv  =  conv1x64(RESNET_DIM,ANGLE_BINS)

        self.resnet_blocks = self._make_layer(n_groups)


    def _make_layer(self, n_groups):
        layers = []
        #here I need to pass in the correct dilations 1,2,4,8
        dilations = [1,2,4,8]
        for i,_ in enumerate(range(0, n_groups)):
            layers.append(BasicBlock( dilation = dilations[i]))

        return nn.Sequential(*layers)

    def forward(self, x):
        #fix input dimensions
        x = self.conv1(x)
  
        #propagate through RESNET blocks
        resnet_out = self.resnet_blocks(x)
        #renet_out has shape 1,128,64,64
        phii_out = self.phi_i_conv(resnet_out)
        phij_out = self.phi_j_conv(resnet_out)
        psii_out = self.psi_i_conv(resnet_out)
        psij_out = self.psi_j_conv(resnet_out)

        #fix output dimensions
        x = self.conv2(resnet_out)
        #FIX THIS TO WORK WITH BATCHES!
        m = nn.Softmax2d()
        return (m(x),m(phii_out),m(phij_out),m(psii_out),m(psij_out))


def is_training():
    pass  # change BATCH_SIZE

def conv3x3(in_planes, out_planes, stride=1, dilation = 1):
    """3x3 convolution with padding"""
    padding = 1 + (dilation -1 ) #derived to ensure consistent size
    return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
                     padding=padding, bias=True, dilation = dilation)

def conv1x1(in_planes, out_planes, stride=1):
    """1x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=True)
def conv1x64(in_planes, out_planes, stride=1, groups=1):
    """64x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=(1,64), stride=stride, groups=groups, bias=True)
def conv64x1(in_planes, out_planes, stride=1, groups=1):
    """64x1 convolution"""
    return nn.Conv2d(in_planes, out_planes, kernel_size=(64,1), stride=stride, groups=groups, bias=True)


In [0]:
from torch.utils.data import IterableDataset
class Sequences(IterableDataset):
  def __init__(self, final_input):
    self.data = final_input # [(prim_msa,tertiary,binnedphi,binnedpsi,seqlen),(),....,()]
    self.bins= np.arange(float(2),float(22), float(20/64))
  def get_tile(self, p, i , j):
    prim_msa_i = self.data[p][0][i:i+64]
    prim_msa_j = self.data[p][0][j:j+64]
    return seq2block(prim_msa_i,prim_msa_j)
  def get_dmat(self, p, i, j):
    prim_msa_i = self.data[p][1][i:i+64]
    prim_msa_j = self.data[p][1][j:j+64]
    return calc_pairwise_distances(prim_msa_i,prim_msa_j,False)
  def bin_dmat(self, dmat):
    return np.digitize(dmat,self.bins)
  def __len__(self):
    return len(self.data)
  def __iter__(self):
    self.p =0
    self.i=0
    self.j=0
    return self
  def __next__(self):
    if(self.p<len(self.data)):
      a = self.get_tile(self.p,self.i,self.j)
    else:
      raise StopIteration
    a = a.permute(2,0,1)
    b_unbinned = self.get_dmat(self.p,self.i,self.j)
    b = self.bin_dmat(b_unbinned)

    phi_i = self.data[self.p][2][self.i:self.i+64].unsqueeze(1)
    phi_j = self.data[self.p][2][self.j:self.j+64].unsqueeze(1)
    psi_i = self.data[self.p][3][self.i:self.i+64].unsqueeze(1)
    psi_j = self.data[self.p][3][self.j:self.j+64].unsqueeze(1)

    ret_i = self.i
    ret_j = self.j
    ret_n = self.data[self.p][4]
    ret_p = self.p
    self.j = self.j+32
    if( self.j > len(self.data[self.p][0])-64):
      self.j = 0
      self.i += 32
    if(self.i > len(self.data[self.p][0])-64 ):
      self.i =0
      self.j =0
      self.p +=1
    if(self.p>=len(self.data)):
        raise StopIteration
    return (a,b,(phi_i,phi_j,psi_i,psi_j),(ret_i,ret_j, ret_n, ret_p))



In [0]:
train = process_file(training_file,use_gpu=False)

valid = process_file(validation_file,use_gpu=False)

test = process_file(testing_file,use_gpu=False)

In [0]:
# one hotting the primary and concatenating with MSA
for protein in train:
  one_hot = sequence_onehot(protein['primary'])
  evo_tensor = torch.Tensor(protein['evolutionary']).float()
  protein['prim_msa_cat'] = torch.cat((one_hot,evo_tensor),1)

# one hotting the primary and concatenating with MSA -- validation data
for protein in valid:
  one_hot = sequence_onehot(protein['primary'])
  evo_tensor = torch.Tensor(protein['evolutionary']).float()
  protein['prim_msa_cat'] = torch.cat((one_hot,evo_tensor),1)

# one hotting the primary and concatenating with MSA -- validation data
for protein in test:
  one_hot = sequence_onehot(protein['primary'])
  evo_tensor = torch.Tensor(protein['evolutionary']).float()
  protein['prim_msa_cat'] = torch.cat((one_hot,evo_tensor),1)

In [0]:
#bin phi/psi values
import math
for protein in train:
  bins= np.arange(float(-math.pi),float(math.pi), float((2*math.pi)/35))
  protein['phi_binned'] = np.digitize(protein['phi'],bins)
  protein['psi_binned'] = np.digitize(protein['psi'],bins)

for protein in valid:
  bins= np.arange(float(-math.pi),float(math.pi), float((2*math.pi)/35))
  protein['phi_binned'] = np.digitize(protein['phi'],bins)
  protein['psi_binned'] = np.digitize(protein['psi'],bins)

for protein in test:
  bins= np.arange(float(-math.pi),float(math.pi), float((2*math.pi)/35))
  protein['phi_binned'] = np.digitize(protein['phi'],bins)
  protein['psi_binned'] = np.digitize(protein['psi'],bins)




In [0]:
#staging input features, labels and seq len for Dataloader
final_input = [] #[(primmsa,tertiary,phibin,psibin,seqlen)]
for protein in train:
  final_input.append((protein['prim_msa_cat'],torch.Tensor(protein['tertiary']).float(),
                      torch.Tensor(protein['phi_binned']).long(),
                      torch.Tensor(protein['psi_binned']).long(),
                      protein['seq_len']))

valid_input = []
for protein in valid:
  valid_input.append((protein['prim_msa_cat'],torch.Tensor(protein['tertiary']).float(),
                      torch.Tensor(protein['phi_binned']).long(),
                      torch.Tensor(protein['psi_binned']).long(),
                      protein['seq_len']))

test_input = []
for protein in test:
  test_input.append((protein['prim_msa_cat'],torch.Tensor(protein['tertiary']).float(),
                     torch.Tensor(protein['phi_binned']).long(),
                     torch.Tensor(protein['psi_binned']).long(),
                     protein['seq_len']))

In [15]:
print(len(final_input))
print(len(valid_input))
print(len(test_input))
print(final_input[0][2].shape)
print(final_input[0][3].shape)

11
11
11
torch.Size([100])
torch.Size([100])


In [0]:
train_data = Sequences(final_input)
dataloader = torch.utils.data.DataLoader(train_data, batch_size =1, num_workers = 0)

valid_data = Sequences(valid_input)
valid_dataloader = torch.utils.data.DataLoader(valid_data, batch_size =4, num_workers = 1)

test_data = Sequences(test_input)
test_dataloader = torch.utils.data.DataLoader(test_data, batch_size =1, num_workers = 1)


In [0]:
def train_model(model, n_epochs):
    train_losses = []
    # to track the validation loss as the model trains
    valid_losses = []
    # to track the average training loss per epoch as the model trains
    avg_train_losses = []
    # to track the average validation loss per epoch as the model trains
    avg_valid_losses = []
    
    
    for e in range(1,n_epochs+1):
        model.train()
        ###################
        # train the model #
        ###################
        #(a,b,(phi_i,phi_j,psi_i,psi_j),(ret_i,ret_j, ret_n, ret_p))
        for(batch,(data,target,phipsi,(_,_,_,_))) in enumerate(dataloader):
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output[0], target)
            loss += criterion(output[1],phipsi[0]) #phi i
            loss += criterion(output[2],phipsi[1]) #phi j
            loss += criterion(output[3],phipsi[2]) #psi i
            loss += criterion(output[4],phipsi[3]) #psi j
            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())
        ######################    
        # validate the model #
        ######################
        model.eval() # prep model for evaluation
  
        for(batch,(data, target,phipsi,(_,_,_,_))) in enumerate(valid_dataloader):
            print(batch)
            # forward pass: compute predicted outputs by passing inputs to the model
            output = model(data)
            # calculate the loss
            loss = criterion(output[0], target)
            loss += criterion(output[1],phipsi[0]) #phi i
            loss += criterion(output[2],phipsi[1]) #phi j
            loss += criterion(output[3],phipsi[2]) #psi i
            loss += criterion(output[4],phipsi[3]) #psi j
            # record validation loss
            valid_losses.append(loss.item())
        train_loss = np.average(train_losses)
        valid_loss = np.average(valid_losses)
        avg_train_losses.append(train_loss)
        avg_valid_losses.append(valid_loss)
        epoch_len = len(str(n_epochs))
        
        print_msg = (f'[{e:>{epoch_len}}/{n_epochs:>{epoch_len}}] ' +
                      f'train_loss: {train_loss:.5f} ' +
                      f'valid_loss: {valid_loss:.5f}')

        print(print_msg)
          
        # clear lists to track next epoch
        train_losses = []
        valid_losses = []
    return  model, avg_train_losses, avg_valid_losses

In [0]:
model = ResNet(ng_val)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())



In [41]:
model, train_losses, valid_losses = train_model(model,1)

0
1
2
3
4
5
6
7
8
9
10
[1/1] train_loss: 17.21141 valid_loss: 16.50090


In [0]:
import math
def Average(lst): 
    return sum(lst) / len(lst) 
def test_model(model):
  test_dataloader = torch.utils.data.DataLoader(test_data, batch_size =1, num_workers = 1)
  all_preds = []
  p_len = {}
  for (batch,(data,target,phipsi,(i_off,j_off,seq_len,p_num))) in enumerate(test_dataloader):
    p_len[p_num.item()] = seq_len.item()
  for i in range(0,len(test_data)):
    temp = [[1 for j in range(MAX_SEQUENCE_LENGTH)] for i in range(MAX_SEQUENCE_LENGTH)]
    for j in range(0,MAX_SEQUENCE_LENGTH):
      for k in range(0,MAX_SEQUENCE_LENGTH):
        temp[j][k] = (-1,0,0) # (true label , summed probability, contact or not)
    all_preds.append(temp)
  for (batch,(data,target,phipsi,(i_off,j_off,seq_len,p_num))) in enumerate(test_dataloader):
    output = model(data)
    output = output[0]
    output = output.squeeze(0)
    target = target.squeeze(0)
    summed = torch.sum(output[0:20],dim=0)  
    for i in range(0,64):
      for j in range(0,64):
        if(all_preds[p_num][i+i_off][j+j_off][0] == -1):
          if(target[i][j]<20):
            x=1
          else:
            x=0
          y = summed[i][j].item()
          if(all_preds[p_num][i+i_off][j+j_off][1]>=.5):
            z = 1
          else:
            z = 0
          all_preds[p_num][i+i_off][j+j_off] = (x,y,z)

  accuracy = []
  accuracy2 = []
  accuracy5 = []
  for i,preds in enumerate(all_preds):
    num_correct = 0
    total = 0
    for j in range(0,min(95,p_len[i])):
      for k in range(0,min(95,p_len[i])):
        total+=1
        if(preds[j][k][0]==preds[j][k][2]):
          num_correct+=1
    accuracy.append(num_correct/total)
    flat = [j for sub in preds[0:p_len[i]][0:p_len[i]] for j in sub] 
    flat.sort(key=lambda x:x[1])
    flat_2 = flat[0:math.floor(len(flat)/2)]
    flat_5 = flat[0:math.floor(len(flat)/5)]
    num_correct=0
    total=0
    for j in range(0,len(flat_2)):
      total+=1
      if(flat_2[j][0]==flat_2[j][2]):
        num_correct+=1
    accuracy2.append(num_correct/total)
    num_correct=0
    total=0
    for j in range(0,len(flat_5)):
      total+=1
      if(flat_5[j][0]==flat_5[j][2]):
        num_correct+=1
    accuracy5.append(num_correct/total)
  print("N-Accuracy: ",Average(accuracy))
  print("N/2-Accuracy: ",Average(accuracy2))
  print("N/5-Accuracy: ",Average(accuracy5))


Initial testing is for tertiary label prediction

In [45]:
test_model(model)

N-Accuracy:  0.6110701547315175
N/2-Accuracy:  0.5905577225203582
N/5-Accuracy:  0.502278763983636
