In [9]:
import numpy as np
import torch
import math
import os
import pickle
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from torchvision import transforms
from pathlib import Path
import torch.nn.init as init
from torchvision.datasets import MNIST

In [94]:
DATA_DIR = '/content/drive/MyDrive/AML_project/datasets/MNIST_DATASET'
SEED = 42
K = 256
BATCH_SIZE = 100
LR = 1e-4 # learning rate
BETA = 1e-3 # rate distortion hyperparam
BETA_DECAY = 0.999 # EMA exponential decay
SAVE_DIR = '/content/drive/MyDrive/AML_project'
MODEL_SAVE_PATH = '/content/drive/MyDrive/AML_project/saved_models/mnist_model.h5'
EMA_SAVE_PATH = '/content/drive/MyDrive/AML_project/saved_models/mnist_model_ema.h5'
PICKLE_SAVE_PATH = '/content/drive/MyDrive/AML_project/pickles/mnist_stats.pkl'
LOAD_MODEL = False
MC_SAMPLE_SIZE = 12
EPOCHS = 200 # in the paprt they did 200 epochs

In [95]:
torch.manual_seed(SEED)
np.random.seed(SEED)
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

In [96]:
if os.path.isdir(DATA_DIR):
    train_data = MNIST(root=DATA_DIR, train=True, transform=transforms.ToTensor())
    test_data = MNIST(root=DATA_DIR, train=False, transform=transforms.ToTensor())
else:
    train_data = MNIST(root=DATA_DIR, train=True, download=True, transform=transforms.ToTensor())
    test_data = MNIST(root=DATA_DIR, train=False, download=True, transform=transforms.ToTensor())

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /content/drive/MyDrive/MNIST_DATASET/MNIST/raw/train-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=9912422.0), HTML(value='')))


Extracting /content/drive/MyDrive/MNIST_DATASET/MNIST/raw/train-images-idx3-ubyte.gz to /content/drive/MyDrive/MNIST_DATASET/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /content/drive/MyDrive/MNIST_DATASET/MNIST/raw/train-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=28881.0), HTML(value='')))


Extracting /content/drive/MyDrive/MNIST_DATASET/MNIST/raw/train-labels-idx1-ubyte.gz to /content/drive/MyDrive/MNIST_DATASET/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /content/drive/MyDrive/MNIST_DATASET/MNIST/raw/t10k-images-idx3-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=1648877.0), HTML(value='')))


Extracting /content/drive/MyDrive/MNIST_DATASET/MNIST/raw/t10k-images-idx3-ubyte.gz to /content/drive/MyDrive/MNIST_DATASET/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /content/drive/MyDrive/MNIST_DATASET/MNIST/raw/t10k-labels-idx1-ubyte.gz


HBox(children=(FloatProgress(value=0.0, max=4542.0), HTML(value='')))


Extracting /content/drive/MyDrive/MNIST_DATASET/MNIST/raw/t10k-labels-idx1-ubyte.gz to /content/drive/MyDrive/MNIST_DATASET/MNIST/raw



In [97]:
train_loader = DataLoader(train_data,
                          batch_size=BATCH_SIZE,
                          shuffle=True,
                          num_workers=1,
                          drop_last=True)

test_loader = DataLoader(test_data,
                         batch_size=BATCH_SIZE,
                         shuffle=False,
                         num_workers=1,
                         drop_last=False)

In [98]:
def reparametrize(mu, std):
    """
    Performs reparameterization trick z = mu + epsilon * std
    Where epsilon~N(0,1)
    """
    mu = mu.expand(1, *mu.size())
    std = std.expand(1, *std.size())
    eps = torch.normal(0, 1, size=std.size()).to(device)
    return mu + eps * std

