# Secondary protein structure prediction using Target Embedding Autoencoders

## Load data and packages

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from torch.nn.functional import relu, elu, relu6, sigmoid, tanh, softmax
from torch.autograd import Variable
import torch.nn as nn
import sklearn.model_selection as model_selection
import pickle
from typing import *
import matplotlib
import matplotlib.pyplot as plt
from IPython.display import Image, display, clear_output
import numpy as np

import seaborn as sns
import pandas as pd
sns.set_style("whitegrid")
try:
    from plotting import make_vae_plots
except Exception as ex:
    print(f"If using Colab, you may need to upload `plotting.py`. \
          \nIn the left pannel, click `Files > upload to session storage` and select the file `plotting.py` from your computer \
          \n---------------------------------------------")
    print(ex)

In [None]:
# Loading data
# Insert your path to dataset (drive_path)
drive_path = 'C:/Users/augus/OneDrive - Danmarks Tekniske Universitet/Studiemappe/9. semester/Deep Learning/project/'
traindata = pickle.load(open(drive_path+'train_reduced.pickle','rb'))
X = traindata[0][:,:,0:20]
y = traindata[0][:,:,57:65]
CASP12data = pickle.load(open(drive_path+'Casp12Data.pickle','rb'))
X_casp = CASP12data[0][:,:,0:20]
y_casp = CASP12data[0][:,:,57:65]
TS115data = pickle.load(open(drive_path+'TS115.pickle','rb'))
X_TS115 = TS115data[0][:,:,0:20]
y_TS115 = TS115data[0][:,:,57:65]
CB513data = pickle.load(open(drive_path+'CB513.pickle','rb'))
X_CB513 = CB513data[0][:,:,0:20]
y_CB513 = CB513data[0][:,:,57:65]
labels = traindata[1].tolist()

In [3]:
# As the lengths of the sequences are extrapolated to match the length of the longest
# sequence in the dataset. Therefore, we have to add an artificial class so that the extrapolated class is not mistaken for the first class.
if y.shape[2] == 8:
    EPzeros = np.expand_dims(np.zeros((y.shape[1])),1)
    EPzeros_casp = np.expand_dims(np.zeros((y_casp.shape[1])),1)
    EPzeros_TS115 = np.expand_dims(np.zeros((y_TS115.shape[1])),1)
    EPzeros_CB513 = np.expand_dims(np.zeros((y_CB513.shape[1])),1)
    y = np.asarray([np.hstack((y[i],EPzeros)) for i in range(y.shape[0])])
    y_casp = np.asarray([np.hstack((y_casp[i],EPzeros_casp)) for i in range(y_casp.shape[0])])
    y_TS115 = np.asarray([np.hstack((y_TS115[i],EPzeros_TS115)) for i in range(y_TS115.shape[0])])
    y_CB513 = np.asarray([np.hstack((y_CB513[i],EPzeros_CB513)) for i in range(y_CB513.shape[0])])
    for i in range(y.shape[0]):
        for j in range(y[i].shape[0]):
            if np.all(y[i][j,:] == 0):
                y[i][j,:][-1] = 1
    for i in range(y_casp.shape[0]):
        for j in range(y_casp[i].shape[0]):
            if np.all(y_casp[i][j,:] == 0):
                y_casp[i][j,:][-1] = 1
    for i in range(y_TS115.shape[0]):
        for j in range(y_TS115[i].shape[0]):
            if np.all(y_TS115[i][j,:] == 0):
                y_TS115[i][j,:][-1] = 1
    for i in range(y_CB513.shape[0]):
        for j in range(y_CB513[i].shape[0]):
            if np.all(y_CB513[i][j,:] == 0):
                y_CB513[i][j,:][-1] = 1

NameError: name 'y' is not defined

