In [None]:
from __future__ import print_function
import torch
from torch import nn, optim
from torch.autograd import Variable
import torch.nn.functional as F
import torch.distributions as dist
from torchvision import datasets, transforms
import torch.utils.data as utils

import numpy as np
from tqdm import tqdm_notebook as tqdm

from matplotlib import pyplot as plt

In [None]:
%load_ext autoreload
%autoreload 2
from models import *

In [None]:
train_7x7 = np.load('kernels/kernels_7x7.npy').reshape(-1,1,7,7)
train_7x7_dataset = utils.TensorDataset(torch.Tensor(train_7x7))
train_7x7_loader = utils.DataLoader(train_7x7_dataset, batch_size=128, shuffle=True)

train_5x5 = np.load('kernels/kernels_5x5.npy').reshape(-1,1,5,5)
train_5x5_dataset = utils.TensorDataset(torch.Tensor(train_5x5))
train_5x5_loader = utils.DataLoader(train_5x5_dataset, batch_size=128, shuffle=True)

In [None]:
vae7x7 = VAE7x7(32,2)
optimizer7x7 = optim.Adam(vae7x7.parameters())
vae7x7.cuda()

vae5x5 = VAE5x5(64,6)
optimizer5x5 = optim.Adam(vae5x5.parameters())
vae5x5.cuda()

None

In [None]:
def train_vae_n_epochs(vae, optimizer, train_loader, n=1):
    for epoch in tqdm(range(n)):
        train_loss = 0
        vae.train()
        for batch_idx, (x,) in enumerate(train_loader):
            x = x.cuda()
            optimizer.zero_grad()
            loss = vae.elbo(x, beta=1.)
            train_loss += loss.item()
            loss.backward()
            optimizer.step()
            #print(loss.item())
        #print(train_loss / len(train_loader))

In [None]:
train_vae_n_epochs(vae7x7, optimizer7x7, train_7x7_loader, n=30*5)
torch.save(vae7x7, './models/serialized_vae7x7')

In [None]:
train_vae_n_epochs(vae5x5, optimizer5x5, train_5x5_loader, n=2*5)
torch.save(vae5x5, './models/serialized_vae5x5')

#### vae7x7 check

In [None]:
x = next(enumerate(train_7x7_loader))[1][0].cuda()
z_mean, z_logvar, z, x_mean, x_logvar = vae7x7(x)
vae = vae7x7

In [None]:
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=[8,8])
axes = np.array(axes).flatten()

for i, ax in enumerate(axes):
    t = ax.imshow(x.cpu().data.numpy()[i,0,:,:])
    t.set_cmap('RdYlBu')
    ax.axis('off')
plt.show()
fig.savefig('figures/original_7x7.png', dpi=300)

In [None]:
xs_gen = vae.generate(n=25)

In [None]:
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=[8,8])
axes = np.array(axes).flatten()

for i, ax in enumerate(axes):
    t = ax.imshow(xs_gen.cpu().data.numpy()[i,0,:,:])
    t.set_cmap('RdYlBu')
    ax.axis('off')
plt.show()
fig.savefig('figures/generated_7x7.png', dpi=300)

#### vae5x5 check

In [None]:
x = next(enumerate(train_5x5_loader))[1][0].cuda()
z_mean, z_logvar, z, x_mean, x_logvar = vae5x5(x)
vae = vae5x5

In [None]:
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=[8,8])
axes = np.array(axes).flatten()

for i, ax in enumerate(axes):
    t = ax.imshow(x.cpu().data.numpy()[i,0,:,:])
    t.set_cmap('RdYlBu')
    ax.axis('off')
plt.show()
fig.savefig('figures/original_5x5.png', dpi=300)

In [None]:
xs_gen = vae.generate(n=25)

In [None]:
fig, axes = plt.subplots(nrows=5, ncols=5, figsize=[8,8])
axes = np.array(axes).flatten()