def evaluate(model, model_ema):
    model.eval()
    model_ema.model.eval()
    total_iterations = 0

    class_loss = 0
    info_loss = 0
    total_loss = 0
    izy_bound = 0
    izx_bound = 0
    correct = 0
    avg_correct = 0

    class_loss_ema = 0
    info_loss_ema = 0
    total_loss_ema = 0
    izy_bound_ema = 0
    izx_bound_ema = 0
    correct_ema = 0
    avg_correct_ema = 0
    total_num = 0

    for idx, (images, labels) in enumerate(test_loader):

        x = images.to(device)
        y = labels.to(device)

        (mu, std), logit = model(x)
        (mu_ema, std_ema), logit_ema = model_ema.model(x)

        class_loss += F.cross_entropy(logit, y).div(math.log(2))
        info_loss += -0.5 * (1 + 2 * std.log() - mu.pow(2) - std.pow(2)).sum().div(math.log(2))
        total_loss += class_loss + BETA * info_loss

        # TODO: check if size_average=False is needed here
        class_loss_ema += F.cross_entropy(logit_ema, y).div(math.log(2)) # reduction='sum'
        info_loss_ema += -0.5 * (1 + 2 * std_ema.log() - mu_ema.pow(2) - std_ema.pow(2)).sum().div(math.log(2))
        total_loss_ema += class_loss_ema + BETA * info_loss_ema

        total_num += y.size(0)
        total_iterations += 1

        # the math.log(10,2) is the entropy of the uniform RV for 10 examples in bits
        # (not expectation cause we assume it's unifrom).
        # This is negative in sign because we switch from a loss to a gain function
        izy_bound += math.log(10,2) - F.cross_entropy(logit, y).div(math.log(2))
        izx_bound += -0.5 * (1 + 2 * std.log() - mu.pow(2) - std.pow(2)).sum().div(math.log(2))

        izy_bound_ema += math.log(10,2) - F.cross_entropy(logit_ema, y).div(math.log(2)) # reduction='sum'
        izx_bound_ema += -0.5 * (1 + 2 * std_ema.log() - mu_ema.pow(2) - std_ema.pow(2)).sum().div(math.log(2))

        prediction = F.softmax(logit, dim=1).max(1)[1]
        correct += torch.eq(prediction, y).float().sum()

        prediction_ema = F.softmax(logit_ema, dim=1).max(1)[1]
        correct_ema += torch.eq(prediction_ema, y).float().sum()

        # TODO: add MC logic
        avg_correct = torch.zeros(correct.size()).to(device)
        avg_correct_ema = torch.zeros(correct_ema.size()).to(device)

    accuracy = correct / total_num
    avg_accuracy = avg_correct / total_num

    accuracy_ema = correct_ema / total_num
    avg_accuracy_ema = avg_correct_ema / total_num

    izy_bound /= total_iterations
    izx_bound /= total_num
    class_loss /= total_num
    info_loss /= total_num
    total_loss /= total_num

    izy_bound_ema /= total_iterations
    izx_bound_ema /= total_num
    class_loss_ema /= total_num
    info_loss_ema /= total_num
    total_loss_ema /= total_num

    book_keeper['test_izx'].append(izx_bound.data.item())
    book_keeper['test_izy'].append(izy_bound.data.item())
    book_keeper['test_acc'].append(accuracy.data.item())
    book_keeper['test_error'].append(1 - accuracy.data.item())
    book_keeper['test_class_loss'].append(class_loss.data.item())
    book_keeper['test_info_loss'].append(info_loss.data.item())
    book_keeper['test_total_loss'].append(total_loss.data.item())
    book_keeper['test_izx_ema'].append(izx_bound_ema.data.item())
    book_keeper['test_izy_ema'].append(izy_bound_ema.data.item())
    book_keeper['test_acc_ema'].append(accuracy_ema.data.item())
    book_keeper['test_error_ema'].append(1 - accuracy_ema.data.item())
    book_keeper['test_class_loss_ema'].append(class_loss_ema.data.item())
    book_keeper['test_info_loss_ema'].append(info_loss_ema.data.item())
    book_keeper['test_total_loss_ema'].append(total_loss_ema.data.item())

    print('[TEST RESULT - regular]')
    print('e:{} IZY:{:.2f} IZX:{:.2f}'
            .format(e, izy_bound.data.item(), izx_bound.data.item()), end=' ')
    print('acc:{:.4f} avg_acc:{:.4f}'
            .format(accuracy.data.item(), avg_accuracy.data.item()), end=' ')
    print('err:{:.4f} avg_erra:{:.4f}'
            .format(1 - accuracy.data.item(), 1 - avg_accuracy.data.item()))
    
    print('[TEST RESULT - ema]')
    print('e:{} IZY:{:.2f} IZX:{:.2f}'
            .format(e, izy_bound_ema.data.item(), izx_bound_ema.data.item()), end=' ')
    print('acc:{:.4f} avg_acc:{:.4f}'
            .format(accuracy_ema.data.item(), avg_accuracy_ema.data.item()), end=' ')
    print('err:{:.4f} avg_erra:{:.4f}'
            .format(1 - accuracy_ema.data.item(), 1 - avg_accuracy_ema.data.item()))
    print()

    if history['acc'] < accuracy.data.item():
        history['acc'] = accuracy.data.item()
        history['avg_acc'] = avg_accuracy.data.item()
        history['class_loss'] = class_loss.data.item()
        history['info_loss'] = info_loss.data.item()
        history['total_loss'] = total_loss.data.item()
        history['epoch'] = e
        history['iter'] = iterations
        torch.save(model, MODEL_SAVE_PATH)
        print('Saved model to {}'.format(MODEL_SAVE_PATH))
    
    if history_ema['acc'] < accuracy_ema.data.item():
        history_ema['acc'] = accuracy_ema.data.item()
        history_ema['avg_acc'] = avg_accuracy_ema.data.item()
        history_ema['class_loss'] = class_loss_ema.data.item()
        history_ema['info_loss'] = info_loss_ema.data.item()
        history_ema['total_loss'] = total_loss_ema.data.item()
        history_ema['epoch'] = e
        history_ema['iter'] = iterations
        torch.save(model_ema, EMA_SAVE_PATH)
        print('Saved model to {}'.format(EMA_SAVE_PATH))

    model.train()
    model_ema.model.train()