In [None]:
# Converting arrays to tensors for PyTorch
X_train = torch.tensor(X, dtype = torch.float)
y_train = torch.tensor(y, dtype = torch.float).permute(0,2,1)
X_casp = torch.tensor(X_casp,dtype = torch.float)
y_casp = torch.tensor(y_casp,dtype=torch.float).permute(0,2,1)
X_TS115 = torch.tensor(X_TS115,dtype=torch.float)
y_TS115 = torch.tensor(y_TS115,dtype=torch.float).permute(0,2,1)
X_CB513 = torch.tensor(X_CB513,dtype=torch.float)
y_CB513 = torch.tensor(y_CB513,dtype=torch.float).permute(0,2,1)
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler
from torchvision.transforms import ToTensor
from functools import reduce

# Train in batch sizes of 65 and evaluate on batch size of 65
batch_size = 65
eval_batch_size = 65
# The loaders perform the actual work
train_loader = DataLoader(y_train, batch_size=batch_size)
casp_loader = DataLoader(y_casp,batch_size=3)
TS115_loader = DataLoader(y_TS115,batch_size=batch_size)
CB513_loader = DataLoader(y_CB513,batch_size=batch_size)

## Stage 1 training: Variational Autoencoder (VAE)

### Implementing the reparameterization trick

In [None]:
import math 
import torch
from torch import nn, Tensor
from torch.nn.functional import softplus
from torch.distributions import Distribution

class ReparameterizedDiagonalGaussian(Distribution):
    """
    A distribution `N(y | mu, sigma I)` compatible with the reparameterization trick given `epsilon ~ N(0, 1)`.
    """
    def __init__(self, mu: Tensor, log_sigma:Tensor):
        assert mu.shape == log_sigma.shape, f"Tensors `mu` : {mu.shape} and ` log_sigma` : {log_sigma.shape} must be of the same shape"
        self.mu = mu
        self.sigma = log_sigma.exp()
        
    def sample_epsilon(self) -> Tensor:
        """`\eps ~ N(0, I)`"""
        return torch.empty_like(self.mu).normal_()
        
    def sample(self) -> Tensor:
        """sample `z ~ N(z | mu, sigma)` (without gradients)"""
        with torch.no_grad():
            return self.rsample()
        
    def rsample(self) -> Tensor:
        """sample `z ~ N(z | mu, sigma)` (with the reparameterization trick) """
        return self.mu + self.sigma * self.sample_epsilon() # your code
        
        
    def log_prob(self, z:Tensor) -> Tensor:
        """return the log probability: log `p(z)`"""
        logprob = -0.5 * math.log(2 * math.pi) - self.sigma.log() - 0.5 * ((z - self.mu)/self.sigma)**2
        return logprob

## Defining the VAE

In [None]:
from numbers import Number
import math
import torch
from torch.distributions import constraints
from torch.distributions.uniform import Uniform
from torch.distributions.transformed_distribution import TransformedDistribution
from torch.distributions.transforms import AffineTransform, ExpTransform
from torch.distributions.utils import broadcast_all
from torch.distributions.bernoulli import Bernoulli

channels = y_train.shape[1]
kernel_size_conv1 = 9
stride_conv1 = 1
padding_conv1 = 4

