<a href="https://colab.research.google.com/github/evaneill/vae_network/blob/master/notebooks/VR_alpha_AE_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from __future__ import print_function
import argparse
import torch
import torch.utils.data
from torch import nn, optim, Tensor as T
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
from torch.autograd import Variable
from torch.distributions.multinomial import Multinomial
import datetime
import os
import pickle
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from scipy.io import loadmat
import logging
import math

In [0]:
batch_size = 20
test_batch_size = 20
testing_frequency=50
epochs = 501
seed = 1
log_interval = 100
log_test_value = 100
K = 5
learning_rate = 2e-4
discrete_data = True
cuda = torch.cuda.is_available()

torch.manual_seed(seed)

data_name = 'omniglot'

alpha = .5
model_type = 'vralpha'

device = torch.device("cuda" if cuda else "cpu")

if model_type!="general_alpha" and model_type!="vralpha":
	model_name=model_type
else:
	model_name = model_type+str(alpha)

logging_filename = f'{model_name}_{data_name}_K{K}_M{batch_size}.log'
logging.basicConfig(filename=logging_filename,level=logging.DEBUG)



In [0]:
from google.colab import drive
drive.mount('/content/drive')

# Load data with random initialized train/test split
if os.environ.get('CLOUDSDK_CONFIG') is not None:   
    fpath = "/content/drive/My Drive/data/chardata.mat"
else:
    fpath = os.path.abspath('data/chardata.mat')

data = loadmat(fpath)

# From iwae repository
data_train = data['data'].T.astype('float32').reshape((-1, 28, 28)).reshape((-1, 28*28), order='F') 
data_test = data['testdata'].T.astype('float32').reshape((-1, 28, 28)).reshape((-1, 28*28), order='F')

data_train_t, data_test_t = T(data_train), T(data_test)

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
# Define the model
class omniglot2_model(nn.Module):
    def __init__(self):
        super(omniglot2_model, self).__init__()

        self.fc1 = nn.Linear(784, 200)
        self.fc2 = nn.Linear(200, 200)
        self.fc31 = nn.Linear(200, 100) # stochastic 1
        self.fc32 = nn.Linear(200, 100)

        self.fc4 = nn.Linear(100,100)
        self.fc5 = nn.Linear(100,100)
        self.fc61 = nn.Linear(100,50) # Innermost (stochastic 2)
        self.fc62 = nn.Linear(100,50)

        self.fc7 = nn.Linear(50,100)
        self.fc8 = nn.Linear(100,100)
        self.fc81 = nn.Linear(100,100) # stochastic 1
        self.fc82 = nn.Linear(100,100)

        self.fc9 = nn.Linear(100, 200)
        self.fc10 = nn.Linear(200, 200)
        self.fc11 = nn.Linear(200, 784) # reconstruction

        self.K = K

    def encode(self, x):
        #h1 = F.relu(self.fc1(x))
        h1 = torch.tanh(self.fc1(x))
        h2 = torch.tanh(self.fc2(h1))
        mu, log_std = self.fc31(h2), self.fc32(h2)

        z1 = self.reparameterize(mu, log_std)
        h3 = torch.tanh(self.fc4(z1))
        h4 = torch.tanh(self.fc5(h3))

        return self.fc61(h4), self.fc62(h4), [x,z1]

    def reparameterize(self, mu, logstd,test=False):
        std = torch.exp(logstd)
        if test==True:
          eps = torch.zeros_like(mu)
        else:
          eps = torch.randn_like(std)
        return mu + eps*std

    def decode(self, z,test=False):
        #h3 = F.relu(self.fc3(z))
        h5 = torch.tanh(self.fc7(z))
        h6 = torch.tanh(self.fc8(h5))
        mu, log_std = self.fc81(h6), self.fc82(h6)

        z1 = self.reparameterize(mu, log_std,test=test)
        h7 = torch.tanh(self.fc9(z1))
        h8 = torch.tanh(self.fc10(h7))

        return torch.sigmoid(self.fc11(h8))

    def forward(self, x):
        mu, logstd, _= self.encode(x.view(-1, 784))
        z = self.reparameterize(mu, logstd)
        return self.decode(z), mu, logstd

    def compute_loss_for_batch(self, data, model, K=K,test=False,alpha=alpha):
        # data = (N,560)
        if model_type=='vae':
            alpha=1
        elif model_type in ('iwae','vrmax'):
            alpha=0
        else:
            # use whatever alpha is defined in hyperparameters
            if abs(alpha-1)<=1e-3:
                alpha=1

        data_k_vec = data.repeat_interleave(K,0)

        mu, log_std , [x,z1] = self.encode(data_k_vec)
        # (B*K, #latents)
        z = model.reparameterize(mu, log_std)

        # Log p(z) (prior)
        log_p_z = torch.sum(-0.5 * z ** 2, 1)-.5*z.shape[1]*T.log(torch.tensor(2*np.pi))

        # q (z | h1)
        log_qz_h1 = compute_log_probabitility_gaussian(z, mu, log_std)

        h1 = torch.tanh(self.fc1(x))
        h2 = torch.tanh(self.fc2(h1))
        mu, log_std = self.fc31(h2), self.fc32(h2)

        # q (h1 | x)
        log_qh1_x = compute_log_probabitility_gaussian(z1, mu, log_std)

        h5 = torch.tanh(self.fc7(z))
        h6 = torch.tanh(self.fc8(h5))
        mu, log_std = self.fc81(h6), self.fc82(h6)

        # log p(h1 | z)
        log_ph1_z = compute_log_probabitility_gaussian(z1,mu,log_std)

        h7 = torch.tanh(self.fc9(z1))
        h8 = torch.tanh(self.fc10(h7))

        decoded = torch.sigmoid(self.fc11(h8))

        # log p(x | h1)
        log_px_h1 = compute_log_probabitility_bernoulli(decoded,x)
        
        # hopefully this reshape operation magically works like always
        if model_type == 'iwae' or test==True:
            log_w_matrix = (log_p_z + log_ph1_z + log_px_h1 - log_qz_h1 - log_qh1_x).view(-1, K)
        elif model_type =='vae':
            # treat each sample for a given data point as you would treat all samples in the minibatch
            # 1/K value because loss values seemed off otherwise
            log_w_matrix = (log_p_z + log_ph1_z + log_px_h1 - log_qz_h1 - log_qh1_x).view(-1, 1)*1/K
            return decoded, -torch.sum(log_w_matrix)
        elif model_type=='general_alpha' or model_type=='vralpha':
            log_w_matrix = (log_p_z + log_ph1_z + log_px_h1 - log_qz_h1 - log_qh1_x).view(-1, K) * (1-alpha)
        elif model_type == 'vrmax':
            log_w_matrix = (log_p_z + log_ph1_z + log_px_h1 - log_qz_h1 - log_qh1_x).view(-1, K).max(axis=1,keepdim=True).values
            return 0, 0, 0, -torch.sum(log_w_matrix)
        
        log_w_minus_max = log_w_matrix - torch.max(log_w_matrix, 1, keepdim=True)[0]
        ws_matrix = torch.exp(log_w_minus_max)
        ws_norm = ws_matrix / torch.sum(ws_matrix, 1, keepdim=True)

        if model_type=='vralpha' and not test:
            sample_dist = Multinomial(1,ws_norm)
            ws_sum_per_datapoint = log_w_matrix.gather(1,sample_dist.sample().argmax(1,keepdim=True))
        else:
            ws_sum_per_datapoint = torch.sum(log_w_matrix * ws_norm, 1)
        
        if model_type in ["general_alpha","vralpha"] and not test:
            ws_sum_per_datapoint/=(1-alpha)
        
        loss = -torch.sum(ws_sum_per_datapoint)

        return decoded, loss