In [99]:
class MNIST_IB_VAE(nn.Module):
    """
    Direct implementation of the paper's MNIST net
    Only one shot eval (no MC) - work well for beta <= 1e-3
    """
    def __init__(self, k=K):
        super(MNIST_IB_VAE, self).__init__()
        self.k = k

        self.encoder = nn.Sequential(
            nn.Linear(784, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 2 * self.k))

        self.decoder = nn.Sequential(
                nn.Linear(self.k, 10))

        # Xavier initialization
        for _, module in self._modules.items():
            for layer in module:
                if isinstance(layer, nn.Linear) or isinstance(layer, nn.Conv2d):
                            nn.init.xavier_uniform_(layer.weight, gain=nn.init.calculate_gain('relu'))
                            layer.bias.data.zero_()

    def forward(self, x):
        # squiwsh from shape (100,1,28,28) to (100,784)
        x = x.view(x.size(0),-1)
        z_params = self.encoder(x)
        mu = z_params[:, :self.k]
        # softplus transformation (soft relu) and a -5 bias is added as in the paper
        std = F.softplus(z_params[:, self.k:] - 5, beta=1)
        z = reparametrize(mu, std)
        logit = self.decoder(z)
        return (mu, std), logit[0]

class EMA_smoothning(object):
    """
    Performs exponential moving average smoothing on model updates as per
    Polyak & Juditsky, 1992.
    This will be used as a second network refference when evaluating.
    """
    def __init__(self, model, state_dict, beta_decay=0.999):
        self.model = model
        self.model.load_state_dict(state_dict, strict=True)
        self.beta_decay = beta_decay

    def update(self, new_state_dict):
        state_dict = self.model.state_dict()
        for key in state_dict.keys():
            state_dict[key] = (self.beta_decay) * state_dict[key] + (1 - self.beta_decay) * new_state_dict[key]
        self.model.load_state_dict(state_dict)

In [100]:
if LOAD_MODEL:
    model = torch.load(os.path.join(MODEL_SAVE_PATH)
else:
    model = MNIST_IB_VAE(k=K).to(device)

model_ema = EMA_smoothning(MNIST_IB_VAE(k=K).to(device), model.state_dict(),
                           beta_decay=BETA_DECAY)

optimizer = optim.Adam(model.parameters(), lr=LR, betas=(0.5,0.999))
scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=0.97)

history = dict()
history['acc']=0.
history['avg_acc']=0.
history['info_loss']=0.
history['class_loss']=0.
history['total_loss']=0.
history['epoch']=0
history['iter']=0

history_ema = dict()
history_ema['acc']=0.
history_ema['avg_acc']=0.
history_ema['info_loss']=0.
history_ema['class_loss']=0.
history_ema['total_loss']=0.
history_ema['epoch']=0
history_ema['iter']=0

keys = [
        'test_izx', 'test_izy', 'test_acc', 'test_error',
        'test_izx_ema', 'test_izy_ema', 'test_acc_ema', 'test_error_ema',
        'test_class_loss', 'test_info_loss', 'test_total_loss',
        'test_class_loss_ema', 'test_info_loss_ema', 'test_total_loss_ema',
        'train_izx', 'train_izy', 'train_acc', 'train_error', 'train_izx_ema',
        'train_izy_ema', 'train_acc_ema', 'train_error_ema',
        'train_class_loss', 'train_info_loss', 'train_total_loss'
        ]
book_keeper = {key: [] for key in keys}

iterations = 0
epoch = 0

### Train loop

In [101]:
model.train()
model_ema.model.train()