In [None]:
class VariationalAutoencoder(nn.Module):
    
    def __init__(self, input_shape:torch.Size, latent_features:int) -> None:
        super(VariationalAutoencoder, self).__init__()
        
        self.input_shape = input_shape
        self.latent_features_cat = latent_features_cat
        self.observation_features =  self.input_shape

        # Inference Network
        self.encoder = nn.Sequential(
            nn.Conv1d(in_channels=channels, out_channels=2*latent_features_cat, kernel_size=kernel_size_conv1, stride=stride_conv1, padding=padding_conv1,bias=True),
            nn.Softmax(dim = 1)
        )
        # Generative Model
        self.decoder = nn.Sequential(
            nn.Conv1d(in_channels=latent_features_cat, out_channels=9, kernel_size=kernel_size_conv1, stride=stride_conv1, padding=padding_conv1,bias=True)
        )
        
        # define the parameters of the prior, chosen as p(z) = N(0, I)
        self.register_buffer('prior_params', torch.zeros(torch.Size([1, 2*latent_features_cat])))

    def posterior(self, x:Tensor) -> Distribution:
        """return the distribution `q(x|x) = N(z | \mu(x), \sigma(x))`"""
        
        # compute the parameters of the posterior
        h_x = self.encoder(x)
        h_x = h_x.permute(1,2,0).flatten(start_dim=1).permute(1,0)
        mu, log_sigma =  h_x.chunk(2, dim=-1)
        # return a distribution `q(x|x) = N(z | \mu(x), \sigma(x))`
        return ReparameterizedDiagonalGaussian(mu, log_sigma)
    
    def prior(self, batch_size:int=1)-> Distribution:
        """return the distribution `p(z)`"""
        prior_params = self.prior_params.expand(batch_size, *self.prior_params.shape[-1:])
        mu, log_sigma = prior_params.chunk(2, dim=-1)        
        # return the distribution `p(z)`
        return ReparameterizedDiagonalGaussian(mu, log_sigma)
    
    def observation_model(self, z:Tensor) -> Distribution:
        """return the distribution `p(x|z)`"""
        px_logits1 = self.decoder(z.permute(0,2,1))
        px_logits = px_logits1
        px_logits = sigmoid(px_logits.permute(0,2,1))
        px_logits = px_logits.reshape(px_logits.shape[0]*px_logits.shape[1],-1)
        return Bernoulli(logits=px_logits), px_logits1
        

    def forward(self, x) -> Dict[str, Any]:
        """compute the posterior q(z|x) (encoder), sample z~q(z|x) and return the distribution p(x|z) (decoder)"""

        # define the posterior q(z|x) / encode x into q(z|x)
        x = x.permute(0,2,1)
        qz = self.posterior(x)
        
        # define the prior p(z)
        pz = self.prior(batch_size=x.size(0)*x.size(2))
        
        # sample the posterior using the reparameterization trick: z ~ q(z | x)
        z = qz.rsample()
        
        # define the observation model p(x|z) = B(x | g(z))
        px = self.observation_model(z.view(x.size(0),x.size(2),-1))[0]
        
        return {'px': px, 'pz': pz, 'qz': qz, 'z': z}
    
    
    def sample_from_prior(self, batch_size:int=100):
        """sample z~p(z) and return p(x|z)"""
        
        # degine the prior p(z)
        pz = self.prior(batch_size=batch_size)
        
        # sample the prior 
        z = pz.rsample()
        
        # define the observation model p(x|z) = B(x | g(z))
        px = self.observation_model(z)[0]
        
        return {'px': px, 'pz': pz, 'z': z}


## Creating variational inference module

In [None]:
def reduce(x:Tensor) -> Tensor:
    """for each datapoint: sum over all dimensions"""
    return x.view(x.size(0), -1).sum(dim=1)

class VariationalInference(nn.Module):
    def __init__(self, beta:float=1.):
        super().__init__()
        self.beta = beta
        
    def forward(self, model:nn.Module, x:Tensor) -> Tuple[Tensor, Dict]:
        
        # forward pass through the model
        outputs = model(x)
        # unpack outputs
        px, pz, qz, z = [outputs[k] for k in ["px", "pz", "qz", "z"]]
        
        # evaluate log probabilities
        log_px = reduce(px.log_prob(x.permute(2,1,0).flatten(start_dim=1).permute(1,0)))
        log_px = log_px.view(x.size(0),x.size(1))
        log_pz = reduce(pz.log_prob(z))
        log_qz = reduce(qz.log_prob(z))
        #print("log_px forward", log_px.shape, "log_pz forward", log_pz.shape, "log_qz forward", log_qz.shape)

        # compute the ELBO with and without the beta parameter: 
        # `L^\beta = E_q [ log p(x|z) - \beta * D_KL(q(z|x) | p(z))`
        # where `D_KL(q(z|x) | p(z)) = log q(z|x) - log p(z)`
        kl = log_qz - log_pz
        kl=kl.view(x.size(0),x.size(1))
        elbo = log_px - kl# <- your code here
        beta_elbo = log_px - self.beta * kl# <- your code here
        
        # loss
        loss = -beta_elbo.mean()
        
        # prepare the output
        with torch.no_grad():
            diagnostics = {'elbo': beta_elbo, 'log_px':log_px, 'kl': kl}
            
        return loss, diagnostics, outputs

