In [1]:
import numpy as np
import pickle
import einops

# from main import VAE

In [2]:
from __future__ import print_function
import argparse
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import pickle

class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.fc1 = nn.Linear(1200, 400)
        self.fc21 = nn.Linear(400, 20)
        self.fc22 = nn.Linear(400, 20)
        self.fc3 = nn.Linear(20, 400)
        self.fc4 = nn.Linear(400, 1200)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.sigmoid(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x.view(-1, 1200))
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

    

In [3]:
model = VAE().to("cpu")
model.load_state_dict(torch.load("model_chkpts/vae_e100.chkpt"))

<All keys matched successfully>

In [4]:
# Load signal peptide data
train_tensors = torch.load('train99_tensors.pt')
valid_tensors = torch.load('valid99_tensors.pt')


# Create new dataloaders
test_loader = torch.utils.data.DataLoader(
    valid_tensors, batch_size=128, shuffle=False)



In [5]:
def loss_function(recon_x, x, mu, logvar):
    BCE = F.binary_cross_entropy(recon_x, x.view(-1, 1200), reduction='sum')

    # see Appendix B from VAE paper:
    # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
    # https://arxiv.org/abs/1312.6114
    # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
#     KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

    return BCE #+ KLD

In [6]:
def return_accuracy(recon_batch, data):
    b, r, data_c = data.shape
    recon_batch = einops.rearrange(recon_batch, "b (l c) -> b l c", c=data_c)
    recon_inds =  torch.argmax(recon_batch, dim=2)
    true_inds = torch.argmax(data, dim=2)

    padmask = true_inds!=1

    nopad_recon = recon_inds[padmask]
    nopad_true = true_inds[padmask]

    acc = ((nopad_recon==nopad_true).sum()).item() / len(nopad_true)

    return acc

In [7]:
def CELoss(data, recon_batch):

    data_c = data.shape[-1]
    recon_batch = einops.rearrange(recon_batch, "b (l c) -> b l c", c=data_c)

    weights = torch.tensor([1.0 if i != 1 else 0 for i in range(24)])
    loss = nn.CrossEntropyLoss(weight = weights, reduction="mean")

    num_seqs, num_pos, channels = data.shape

    total_loss = 0
    count = 0

    all_loss, all_count = 0, 0

    for seq in range(num_seqs):
        for pos in range(num_pos):


            val, true_ind = torch.max(data[seq, pos, :], 0)

            target=torch.tensor([true_ind])
    #         print(data[seq,pos,:])
    #         print(f"data : {val:0.2f} \t {true_ind:d}")

            val, ind = torch.max(recon_batch[seq,pos,:], 0)
    #         print(f"data : {val:0.2f} \t {ind:d}")
    #         print(recon_batch[seq,pos,:])
            inputs = recon_batch[seq,pos,:].view(1,24)


            _loss = loss(inputs, target)

            if true_ind != 1:
                total_loss += _loss
                count += 1

            all_loss += _loss
            all_count += count

#     print(total_loss / count, total_loss, count)
#     print(all_loss / all_count, all_loss, all_count)

    return (total_loss / count).item()

In [8]:
model.eval()
test_loss = 0

device="cpu"
losses = []
accs = []
with torch.no_grad():
    for i, data in enumerate(test_loader):
        data = data.to(device)
        recon_batch, mu, logvar = model(data)
        
        acc = return_accuracy(recon_batch, data)
        accs.append(acc)
        test_loss += loss_function(recon_batch, data, mu, logvar).item()
        CE_loss = CELoss(data, recon_batch)
        losses.append(CE_loss)
#         if i == 0:
#             n = min(data.size(0), 8)
#             comparison = torch.cat([data[:n],
#                                   recon_batch.view(args.batch_size, 1, 50, 24)[:n]])
#             save_image(comparison.cpu(),
#                      'results/reconstruction_' + str(epoch) + '.png', nrow=n)

In [9]:
# b, r, data_c = data.shape
# recon_batch = einops.rearrange(recon_batch, "b (l c) -> b l c", c=data_c)
# recon_batch.shape

In [10]:
print(np.average(accs), np.std(accs))

0.4535704440063153 0.016660600485364174


In [None]:
print(f"The average accuracy is {-np.average(accs):0.2f}")

In [11]:
print(np.average(losses), np.std(losses))

2.884387585249814 0.011155884661706113


In [12]:
print(f"The average log likelihood is {-np.average(losses):0.2f}")

The average log likelihood is -2.88