for e in range(EPOCHS):
    total_num = 0
    epoch_iterations = 0
    epoch_class_loss = 0
    epoch_info_loss = 0
    epoch_total_loss = 0
    epoch_izy_bound = 0
    epoch_izx_bound = 0
    epoch_correct = 0
    correct = 0

    for idx, (images, labels) in enumerate(train_loader):
        iterations += 1

        x = images.to(device)
        y = labels.to(device)
        (mu, std), logit = model(x)

        class_loss = F.cross_entropy(logit, y).div(math.log(2)) # div(log(2)) -> transfer to bits
        info_loss = -0.5 * (1 + 2 * std.log() - mu.pow(2) - std.pow(2)).sum(1).mean().div(math.log(2))
        total_loss = class_loss + BETA * info_loss

        izy_bound = math.log(10, 2) - class_loss
        izx_bound = info_loss

        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

        # A secondary smoothed network for comparison
        model_ema.update(model.state_dict())

        prediction = F.softmax(logit, dim=1).max(1)[1]
        accuracy = torch.eq(prediction, y).float().mean()

        # TODO: change this to MC sampling 
        avg_accuracy = torch.zeros(accuracy.size()).to(device)

        # book keeping
        epoch_class_loss += class_loss
        info_loss += info_loss
        total_loss += total_loss
        correct += torch.eq(prediction, y).float().sum()
        total_num += y.size(0)
        epoch_iterations += 1

        if (iterations % 100 == 0):
            print('i:{} IZY:{:.2f} IZX:{:.2f}'
                    .format(idx+1, izy_bound.data.item(), izx_bound.data.item()), end=' ')
            print('acc:{:.4f} avg_acc:{:.4f}'
                    .format(accuracy.data.item(), avg_accuracy.data.item()), end=' ')
            print('err:{:.4f} avg_err:{:.4f}'
                    .format(1-accuracy.data.item(), 1-avg_accuracy.data.item()))

    if (e % 2) == 0 and (e != 0):
        scheduler.step()

    accuracy_ = correct / total_num
    izy_bound /= epoch_iterations
    izx_bound /= total_num
    class_loss /= total_num
    info_loss /= total_num
    total_loss /= total_num

    book_keeper['train_izx'].append(izx_bound.data.item())
    book_keeper['train_izy'].append(izy_bound.data.item())
    book_keeper['train_acc'].append(accuracy_.data.item())
    book_keeper['train_error'].append(1 - accuracy_.data.item())
    book_keeper['train_class_loss'].append(class_loss.data.item())
    book_keeper['train_info_loss'].append(info_loss.data.item())
    book_keeper['train_total_loss'].append(total_loss.data.item())

    evaluate(model, model_ema)

print("----- Training complete -----")

i:100 IZY:2.76 IZX:887.28 acc:0.8700 avg_acc:0.0000 err:0.1300 avg_err:1.0000
i:200 IZY:2.64 IZX:626.83 acc:0.9000 avg_acc:0.0000 err:0.1000 avg_err:1.0000
i:300 IZY:3.02 IZX:475.53 acc:0.9300 avg_acc:0.0000 err:0.0700 avg_err:1.0000
i:400 IZY:2.97 IZX:417.49 acc:0.9100 avg_acc:0.0000 err:0.0900 avg_err:1.0000
i:500 IZY:2.59 IZX:443.82 acc:0.8800 avg_acc:0.0000 err:0.1200 avg_err:1.0000
i:600 IZY:2.97 IZX:352.20 acc:0.9600 avg_acc:0.0000 err:0.0400 avg_err:1.0000
[TEST RESULT - regular]
e:0 IZY:3.03 IZX:162.86 acc:0.9374 avg_acc:0.0000 err:0.0626 avg_erra:1.0000
[TEST RESULT - ema]
e:0 IZY:2.18 IZX:1302.78 acc:0.7727 avg_acc:0.0000 err:0.2273 avg_erra:1.0000

Saved model to /content/drive/MyDrive/AML_project/saved_models/mnist_model.h5
Saved model to /content/drive/MyDrive/AML_project/saved_models/mnist_model_ema.h5
i:100 IZY:2.98 IZX:312.33 acc:0.9100 avg_acc:0.0000 err:0.0900 avg_err:1.0000
i:200 IZY:3.08 IZX:292.63 acc:0.9400 avg_acc:0.0000 err:0.0600 avg_err:1.0000
i:300 IZY:2.76 I

In [102]:
with open(PICKLE_SAVE_PATH, 'wb') as f:
    pickle.dump(book_keeper, f)