In [None]:
from collections import defaultdict
# define the models, evaluator and optimizer

# Initializing the VAE with the 3 latent variables
# We choose to project the 8 classes into a three dimensional latent space.
latent_features_cat = 3
vae = VariationalAutoencoder(y_train.shape[1], latent_features_cat)

## Training the VAE

In [None]:
# Evaluator: Variational Inference
beta = 0.04
vi = VariationalInference(beta=beta)

# The Adam optimizer works well with VAEs.
optimizer_vae = torch.optim.Adam(vae.parameters(), lr=1e-3,betas=(0.85,0.95),weight_decay=1e-6)

# define dictionary to store the training curves
training_data = defaultdict(list)
casp_data = defaultdict(list)
TS115_data = defaultdict(list)
CB513_data = defaultdict(list)

epoch = 0

num_epochs = 100

In [None]:
# training..
while epoch < num_epochs:
    epoch+= 1
    print("Epoch ",epoch," of ",num_epochs)
    training_epoch_data = defaultdict(list)
    vae.train()
    
    # Go through each batch in the training dataset using the loader
    # Note that y is not necessarily known as it is here
    for x in train_loader:
        # perform a forward pass through the model and compute the ELBO
        loss, diagnostics, outputs = vi(vae, x.permute(0,2,1))
        
        optimizer_vae.zero_grad()
        loss.backward()
        optimizer_vae.step()
        
        # gather data for the current bach
        for k, v in diagnostics.items():
            training_epoch_data[k] += [v.mean().item()]
            

    # gather data for the full epoch
    for k, v in training_epoch_data.items():
        training_data[k] += [np.mean(training_epoch_data[k])]

    # Evaluate on a single batch, do not propagate gradients
    with torch.no_grad():
        vae.eval()
        # Load a single batch from the test loader
        x_casp = next(iter(casp_loader))
        
        # perform a forward pass through the model and compute the ELBO
        loss_casp, diagnostics_casp, outputs_casp = vi(vae, x_casp.permute(0,2,1))
        
        # gather data for the validation step
        for k, v in diagnostics_casp.items():
            casp_data[k] += [v.mean().item()]
        # Just load a single batch from the test loader
        x_TS115 = next(iter(TS115_loader))
        
        # perform a forward pass through the model and compute the ELBO
        loss_TS115, diagnostics_TS115, outputs_TS115 = vi(vae, x_TS115.permute(0,2,1))
        
        # gather data for the validation step
        for k, v in diagnostics_TS115.items():
            TS115_data[k] += [v.mean().item()]
        # Just load a single batch from the test loader
        x_CB513 = next(iter(CB513_loader))
        x_CB513 = x_CB513.to(device)
        
        # perform a forward pass through the model and compute the ELBO
        loss_CB513, diagnostics_CB513, outputs_CB513 = vi(vae, x_CB513.permute(0,2,1))
        
        # gather data for the validation step
        for k, v in diagnostics_CB513.items():
            CB513_data[k] += [v.mean().item()]
    
    # Reproduce the figure from the begining of the notebook, plot the training curves and show latent samples
    make_vae_plots(vae, X, y, outputs, training_data, casp_data,TS115_data,CB513_data)
plt.close()

# Stage 2: CNN

In [None]:
# As we have 20 different amino acids we have 20 channels
channels = 20
# Defining sizes for the first layer
kernel_size_conv1 = 15
padding_conv1 = 7
stride_conv1 = 1
# second layer
kernel_size_conv2 = 9
padding_conv2 = 4
stride_conv2 = 1
# Third layer
kernel_size_conv3 = 5
padding_conv3 = 2
stride_conv3 = 1

