<a href="https://colab.research.google.com/github/angzhifan/Importance-Weighted-Autoencoders/blob/main/IWAE_VAE.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This is a PyTorch implementation of the IWAE model and VAE model in the paper *Importance Weighted Autoencoders* by Yuri Burda, Roger Grosse & Ruslan Salakhutdinov

In [None]:
#IWAE & VAE
#Angzhi (Andrew) Fan, fana@uchicago.edu
#Oct 5, 2020

import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np
import math

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

# VAE or IWAE with one layer
class VAE_1(nn.Module):
    
    def __init__(self, t):
        super(VAE_1, self).__init__()
        self.k = t
        self.fc1 = nn.Linear(28*28, 200)
        self.fc2 = nn.Linear(200, 200)
        self.fc3_mu = nn.Linear(200, 50)
        self.fc3_sigma = nn.Linear(200, 50)
        self.fc4 = nn.Linear(50, 200)
        self.fc5 = nn.Linear(200, 200)
        self.fc6 = nn.Linear(200,28*28)
        
        
    def forward(self, x):
        x = x.view(-1,1,28*28)
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        mu = self.fc3_mu(x).view(-1,1,50)
        log_sigma = self.fc3_sigma(x).view(-1,1,50)
        eps = torch.randn_like(mu.repeat(1,self.k,1))
        x = mu.repeat(1,self.k,1) + torch.exp(log_sigma.repeat(1,self.k,1))*eps
        x = torch.tanh(self.fc4(x))
        x = torch.tanh(self.fc5(x))
        x = self.fc6(x)
        return x, mu, log_sigma, eps
    


cpu


In [None]:
# VAE or IWAE with two layers
class VAE_2(nn.Module):
    
    def __init__(self, t):
        super(VAE_2, self).__init__()
        self.k = t
        self.fc1 = nn.Linear(28*28, 200)
        self.fc2 = nn.Linear(200, 200)
        self.fc3_mu = nn.Linear(200, 100)
        self.fc3_sigma = nn.Linear(200, 100)
        self.fc4 = nn.Linear(100, 100)
        self.fc5 = nn.Linear(100, 100)
        self.fc6_mu = nn.Linear(100,50)
        self.fc6_sigma = nn.Linear(100,50)
        self.fc7 = nn.Linear(50,100)
        self.fc8 = nn.Linear(100,100)
        self.fc9_mu = nn.Linear(100,100)
        self.fc9_sigma = nn.Linear(100,100)
        self.fc10 = nn.Linear(100,200)
        self.fc11 = nn.Linear(200,200)
        self.fc12 = nn.Linear(200,28*28)
        
        
    def forward(self, x):
        x = x.view(-1,1,28*28)
        x = torch.tanh(self.fc1(x))
        x = torch.tanh(self.fc2(x))
        mu1 = self.fc3_mu(x)
        log_sigma1 = self.fc3_sigma(x)
        eps1 = torch.randn_like(mu1.repeat(1,self.k,1))
        h1 = mu1.repeat(1,self.k,1) + torch.exp(log_sigma1.repeat(1,self.k,1))*eps1
        x = torch.tanh(self.fc4(h1))
        x = torch.tanh(self.fc5(x))
        mu2 = self.fc6_mu(x)
        log_sigma2 = self.fc6_sigma(x)
        eps2 = torch.randn_like(mu2)
        h2 = mu2 + torch.exp(log_sigma2)*eps2
        x = torch.tanh(self.fc7(h2))
        x = torch.tanh(self.fc8(x))
        mu3 = self.fc9_mu(x)
        log_sigma3 = self.fc9_sigma(x)
        x = torch.tanh(self.fc10(h1))
        x = torch.tanh(self.fc11(x))
        x = self.fc12(x)
        return x,mu1,mu2,mu3,log_sigma1,log_sigma2,log_sigma3,eps1,eps2
    


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
def load_data():
    train_file = '/content/drive/My Drive/dataset/BinaryMNIST/binarized_mnist_train.amat'
    valid_file = '/content/drive/My Drive/dataset/BinaryMNIST/binarized_mnist_valid.amat'
    test_file = '/content/drive/My Drive/dataset/BinaryMNIST/binarized_mnist_test.amat'
    mnist_train = np.concatenate([np.loadtxt(train_file),np.loadtxt(valid_file)])
    mnist_test = np.loadtxt(test_file)
    return mnist_train, mnist_test

mnist_train, mnist_test = load_data()
print(mnist_train.shape)
print(mnist_test.shape)

(60000, 784)
(10000, 784)