for i, ax in enumerate(axes):
    t = ax.imshow(xs_gen.cpu().data.numpy()[i,0,:,:])
    t.set_cmap('RdYlBu')
    ax.axis('off')
plt.show()
fig.savefig('figures/generated_5x5.png', dpi=300)

In [None]:
vae5x5.eval()
vae7x7.eval()
for p in vae7x7.parameters():
    p.requires_grad = False
for p in vae5x5.parameters():
    p.requires_grad = False

In [None]:
def train_and_get_accs_every_10_epochs(model, optimizer, n_epochs=101, mode='dwp'):
    
    list(model.children())[0].mode = mode
    list(model.children())[1].mode = mode
    
    accs = []
    for epoch in tqdm(range(1,n_epochs)):      
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.cuda(), target.cuda()
            optimizer.zero_grad()
            output = model(data)
            loss = F.cross_entropy(output, target, reduction='sum') * len(train_loader)
            loss += list(model.children())[0].kl(vae7x7) * 1
            #loss += list(model.children())[1].kl(vae5x5) * 1
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1e6)
            optimizer.step()

        if epoch % 10 == 0:
            model.eval()
            test_loss = 0
            correct = 0
            with torch.no_grad():
                for data, target in test_loader:
                    data, target = data.cuda(), target.cuda()
                    output = model(data)
                    test_loss += F.cross_entropy(output, target, reduction='sum').item()
                    pred = output.argmax(dim=1, keepdim=True)
                    correct += pred.eq(target.view_as(pred)).sum().item()
            test_loss /= len(test_loader.dataset)
            
            accs.append(100. * correct / len(test_loader.dataset))

        if epoch % 100 == 0:
            print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
                test_loss, correct, len(test_loader.dataset),
                100. * correct / len(test_loader.dataset)))
            
            
    return accs

## Experiment On The Max Accuracy

In [None]:
#train_sizes = [50,150,500,1000]
train_sizes = [50]
    
accs_for_different_train_sizes = {}

In [None]:
for train_size in train_sizes:

    train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data',
                       train=True,
                       download=True,
                       transform=transforms.ToTensor()
                      ),
        batch_size=32
    )

    train_loader.dataset.data = train_loader.dataset.data[:train_size]

    test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('./data', train=False, 
                       transform=transforms.ToTensor()
                      ),
        batch_size=128
    )
    
    
    model = BayesNet()
    model.cuda()
    optimizer = optim.Adam(model.parameters())

    accs_dwp = train_and_get_accs_every_10_epochs(model,
                                                  optimizer,
                                                  n_epochs=7001,
                                                  mode='dwp')
    
    
    model = BayesNet()
    model.cuda()
    optimizer = optim.Adam(model.parameters())

    accs_gauss = train_and_get_accs_every_10_epochs(model,
                                                    optimizer,
                                                    mode='gaussian',
                                                    n_epochs=7001)
    
    model = BayesNet()
    model.cuda()
    optimizer = optim.Adam(model.parameters())

    accs_logunif = train_and_get_accs_every_10_epochs(model,
                                                    optimizer,
                                                    mode='log-uniform',
                                                    n_epochs=7001)
    
    
    accs_for_different_train_sizes[train_size] = {'gaussian':accs_gauss,
                                                  'dwp':accs_dwp,
                                                  'log-uniform':accs_logunif}

In [None]:
fig, [ax1, ax2] = plt.subplots(ncols=2, figsize=[13, 4])
axes = [ax1, ax2]

for i,train_size in enumerate(train_sizes[:2]):
    ax = axes[i]
    ax.plot(accs_for_different_train_sizes[train_size]['gaussian'][:70], label='gaussian')
    ax.plot(accs_for_different_train_sizes[train_size]['dwp'][:70], label='dwp')
    ax.plot(accs_for_different_train_sizes[train_size]['log-uniform'][:70], label='log-uniform')
    ax.set_ylabel('accuracy on validation')
    ax.set_xlabel('epoch')
    ax.set_title('{} objects in the train dataset'.format(train_size))
    ax.legend()
plt.show()

