In [1]:
import os
import sys
import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms

from adabelief_pytorch import AdaBelief

from wrapper import Variational_Flipout, Variational_LRT

In [167]:
transform=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])
train_set = datasets.MNIST(
    '~/data', train=True, download=True,
    transform=transform)
test_set = datasets.MNIST(
    '~/data', train=False,
    transform=transform)
train_loader = DataLoader(train_set, batch_size = 128, shuffle=True)
test_loader = DataLoader(test_set, batch_size = 128, shuffle=False)

def mul_sign(x) -> torch.Tensor:
    #Best performance on several experiments
    return x.mul(torch.empty(x.shape, device = x.device).uniform_(-1,1).sign())

In [446]:
class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()
        self.b1 = nn.BatchNorm2d(1)
        self.l1 = Variational_Flipout(nn.Conv2d(1, 64, kernel_size=3))
        self.p1 = nn.MaxPool2d((2,2))
        self.b2 = nn.BatchNorm2d(64)
        self.l2 = Variational_Flipout(nn.Conv2d(64, 64, kernel_size=3))
        self.p2 = nn.MaxPool2d((2,2))
        self.b3 = nn.BatchNorm2d(64)
        self.l3 = Variational_Flipout(nn.Conv2d(64, 64, kernel_size=3))
        
        self.b4 = nn.BatchNorm1d(576)
        self.l4 = Variational_Flipout(nn.Linear(576, 64))
        self.b5 = nn.BatchNorm1d(64)
        self.l5 = Variational_Flipout(nn.Linear(64, 10))
        
    def forward(self, x):
        x = self.b1(x)
        x = self.l1(x)
        x = F.silu(x)
        x = self.p1(x)
        x = self.b2(x)
        x = self.l2(x)
        x = F.silu(x)
        x = self.p2(x)
        x = self.b3(x)
        x = self.l3(x)
        x = F.silu(x)
        
        x = x.view(x.shape[0], -1)
        
        x = self.b4(x)
        x = self.l4(x)
        x = F.silu(x)
        x = self.b5(x)
        x = self.l5(x)
                
        return x
    
    def kld(self):
        sum_kl = 0.0
        for module in self.modules():
            if isinstance(module, Variational_Flipout):
                sum_kl += module.kld()
        return sum_kl

In [444]:
class Variational_LRT(nn.Module):
    def __init__(self, module: nn.Module, weight_multiplcative_variance = True):
        super(Variational_LRT, self).__init__()
        """
        Wrapper class for existing torch modules.
        Use multiplicative noise in weight space to make layer stochastic.
        """
        
        #assert True in [isinstance(module, m) for m in registered_modules]

        self.weight_mean = module
        self.weight_logvar = nn.Parameter(self.weight_mean.weight.data.clone().detach().fill_(0))
        self.weight_multiplcative_variance = weight_multiplcative_variance
    
    def forward(self, x) -> torch.Tensor:
        weight = self.weight_mean.weight    
        bias = self.weight_mean.bias
        self.weight_mean.bias = None
        
        #Assume standard normal prior of weight and calculate output variance
        with torch.no_grad():
            self.weight_mean.weight.data = torch.ones(
                self.weight_mean.weight.data.shape,
                device = self.weight_mean.weight.data.device,
                requires_grad = False,
            )
            var_prior = self.weight_mean(x.pow(2))
        
        #Calculate LRT variance of layer output
        if self.weight_multiplcative_variance:
            self.weight_mean.weight.data = weight.pow(2) * self.weight_logvar.exp()
        else:
            self.weight_mean.weight.data = self.weight_logvar.exp()
        var = self.weight_mean(x.pow(2))
        
        self.weight_mean.weight.data = weight
        self.weight_mean.bias = bias
        mean = self.weight_mean(x)
        
        self._kld = (mean.pow(2) - var + (var - var_prior).exp() - 1).mean().div(2)
        return mean + var.sqrt() * torch.randn(var.shape, device = var.device, requires_grad = False)
    
    def kld(self) -> torch.Tensor:
        return self._kld

In [421]:
class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()
        self.b1 = nn.BatchNorm2d(1)
        self.l1 = Variational_LRT(nn.Conv2d(1, 64, kernel_size=3))
        self.p1 = nn.MaxPool2d((2,2))
        self.b2 = nn.BatchNorm2d(64)
        self.l2 = Variational_LRT(nn.Conv2d(64, 64, kernel_size=3))
        self.p2 = nn.MaxPool2d((2,2))
        self.b3 = nn.BatchNorm2d(64)
        self.l3 = Variational_LRT(nn.Conv2d(64, 64, kernel_size=3))
        
        self.b4 = nn.BatchNorm1d(576)
        self.l4 = Variational_LRT(nn.Linear(576, 64))
        self.b5 = nn.BatchNorm1d(64)
        self.l5 = Variational_LRT(nn.Linear(64, 10))
        
    def forward(self, x):
        x = self.b1(x)
        x = self.l1(x)
        x = F.silu(x)
        x = self.p1(x)
        x = self.b2(x)
        x = self.l2(x)
        x = F.silu(x)
        x = self.p2(x)
        x = self.b3(x)
        x = self.l3(x)
        x = F.silu(x)
        
        x = x.view(x.shape[0], -1)
        
        x = self.b4(x)
        x = self.l4(x)
        x = F.silu(x)
        x = self.b5(x)
        x = self.l5(x)
                
        return x
    
    def kld(self):
        sum_kl = 0.0
        for module in self.modules():
            if isinstance(module, Variational_LRT):
                sum_kl += module.kld()
        return sum_kl