In [0]:
def compute_log_probabitility_gaussian(obs, mu, logstd, axis=1):
    # leaving out constant factor related to 2 pi in formula
    return torch.sum(-0.5 * ((obs-mu) / torch.exp(logstd)) ** 2 - logstd, axis)-.5*obs.shape[1]*T.log(torch.tensor(2*np.pi))

def compute_log_probabitility_bernoulli(obs, p, axis=1):
    return torch.sum(p*torch.log(obs) + (1-p)*torch.log(1-obs), axis)

In [0]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, [data] in enumerate(train_loader):
        # (B, 1, F1, F2) (e.g.
        data = data.to(device)
        optimizer.zero_grad()

        #recon_batch, mu, logvar = model(data)
        #loss = loss_function(recon_batch, data, mu, logvar)
        recon_batch, loss = model.compute_loss_for_batch(data, model)
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))
            logging.info('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader),
                loss.item() / len(data)))

    print('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch,  train_loss / len(train_loader.dataset)))
    logging.info('====> Epoch: {} Average loss: {:.4f}'.format(
          epoch,  train_loss / len(train_loader.dataset)))

# pycharm thinks that I want to run a test whenever I define a function that has 'test' as prefix
# this messes with running the model and is the reason why the function is called _test
def _test(epoch):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for i, [data] in enumerate(test_loader):
            data = data.to(device)
            recon_batch, mu, logvar = model(data)
            _, loss = model.compute_loss_for_batch(data, model, K=5000, test=True)
            test_loss += loss.item()
            #test_loss += loss_function(recon_batch, data, mu, logvar).item()
            if i == 0:
                n = min(data.size(0), 8)
                comparison = torch.cat([data[:n].view(-1,1,28,28),
                                      recon_batch.view(test_batch_size, 1, 28, 28)[:n]])
                save_image(comparison.cpu(),
                         'results/reconstruction_' + str(epoch) + '.png', nrow=n)

    test_loss /= len(test_loader.dataset)
    #test_loss *= 5000
    print('====> Test set loss: {:.4f}'.format(test_loss))
    logging.info('====> Test set loss: {:.4f}'.format(test_loss))

In [0]:
# Initialize a model and data loaders
train_loader = DataLoader(TensorDataset(data_train_t),batch_size=batch_size,shuffle=True,pin_memory=True)
test_loader = DataLoader(TensorDataset(data_test_t),batch_size=test_batch_size,shuffle=True,pin_memory=True)

device = torch.device('cuda')
model = omniglot2_model().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Call the training shenanigans
if torch.cuda.is_available(): 
    print("Training on GPU")
    logging.info("Training on GPU")

os.makedirs('results/',exist_ok=True)

print(datetime.datetime.now())
logging.info(datetime.datetime.now())
for epoch in range(1, epochs + 1):
    train(epoch)
    if epoch % testing_frequency == 1:
        _test(epoch)
        with torch.no_grad():
            sample = torch.randn(64, 50).to(device)
            sample = model.decode(sample,test=True).cpu()
            save_image(sample.view(64, 1, 28, 28),
                        'results/sample_' + str(epoch) + '.png')
print(datetime.datetime.now())
print("Training finished")
logging.info("Training finished")

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
====> Epoch: 145 Average loss: 122.1879
====> Epoch: 146 Average loss: 122.1319
====> Epoch: 147 Average loss: 122.0478
====> Epoch: 148 Average loss: 122.0550
====> Epoch: 149 Average loss: 122.0203
====> Epoch: 150 Average loss: 121.9575
====> Epoch: 151 Average loss: 121.8978
====> Test set loss: 115.1674
====> Epoch: 152 Average loss: 121.8563
====> Epoch: 153 Average loss: 121.7931
====> Epoch: 154 Average loss: 121.7666
====> Epoch: 155 Average loss: 121.7312
====> Epoch: 156 Average loss: 121.6809
====> Epoch: 157 Average loss: 121.6113
====> Epoch: 158 Average loss: 121.5863
====> Epoch: 159 Average loss: 121.5477
====> Epoch: 160 Average loss: 121.5408
====> Epoch: 161 Average loss: 121.4436
====> Epoch: 162 Average loss: 121.4539
====> Epoch: 163 Average loss: 121.4077
====> Epoch: 164 Average loss: 121.3459
====> Epoch: 165 Average loss: 121.2716
====> Epoch: 166 Average loss: 121.2558
====> Epoch: 167 Average 

In [0]:
if 'results' in os.listdir():
    fstring = f'{model_name}_{data_name}_K{K}_M{batch_size}'
    !rm -r {fstring}

    !mkdir {fstring}
    !mkdir {fstring}/samples
    !mkdir {fstring}/recons
    !mv results/reconstruction_*  {fstring}/recons/
    !mv results/sample_* {fstring}/samples/
    !rm -r results

rm: cannot remove 'vralpha0.5_omniglot_K5_M20': No such file or directory


In [0]:
from zipfile import ZipFile
import os

import pickle as pkl

# with open(f'{model_type}_{data_name}_K{K}_M{batch_size}_grads.pkl','wb') as f:
#   pkl.dump((mu_grads,output_grads),f)

with open(f'{model_name}_{data_name}_K{K}_M{batch_size}/{model_name}_{data_name}_K{K}_M{batch_size}.pt','wb') as f:
  torch.save(model,f)

fstring = f'{model_name}_{data_name}_K{K}_M{batch_size}'
!mv {fstring}.log {fstring}/
# !mv {fstring}_grads.pkl {fstring}/
with ZipFile(f'drive/My Drive/experiment results/{model_name}_{data_name}_L2_K{K}_M{batch_size}_ours.zip','w') as f:
  f.write(f'{model_name}_{data_name}_K{K}_M{batch_size}/{model_name}_{data_name}_K{K}_M{batch_size}.pt')
  f.write(f'{model_name}_{data_name}_K{K}_M{batch_size}/{model_name}_{data_name}_K{K}_M{batch_size}.log')
  # f.write(f'{model_type}_{data_name}_K{K}_M{batch_size}/{model_type}_{data_name}_K{K}_M{batch_size}_grads.pkl')
  for img in os.listdir(f'{model_name}_{data_name}_K{K}_M{batch_size}/samples'):
    if img.endswith('.png'):
      f.write(f'{model_name}_{data_name}_K{K}_M{batch_size}/samples/'+img)

  for img in os.listdir(f'{model_name}_{data_name}_K{K}_M{batch_size}/recons'):
    if img.endswith('.png'):
      f.write(f'{model_name}_{data_name}_K{K}_M{batch_size}/recons/'+img)
# with open('iwae_silhouettes_LR00001.pt','wb') as f:
#   torch.save(model,f)


  "type " + obj.__name__ + ". It won't be checked "