In [None]:
def test_function_1(net):
    testloader = torch.utils.data.DataLoader(mnist_test, batch_size=20,shuffle=False)
    nll = 0.0
    A_u = torch.zeros(50)
    for i,data in enumerate(testloader, 0):
      with torch.no_grad():
        test = data.view(-1,1,28*28).to(device)
        output = net(test.float())
        
        # stochastic layer
        eps = torch.randn_like(output[2].repeat(1,5000,1))
        h = output[1].repeat(1,5000,1) + torch.exp(output[2].repeat(1,5000,1))*eps

        # output of x using the new epsilon
        output_x = torch.tanh(net.fc4(h))
        output_x = torch.tanh(net.fc5(output_x))
        output_x = net.fc6(output_x)
        log_prob_condi = torch.sum(output_x*test.repeat(1,5000,1)-torch.log(1+torch.exp(output_x)), 2)

        # log weights, unnormalized
        log_weights = log_prob_condi-(h*h).sum(2)/2+(eps*eps).sum(2)/2+output[2].repeat(1,5000,1).sum(2)

        # estimate log likelihood using L_5000
        L_5000 = log_weights.max(1)[0].mean()+torch.log(torch.exp(log_weights
                        -log_weights.max(1)[0].view(-1,1)).mean(1)).mean()
        nll -= L_5000.item()
        A_u += output[1].view(-1,50).var(0).cpu()
    return nll/500, sum(A_u.detach().numpy()/500>0.01)

print("Finished loading test function 1")

Finished loading test function 1


In [None]:
def test_function_2(net):
    testloader = torch.utils.data.DataLoader(mnist_test, batch_size=20,shuffle=False)
    nll = 0.0
    A_u_1 = torch.zeros(100)
    A_u_2 = torch.zeros(50)
    for i,data in enumerate(testloader, 0):
      with torch.no_grad():
        test = data.view(-1,1,28*28).to(device)
        output = net(test.float())
        
        # stochastic layer sampling
        eps1 = torch.randn_like(output[1].repeat(1,5000,1))
        # stochastic layer h1
        h1 = output[1].repeat(1,5000,1) + torch.exp(output[4].repeat(1,5000,1))*eps1
        
        x = torch.tanh(net.fc4(h1))
        x = torch.tanh(net.fc5(x))

        # stochastic layer h2
        mu2 = net.fc6_mu(x)
        log_sigma2 = net.fc6_sigma(x)
        eps2 = torch.randn_like(mu2)
        
        x = torch.tanh(net.fc10(h1))
        x = torch.tanh(net.fc11(x))
        x = net.fc12(x)

        # log conditional prob
        log_prob_condi = torch.sum(x*test.repeat(1,5000,1), 2)-torch.sum(torch.log(1+torch.exp(x)), 2)

        # log weights, unnormalized
        h2 = (mu2+torch.exp(log_sigma2)*eps2)
        x = torch.tanh(net.fc7(h2))
        x = torch.tanh(net.fc8(x))
        mu3 = net.fc9_mu(x)
        log_sigma3 = net.fc9_sigma(x)
        h1 = h1-mu3
        log_p_h1_h2 = -(h1*h1/torch.exp(2*log_sigma3)).sum(2)/2-log_sigma3.sum(2)
        log_q_h1_x = -(eps1*eps1).sum(2)/2-output[4].repeat(1,5000,1).sum(2)
        log_q_h2_h1 = -(eps2*eps2).sum(2)/2-log_sigma2.sum(2)
        log_weights = log_prob_condi+log_p_h1_h2-(h2*h2).sum(2)/2-log_q_h1_x-log_q_h2_h1


        # estimate log likelihood using L_5000
        L_5000 = log_weights.max(1)[0].mean()+torch.log(torch.exp(log_weights
                        -log_weights.max(1)[0].view(-1,1)).mean(1)).mean()
        nll -= L_5000.item()
        A_u_1 += output[1].view(-1,100).var(0).cpu()
        A_u_2 += mu2[:,0,:].var(0).cpu()
    return nll/500, (sum(A_u_1.detach().numpy()/500>0.01),sum(A_u_2.detach().numpy()/500>0.01))

print("Finished loading test function 2")


Finished loading test function 2


In [None]:
mod = 'iwae'
layer = 2
batch_size = 20
k = 5
continued = 0

if mod == 'iwae':
    #this index_vec will be used in the training of iwae
    index_vec = torch.tensor([i*k for i in range(batch_size)]).to(device)
elif mod == 'vae':
    pass
else:
    raise Exception("Invalid Mode")

if layer ==1:
    net = VAE_1(k)
elif layer == 2:
    net = VAE_2(k)
else:
    raise Exception("Invalid Layer Number")

if continued == 0:
  decay_num = 0
elif continued in [1, 4, 13, 40, 121, 364, 1093, 3280]:
  decay_num = sum(np.array([1, 4, 13, 40, 121, 364, 1093, 3280])<=continued)
  net.load_state_dict(torch.load('/content/drive/My Drive/IWAE_VAE/model/'+mod+'_net'+'_layer'+str(layer)+'_k'+str(k)+'_'+str(continued)+'.pth'))
elif continued < 3281:
  decay_num = sum(np.array([1, 4, 13, 40, 121, 364, 1093, 3280])<=continued)
  net.load_state_dict(torch.load('/content/drive/My Drive/IWAE_VAE/model/'+mod+'_net'+'_layer'+str(layer)+'_k'+str(k)+'_.pth'))