fig.savefig('figures/small_data.png', dpi=300)

## Initialization Experiments

In [None]:
def init_xavier(layer):
    layer.weight.data = torch.nn.init.xavier_uniform_(layer.weight.data)
def init_filters(layer, train):
    n_filters = layer.weight.shape[0] * layer.weight.shape[1]
    inds = np.random.randint(0, train.shape[0], size=n_filters)
    new_weight = train[inds].reshape(layer.weight.shape)
    layer.weight.data = torch.Tensor(new_weight).cuda()
def init_vae(layer, vae):
    n_filters = layer.weight.shape[0] * layer.weight.shape[1]
    xs_gen = vae.generate(n=n_filters).view(layer.weight.shape)
    layer.weight.data = xs_gen

In [None]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data',
                   train=True,
                   download=True,
                   transform=transforms.ToTensor()
                  ),
    batch_size=128
)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False, 
                   transform=transforms.ToTensor()
                  ),
    batch_size=128
)

In [None]:
def train_and_get_accs():
    accs = []
    for batch_idx, (data, target) in enumerate(train_loader):
        net.train()
        data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()
        output = net(data)
        loss = F.cross_entropy(output, target) 
        loss.backward()
        optimizer.step()
    
        if batch_idx % 1 == 0 :

            net.eval()
            test_loss = 0
            correct = 0
            with torch.no_grad():
                for data, target in test_loader:
                    data, target = data.cuda(), target.cuda()
                    output = net(data)
                    test_loss += F.cross_entropy(output, target).item()
                    pred = output.argmax(dim=1, keepdim=True)
                    correct += pred.eq(target.view_as(pred)).sum().item()
            test_loss /= len(test_loader.dataset)
            accs.append(100. * correct / len(test_loader.dataset))
            
            if 100. * correct / len(test_loader.dataset) > 95:
                break
            
            print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
                test_loss, correct, len(test_loader.dataset),
                100. * correct / len(test_loader.dataset)))
        
    return accs

In [None]:
accs_xavier = []

for _ in range(3):
    net = SmartInitializedNet()
    optimizer = optim.Adam(net.parameters())
    net.cuda()

    init_xavier(list(net.children())[0])
    init_xavier(list(net.children())[1])

    accs_xavier.append(train_and_get_accs())
    
min_len = min([len(accs_xavier[i]) for i in range(3)])
for i in range(len(accs_xavier)):
    accs_xavier[i] = accs_xavier[i][:min_len]

In [None]:
accs_filters = []

for _ in range(3):
    net = SmartInitializedNet()
    optimizer = optim.Adam(net.parameters())
    net.cuda()

    init_filters(list(net.children())[0], train_7x7)
    init_filters(list(net.children())[1], train_5x5)

    accs_filters.append(train_and_get_accs())
    
min_len = min([len(accs_filters[i]) for i in range(3)])
for i in range(len(accs_filters)):
    accs_filters[i] = accs_filters[i][:min_len]

In [None]:
accs_dwp = []

for _ in range(3):
    net = SmartInitializedNet()
    optimizer = optim.Adam(net.parameters())
    net.cuda()

    init_vae(list(net.children())[0], vae7x7)
    init_vae(list(net.children())[1], vae5x5)

    accs_dwp.append(train_and_get_accs())
    
min_len = min([len(accs_dwp[i]) for i in range(3)])
for i in range(len(accs_dwp)):
    accs_dwp[i] = accs_dwp[i][:min_len]

In [None]:
fig = plt.figure(figsize=[12,7])
plt.title('convergence for different initializations. Averaged over multiple runs')
plt.plot(np.array(accs_xavier).mean(0), data=None, label='xavier')
plt.plot(np.array(accs_filters).mean(0), label='filters')
plt.plot(np.array(accs_dwp).mean(0), label='dwp')
plt.xlabel('batch id')
plt.ylabel('test accuracy')
plt.legend()
plt.grid()
plt.show()

In [None]:
fig.savefig('figures/init.png', dpi=300)