In [447]:
model = net().cuda()
num_epochs = 200
lr = 3e-4
criterion = nn.CrossEntropyLoss()
optimizer = AdaBelief(
    model.parameters(),
    lr=lr, betas=(0.9, 0.999), eps=1e-8, weight_decouple = True, rectify=False, 
)

[31mPlease check your arguments if you have upgraded adabelief-pytorch from version 0.0.5.
[31mModifications to default arguments:
[31m                           eps  weight_decouple    rectify
-----------------------  -----  -----------------  ---------
adabelief-pytorch=0.0.5  1e-08  False              False
>=0.1.0 (Current 0.2.0)  1e-16  True               True
[34mSGD better than Adam (e.g. CNN for Image Classification)    Adam better than SGD (e.g. Transformer, GAN)
----------------------------------------------------------  ----------------------------------------------
Recommended eps = 1e-8                                      Recommended eps = 1e-16
[34mFor a complete table of recommended hyperparameters, see
[34mhttps://github.com/juntang-zhuang/Adabelief-Optimizer
[32mYou can disable the log message by setting "print_change_log = False", though it is recommended to keep as a reminder.
[0m
Weight decoupling enabled in AdaBelief


In [None]:
print_interval = 50
for epoch in range(num_epochs):
    running_loss = 0.0
    total_loss = 0.0
    for ind, (x, y) in enumerate(train_loader):
        optimizer.zero_grad()
        loss = criterion(model(x.cuda()), y.cuda())
        kld = model.kld()
        #loss.backward()
        (loss + 0.1*kld).backward()
        optimizer.step()
        running_loss += loss.item()
        total_loss += loss.item()
        if ind % print_interval == print_interval - 1:
            print('[%d, %5d] loss: %.6f, kld: %.6f, acc: %.6f' %
                  (epoch + 1, ind + 1, running_loss / print_interval, kld))
            running_loss = 0.0

[1,    50] loss: 1.876510, kld: 38.720524
[1,   100] loss: 1.120174, kld: 32.141052
[1,   150] loss: 0.824844, kld: 25.933521
[1,   200] loss: 0.684259, kld: 20.661304
[1,   250] loss: 0.537141, kld: 16.437626
[1,   300] loss: 0.444920, kld: 13.141074
[1,   350] loss: 0.380234, kld: 10.645747
[1,   400] loss: 0.334852, kld: 8.800092
[1,   450] loss: 0.288808, kld: 7.420666
[2,    50] loss: 0.261173, kld: 6.132567
[2,   100] loss: 0.242622, kld: 5.467493
[2,   150] loss: 0.236099, kld: 5.007080
[2,   200] loss: 0.227049, kld: 4.658742
[2,   250] loss: 0.209621, kld: 4.387023
[2,   300] loss: 0.217690, kld: 4.200820
[2,   350] loss: 0.198898, kld: 4.047126
[2,   400] loss: 0.198934, kld: 3.934625
[2,   450] loss: 0.195539, kld: 3.853153
[3,    50] loss: 0.186454, kld: 3.733509
[3,   100] loss: 0.189648, kld: 3.676773
[3,   150] loss: 0.182483, kld: 3.622032
[3,   200] loss: 0.193903, kld: 3.625362
[3,   250] loss: 0.188844, kld: 3.593967
[3,   300] loss: 0.183582, kld: 3.559656
[3,   350

In [None]:
model.l1.weight_logvar

In [None]:
for epoch in range(num_epochs):
    running_loss = 0.0
    total_loss = 0.0
    for i, data in enumerate(train_loader):
        optimizer.zero_grad()
        model_state_dict = copy.deepcopy(net_fn.state_dict())
        #loss = - log_posterior_fn(model_state_dict, data)
        loss = - log_posterior_fn(net_fn, model_state_dict, data)
        loss.backward()
        optimizer.step()
        scheduler1.step()
        running_loss += loss.item()
        total_loss += loss.item()
        if i % 100 == 99:    # print every 100 mini-batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 100))
            running_loss = 0.0
    model_state_dict = copy.deepcopy(net_fn.state_dict())
    test_acc, all_test_probs = evaluate_fn(test_loader, model_state_dict)
    scheduler2.step(test_acc)
    
    print("Epoch {}".format(epoch))
    print("\tAverage loss: {}".format(total_loss / epoch_steps))
    print("\tTest accuracy: {}".format(test_acc))