## CNN architecture

In [None]:
# Using dropout and batchnorm in all layers and ReLU except in the dense layer where sigmoid is implemented
class Net(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Sequential(
            nn.Conv1d(in_channels=channels,out_channels=25,kernel_size=kernel_size_conv1 , stride=stride_conv1, padding=padding_conv1),
            nn.ReLU(),
            nn.BatchNorm1d(25),
            nn.Dropout(p=0.5))

        self.conv2 = nn.Sequential(
            nn.Conv1d(in_channels=25, out_channels=35, kernel_size=kernel_size_conv2, stride=stride_conv2, padding=padding_conv2),
            nn.ReLU(),
            nn.BatchNorm1d(35),
            nn.Dropout(p=0.5))

        self.conv3 = nn.Sequential(
            nn.Conv1d(in_channels=35,out_channels=40, kernel_size=kernel_size_conv3, stride=stride_conv3, padding=padding_conv3),
            nn.ReLU(),
            nn.BatchNorm1d(40),
            nn.Dropout(p=0.5))

        self.fc1_encode1 = nn.Sequential(
            nn.Conv1d(in_channels=40,out_channels=latent_features_cat,kernel_size=1,stride=1,padding=0,bias=True),
            nn.Sigmoid(),
            nn.BatchNorm1d(latent_features_cat)
        )
    # Forwarding the data
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = self.fc1_encode1(x)
        return x
net = Net()

## Training the CNN

In [None]:
# Defining criterion and optimizer
import torch.optim as optim
# We use the MSELoss function as we have a continuous output.
criterion1 = nn.MSELoss()
# Again, we use the Adam optimizer
optimizer_net = optim.Adam(net.parameters(), lr=0.001,betas=(0.85,0.95),weight_decay=1e-6)

In [None]:
class MyDataset(Dataset):
    def __init__(self, X, y):
        self.data = X
        self.targets = y

    def __getitem__(self, index):
      x = self.data[index]
      y = self.targets[index]
      return x, y
    
    def __len__(self):
      return len(self.data)
batch_size = 65
TrainLoader = DataLoader(MyDataset(X_train,y_train),batch_size=batch_size)
CASPLoader = DataLoader(MyDataset(X_casp,y_casp),batch_size=3)
TS115Loader = DataLoader(MyDataset(X_TS115,y_TS115),batch_size=batch_size)
CB513Loader = DataLoader(MyDataset(X_CB513,y_CB513),batch_size=batch_size)

In [None]:
# Creating list to store training and test losses and accuracies
train_loss = []
train_accuracy = []
casp_loss = []
casp_accuracy = []
TS115_loss = []
TS115_accuracy = []
CB513_loss = []
CB513_accuracy = []
### Training
num_epoch = 100

In [None]:
# We need to use the encoder from the VAE to encode the targets into latent features.
# We don¨t want the VAE to train in this step. therefore we "freeze" it.
vae.eval()
# Training
for epoch in range(num_epoch):  # loop over the dataset multiple times
    print('Epoch ',epoch+1,' of ',num_epoch)
    running_loss = 0.0
    net.train()
    total = 0
    correct = 0
    for data in TrainLoader:
        # get the inputs
        inputs, labels = data
        # Encoding the target values
        outputs = vae(labels.permute(0,2,1))
        z = outputs["z"]
        inputs, labels_latent = Variable(inputs.permute(0,2,1)), Variable(z)

        # zero the parameter gradients
        optimizer_net.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
        outputs = outputs.permute(1,0,2)
        outputs = torch.flatten(outputs, start_dim=1)
        outputs = outputs.permute(1,0)
        # Calculating the loss between the actual latent variables and predicted latent variables.
        loss = criterion1(outputs,labels_latent)
        loss.backward()
        running_loss += loss.data.numpy()
        optimizer_net.step()
    train_loss.append(running_loss/len(TrainLoader))
    # Validating on test set
    net.eval()
    with torch.no_grad():
        running_loss = 0.0
        print('Evaluating on CASP')
        correct = 0
        total = 0
        for data in CASPLoader:
            inputs, labels = data
            outputs = vae(labels.permute(0,2,1))
            z = outputs["z"]
            labels_latent = z
            #print(labels)
            outputs = net(Variable(inputs.permute(0,2,1)))
            outputs = outputs.permute(1,0,2)
            outputs = torch.flatten(outputs, start_dim=1)
            outputs = outputs.permute(1,0)
            running_loss += criterion1(outputs,labels_latent)
        casp_loss.append(running_loss.data.numpy()/len(CASPLoader))
        running_loss = 0.0
        correct = 0
        total = 0
        print('Evaluating on TS115')
        for data in TS115Loader:
            inputs, labels = data
            outputs = vae(labels.permute(0,2,1))
            z = outputs["z"]
            labels_latent = z
            outputs = net(Variable(inputs.permute(0,2,1)))
            outputs = outputs.permute(1,0,2)
            outputs = torch.flatten(outputs, start_dim=1)
            outputs = outputs.permute(1,0)
            running_loss += criterion1(outputs,labels_latent)
        TS115_loss.append(running_loss.data.numpy()/len(TS115Loader))
        running_loss = 0.0
        correct = 0
        total = 0
        print('Evaluating on CB513')
        for data in CB513Loader:
            inputs, labels = data
            outputs = vae(labels.permute(0,2,1))
            z = outputs["z"]
            labels_latent = z
            outputs = net(Variable(inputs.permute(0,2,1)))
            outputs = outputs.permute(1,0,2)
            outputs = torch.flatten(outputs, start_dim=1)
            outputs = outputs.permute(1,0)
            running_loss += criterion1(outputs,labels_latent)
        CB513_loss.append(running_loss.data.numpy()/len(CB513Loader))
plt.plot(range(1,num_epoch+1),train_loss,label='Train')
plt.plot(range(1,num_epoch+1),casp_loss,label='CASP12')
plt.plot(range(1,num_epoch+1),TS115_loss,label='TS115')
plt.plot(range(1,num_epoch+1),CB513_loss,label='CB513')
plt.legend()
plt.suptitle('Stage 2: Loss curves')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.savefig('CNNLoss')
plt.close()

# Stage 3: Traing both the VAE and CNN

In [None]:
### Stage 3 training: Training both the CNN and VAE simultaneously
num_epochs = 100
epoch = 0
# Defining new loss function: NLLLoss
# Not using PyTorch's cross entropy as the decoded latent variables already will be given as log probabilities.
def criterion2(input, target):
    labels = torch.argmax(target,2)
    return nn.NLLLoss()(input, labels)

# Creating lists to store losses and accuracies
train_vaeLoss = []
train_CNNLoss = []
train_FinalLoss = []
train_accuracy1 = []
CASP_vaeLoss = []
CASP_CNNLoss = []
CASP_FinalLoss = []
CASP_accuracy1 = []
TS115_vaeLoss = []
TS115_CNNLoss = []
TS115_FinalLoss = []
TS115_accuracy1 = []
CB513_vaeLoss = []
CB513_CNNLoss = []
CB513_FinalLoss = []
CB513_accuracy1 = []
# training..
while epoch < num_epochs:
    epoch+= 1
    print('Epoch ',epoch, ' of ', num_epochs)
    training_epoch_data = defaultdict(list)
    vae.train()
    net.train()
    running_loss_vae = 0.0
    running_loss_CNN = 0.0
    running_loss_FinalLoss = 0.0
    total = 0
    correct = 0
    print('Training...')
    for data in TrainLoader:
        inputs,labels = data
        loss_vae, diagnostics, outputs = vi(vae, labels.permute(0,2,1))
        running_loss_vae += loss_vae.data.numpy()
        px, pz, qz, z = [outputs[k] for k in ["px", "pz", "qz", "z"]]
        inputs, labels_net = Variable(inputs.permute(0,2,1)), Variable(z)
        outputs_net = net(inputs)
        px_pred = vae.observation_model(outputs_net.permute(0,2,1))[0].log_prob(labels.permute(0,2,1).permute(2,1,0).flatten(start_dim=1).permute(1,0)).view(labels.shape[0],labels.shape[2],labels.shape[1])
        px_pred1 = vae.observation_model(outputs_net.permute(0,2,1))[0].log_prob(labels.permute(0,2,1).reshape(labels.shape[0]*labels.shape[2],labels.shape[1])).argmax(dim=1)
        target = labels.permute(0,2,1).reshape(labels.shape[0]*labels.shape[2],labels.shape[1]).argmax(dim=1)
        total += target.shape[0]
        correct += torch.sum(px_pred1==target).data.numpy()
        CNNLoss = criterion2(px_pred.permute(0,2,1),labels.detach().permute(0,2,1))
        optimizer_net.zero_grad()
        CNNLoss.backward
        optimizer_net.step()
        running_loss_CNN += CNNLoss.data.numpy()    
        FinalLoss = CNNLoss.detach() + loss_vae
        running_loss_FinalLoss += FinalLoss.data.numpy()
        optimizer_vae.zero_grad()
        FinalLoss.backward()
        optimizer_vae.step()
        # gather data for the current bach
        for k, v in diagnostics.items():
            training_epoch_data[k] += [v.mean().item()]
    train_vaeLoss.append(running_loss_vae/len(TrainLoader))
    train_CNNLoss.append(running_loss_CNN/len(TrainLoader))
    train_FinalLoss.append(running_loss_FinalLoss/len(TrainLoader))
    train_accuracy1.append(correct/total)

    # Evaluate on a single batch with no backpropagation
    with torch.no_grad():
        vae.eval()
        net.eval()
        
        print('Testing on CASP12')
        correct = 0
        total = 0
        # Loading a single batch from the test loader
        x_casp,y_casp = next(iter(CASPLoader))
        
        # perform a forward pass through the model and compute the ELBO
        loss_casp, diagnostics_casp, outputs_casp = vi(vae, y_casp.permute(0,2,1))
        outputs_casp = net(Variable(x_casp.permute(0,2,1)))
        px_pred = vae.observation_model(outputs_casp.permute(0,2,1))[0].log_prob(y_casp.permute(0,2,1).permute(2,1,0).flatten(start_dim=1).permute(1,0)).view(y_casp.shape[0],y_casp.shape[2],y_casp.shape[1])
        px_pred1 = vae.observation_model(outputs_casp.permute(0,2,1))[0].log_prob(y_casp.permute(0,2,1).reshape(y_casp.shape[0]*y_casp.shape[2],y_casp.shape[1])).argmax(dim=1)
        target = y_casp.permute(0,2,1).reshape(y_casp.shape[0]*y_casp.shape[2],y_casp.shape[1]).argmax(dim=1)
        total += target.shape[0]
        correct += torch.sum(px_pred1==target).data.numpy()
        CASP_vaeLoss.append(loss_casp.data.numpy()/len(CASPLoader))
        CNNLoss = criterion2(px_pred.permute(0,2,1),y_casp.permute(0,2,1)).data.numpy()/len(CASPLoader)
        CASP_CNNLoss.append(CNNLoss)
        CASP_FinalLoss.append(CNNLoss + loss_casp.data.numpy()/len(CASPLoader))
        CASP_accuracy1.append(correct/total)
        correct = 0
        total = 0
        # Loading a single batch from the test loader
        x_TS115,y_TS115 = next(iter(TS115Loader))

        # perform a forward pass through the model and compute the ELBO
        loss_TS115, diagnostics_TS115, outputs_TS115 = vi(vae, y_TS115.permute(0,2,1))
        outputs_TS115 = net(Variable(x_TS115.permute(0,2,1)))
        px_pred = vae.observation_model(outputs_TS115.permute(0,2,1))[0].log_prob(y_TS115.permute(0,2,1).permute(2,1,0).flatten(start_dim=1).permute(1,0)).view(y_TS115.shape[0],y_TS115.shape[2],y_TS115.shape[1])
        px_pred1 = vae.observation_model(outputs_TS115.permute(0,2,1))[0].log_prob(y_TS115.permute(0,2,1).reshape(y_TS115.shape[0]*y_TS115.shape[2],y_TS115.shape[1])).argmax(dim=1)
        target = y_TS115.permute(0,2,1).reshape(y_TS115.shape[0]*y_TS115.shape[2],y_TS115.shape[1]).argmax(dim=1)
        total += target.shape[0]
        correct += torch.sum(px_pred1==target).data.numpy()
        TS115_vaeLoss.append(loss_TS115.data.numpy()/len(TS115Loader))
        CNNLoss = criterion2(px_pred.permute(0,2,1),y_TS115.permute(0,2,1)).data.numpy()/len(TS115Loader)
        TS115_CNNLoss.append(CNNLoss)
        TS115_FinalLoss.append(CNNLoss + loss_TS115.data.numpy()/len(TS115Loader))
        TS115_accuracy1.append(correct/total)
        
        print('Testing on CB513')
        correct = 0
        total = 0
        # Loading a single batch from the test loader
        x_CB513,y_CB513= next(iter(CB513Loader))

        # perform a forward pass through the model and compute the ELBO
        loss_CB513, diagnostics_CB513, outputs_CB513 = vi(vae, y_CB513.permute(0,2,1))
        outputs_CB513 = net(Variable(x_CB513.permute(0,2,1)))
        px_pred = vae.observation_model(outputs_CB513.permute(0,2,1))[0].log_prob(y_CB513.permute(0,2,1).permute(2,1,0).flatten(start_dim=1).permute(1,0)).view(y_CB513.shape[0],y_CB513.shape[2],y_CB513.shape[1])
        px_pred1 = vae.observation_model(outputs_CB513.permute(0,2,1))[0].log_prob(y_CB513.permute(0,2,1).reshape(y_CB513.shape[0]*y_CB513.shape[2],y_CB513.shape[1])).argmax(dim=1)
        target = y_CB513.permute(0,2,1).reshape(y_CB513.shape[0]*y_CB513.shape[2],y_CB513.shape[1]).argmax(dim=1)
        total += target.shape[0]
        correct += torch.sum(px_pred1==target).data.numpy()
        CB513_vaeLoss.append(loss_CB513.data.numpy()/len(CB513Loader))
        CNNLoss = criterion2(px_pred.permute(0,2,1),y_CB513.permute(0,2,1)).data.numpy()/len(CB513Loader)
        CB513_CNNLoss.append(CNNLoss)
        CB513_FinalLoss.append(CNNLoss + loss_CB513.data.numpy()/len(CB513Loader))
        CB513_accuracy1.append(correct/total)
plt.plot(range(1,num_epoch+1),train_CNNLoss,label='Train')
plt.plot(range(1,num_epoch+1),CASP_CNNLoss,label='CASP12')
plt.plot(range(1,num_epoch+1),TS115_CNNLoss,label='TS115')
plt.plot(range(1,num_epoch+1),CB513_CNNLoss,label='CB513')
plt.legend()
plt.suptitle('Stage 3: Loss curves')
plt.xlabel('epochs')
plt.ylabel('loss')
plt.savefig('CNNLossStage3')
plt.close()
plt.plot(range(1,num_epoch+1),train_accuracy1,label='Train')
plt.plot(range(1,num_epoch+1),CASP_accuracy1,label='CASP12')
plt.plot(range(1,num_epoch+1),TS115_accuracy1,label='TS115')
plt.plot(range(1,num_epoch+1),CB513_accuracy1,label='CB513')
plt.legend()
plt.suptitle('Stage 3: Accuracies')
plt.xlabel('epochs')
plt.ylabel('accuracy')
plt.savefig('CNNaccuracyStage3')
plt.close()