else:
  raise Exception("Invalid Starting Epoch")

net.to(device)


with open('/content/drive/My Drive/IWAE_VAE/outfile_'+mod+'_layer'+str(layer)+'_k'+str(k)+'_'+'.txt', 'w') as outfile:
    outfile.write('output of the code '+'\n'+'author:Angzhi Fan fana@uchicago.edu'+'\n')
    
start = time.time()
trainloader = torch.utils.data.DataLoader(mnist_train, batch_size=batch_size, 
                                         shuffle=True, num_workers=2)
learning_rate = 0.001/10**(decay_num/7)
optimizer = torch.optim.Adam(net.parameters(), lr = learning_rate
                             , betas=(0.9, 0.999), eps=1e-04)
for epoch in range(continued, 122):
    if epoch in [1, 4, 13, 40, 121, 364, 1093, 3280]:
        PATH = '/content/drive/My Drive/IWAE_VAE/model/'+mod+'_net'+'_layer'+str(layer)+'_k'+str(k)+'_'+str(epoch)+'.pth'
        torch.save(net.state_dict(), PATH)
        learning_rate /= 10**(1/7)
        optimizer = torch.optim.Adam(net.parameters(), lr = learning_rate, 
                                     betas=(0.9, 0.999), eps=1e-04)
        with open('/content/drive/My Drive/IWAE_VAE/outfile_'+mod+'_layer'+str(layer)+'_k'+str(k)+'_'+'.txt', 'a') as outfile:
            if layer == 1:
                nll, A_u = test_function_1(net)
            else:
                nll, A_u = test_function_2(net)
            print('NLL:', nll, 'active units:', A_u)
            outfile.write('test average (NLL):'+str(nll)+'\n')
            outfile.write('test average (active units):'+str(A_u)+'\n')
            outfile.write('learning rate decay'+'\n')
        print("learning rate=", learning_rate)
    running_loss = 0.0
    for i,data in enumerate(trainloader, 0):
        train = data.view(-1,1,28*28).to(device)
        optimizer.zero_grad()
        
        output = net(train.float())
        log_prob_condi = torch.sum(output[0]*train.repeat(1,k,1)-torch.log(1+torch.exp(output[0])), 2)

        if layer == 1:
            # stochastic layer
            h = (output[1].repeat(1,k,1) + torch.exp(output[2].repeat(1,k,1))*output[3])
            # log weights, unnormalized
            log_weights = log_prob_condi-(h*h).sum(2)/2+(output[3]*output[3]).sum(2)/2+output[2].repeat(1,k,1).sum(2)
        else:
            # stochastic layer h1 minus mu3
            h1 = output[1].repeat(1,k,1) + torch.exp(output[4].repeat(1,k,1))*output[7]-output[3]
            # stochastic layer h2
            h2 = (output[2]+torch.exp(output[5])*output[8])
            # log weights, unnormalized
            log_p_h1_h2 = -(h1*h1/torch.exp(2*output[6])).sum(2)/2-output[6].sum(2)
            log_q_h1_x = -(output[7]*output[7]).sum(2)/2-output[4].repeat(1,k,1).sum(2)
            log_q_h2_h1 = -(output[8]*output[8]).sum(2)/2-output[5].sum(2)
            log_weights = log_prob_condi+log_p_h1_h2-(h2*h2).sum(2)/2-log_q_h1_x-log_q_h2_h1
        
        log_weights.to(device)
        if mod == 'vae':
            loss = -log_weights.mean().to(device)
        else:
            # sample one index from k sets of hidden values
            temp = torch.exp(F.log_softmax(log_weights-log_weights.min(1)[0].view(-1,1),1)).to(device)
            temp1 = torch.multinomial(temp,1).flatten().to(device)+index_vec
            # estimate loss
            loss = -torch.take(log_weights, temp1).mean().to(device)
        
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        if i%500 == 499:
            print('[%d, %5d] loss: %.3f' %
                 (epoch+1, i + 1, running_loss/500))
            with open('/content/drive/My Drive/IWAE_VAE/outfile_'+mod+'_layer'+str(layer)+'_k'+str(k)+'_'+'.txt', 'a') as outfile:
                outfile.write('[%d, %5d] loss: %.3f' %
                 (epoch+1, i + 1, running_loss/500)+'\n')
            running_loss = 0.0
        
PATH = '/content/drive/My Drive/IWAE_VAE/model/'+mod+'_net'+'_layer'+str(layer)+'_k'+str(k)+'_'+'.pth'
torch.save(net.state_dict(), PATH)

print('Finished Training')
with open('/content/drive/My Drive/IWAE_VAE/outfile_'+mod+'_layer'+str(layer)+'_k'+str(k)+'_'+'.txt', 'a') as outfile:
    outfile.write('Finished Training'+'\n'+'time cost:'+str(time.time()-start)+'\n')


