In [2]:
import os
import sys
sys.path.append('../')
import tqdm
import math
import inspect

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchvision import datasets, transforms

from adabelief_pytorch import AdaBelief

from wrapper import Variational_Flipout, Variational_LRT

os.environ["CUDA_VISIBLE_DEVICES"]="3"

In [2]:
transform=transforms.Compose([
    transforms.ToTensor(),
    #transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])
train_set = datasets.CIFAR10(
    '~/DATA', train=True, download=True,
    transform=transform)
test_set = datasets.CIFAR10(
    '~/DATA', train=False,
    transform=transform)
train_loader = DataLoader(train_set, batch_size = 128, shuffle=True, num_workers = 16)
test_loader = DataLoader(test_set, batch_size = 128, shuffle=False, num_workers = 16)

def mul_sign(x) -> torch.Tensor:
    #Best performance on several experiments
    return x.mul(torch.empty(x.shape, device = x.device).uniform_(-1,1).sign())

Files already downloaded and verified


In [4]:
class net(nn.Module):
    def __init__(self):
        super(net, self).__init__()
        self.b1 = nn.BatchNorm2d(3)
        self.l1 = nn.Conv2d(3, 64, kernel_size=3)
        self.p1 = nn.MaxPool2d((2,2))
        self.b2 = nn.BatchNorm2d(64)
        self.l2 = nn.Conv2d(64, 64, kernel_size=3)
        self.p2 = nn.MaxPool2d((2,2))
        self.b3 = nn.BatchNorm2d(64)
        self.l3 = nn.Conv2d(64, 64, kernel_size=3)
        
        self.flatten = nn.Flatten()
        
        self.b4 = nn.BatchNorm1d(1024)
        self.l4 = nn.Linear(1024, 64)
        self.b5 = nn.BatchNorm1d(64)
        self.l5 = nn.Linear(64, 10, bias = True)
        
    def forward(self, x):
        x = self.b1(x)
        x = self.l1(x)
        x = F.silu(x)
        x = self.p1(x)
        x = self.b2(x)
        x = self.l2(x)
        x = F.silu(x)
        x = self.p2(x)
        x = self.b3(x)
        x = self.l3(x)
        x = F.silu(x)
        
        x = self.flatten(x)
        
        x = self.b4(x)
        x = self.l4(x)
        x = F.silu(x)
        x = self.b5(x)
        x = self.l5(x)
                
        return x
    
    def kld(self):
        return 0.0
    
class net_Flipout(nn.Module):
    def __init__(self):
        super(net_Flipout, self).__init__()
        self.b1 = nn.BatchNorm2d(3)
        self.l1 = Variational_Flipout(nn.Conv2d(3, 64, kernel_size=3))
        self.p1 = nn.MaxPool2d((2,2))
        self.b2 = nn.BatchNorm2d(64)
        self.l2 = Variational_Flipout(nn.Conv2d(64, 64, kernel_size=3))
        self.p2 = nn.MaxPool2d((2,2))
        self.b3 = nn.BatchNorm2d(64)
        self.l3 = Variational_Flipout(nn.Conv2d(64, 64, kernel_size=3))
        
        self.flatten = nn.Flatten()
        
        self.b4 = nn.BatchNorm1d(1024)
        self.l4 = Variational_Flipout(nn.Linear(1024, 64))
        self.b5 = nn.BatchNorm1d(64)
        self.l5 = Variational_Flipout(nn.Linear(64, 10, bias = True))
        
    def forward(self, x):
        x = self.b1(x)
        x = self.l1(x)
        x = F.silu(x)
        x = self.p1(x)
        x = self.b2(x)
        x = self.l2(x)
        x = F.silu(x)
        x = self.p2(x)
        x = self.b3(x)
        x = self.l3(x)
        x = F.silu(x)
        
        x = self.flatten(x)
        
        x = self.b4(x)
        x = self.l4(x)
        x = F.silu(x)
        x = self.b5(x)
        x = self.l5(x)
                
        return x
    
    def kld(self):
        sum_kl = 0.0
        for module in self.modules():
            if isinstance(module, Variational_Flipout):
                sum_kl += module.kld()
        return sum_kl
    
class net_LRT(nn.Module):
    def __init__(self):
        super(net_LRT, self).__init__()
        self.b1 = nn.BatchNorm2d(3)
        self.l1 = Variational_LRT(nn.Conv2d(3, 64, kernel_size=3))
        self.p1 = nn.MaxPool2d((2,2))
        self.b2 = nn.BatchNorm2d(64)
        self.l2 = Variational_LRT(nn.Conv2d(64, 64, kernel_size=3))
        self.p2 = nn.MaxPool2d((2,2))
        self.b3 = nn.BatchNorm2d(64)
        self.l3 = Variational_LRT(nn.Conv2d(64, 64, kernel_size=3))
        
        self.flatten = nn.Flatten()
        
        self.b4 = nn.BatchNorm1d(1024)
        self.l4 = Variational_LRT(nn.Linear(1024, 64))
        self.b5 = nn.BatchNorm1d(64)
        self.l5 = Variational_LRT(nn.Linear(64, 10, bias = True))
        
    def forward(self, x):
        x = self.b1(x)
        x = self.l1(x)
        x = F.silu(x)
        x = self.p1(x)
        x = self.b2(x)
        x = self.l2(x)
        x = F.silu(x)
        x = self.p2(x)
        x = self.b3(x)
        x = self.l3(x)
        x = F.silu(x)
        
        x = self.flatten(x)
        
        x = self.b4(x)
        x = self.l4(x)
        x = F.silu(x)
        x = self.b5(x)
        x = self.l5(x)
                
        return x
    
    def kld(self):
        sum_kl = 0.0
        for module in self.modules():
            if isinstance(module, Variational_LRT):
                sum_kl += module.kld()
        return sum_kl

In [23]:
model = net().cuda()
num_epochs = 200
lr = 1e-2
criterion = nn.CrossEntropyLoss()
optimizer = AdaBelief(
    model.parameters(),
    lr=lr, betas=(0.9, 0.999), eps=1e-8, weight_decouple = True, rectify = False,
    print_change_log = False,
)

for epoch in range(num_epochs):
    train_loss = 0.0
    train_acc = 0.0
    test_loss = 0.0
    test_acc = 0.0
    
    model.train()
    for ind, (x, y) in enumerate(train_loader):
        y = y.cuda()
        optimizer.zero_grad()
        pred = model(x.cuda())
        loss = criterion(pred, y)
        kld = model.kld()
        #loss.backward()
        (loss + 0.1*kld).backward()
        #break
        optimizer.step()
        train_loss += loss.item()
        train_acc += (pred.argmax(dim = -1) == y).float().sum()
    train_loss /= len(train_set)
    train_acc /= len(train_set)

    model.eval()
    with torch.no_grad():
        for ind, (x, y) in enumerate(test_loader):
            y = y.cuda()
            pred = model(x.cuda())
            loss = criterion(pred, y)
            test_loss += loss.item()
            test_acc += (pred.argmax(dim = -1) == y).float().sum()
        test_loss /= len(test_set)
        test_acc /= len(test_set)
        
        kld = model.kld()
        
    print(
        '[%d] loss: %.6f/%.6f, acc: %.6f/%.6f, kld: %.6f' %
        (epoch + 1, train_loss, test_loss, train_acc, test_acc, kld)
    )

[1] loss: 0.010562/0.008737, acc: 0.507900/0.604900, kld: 0.000000
[2] loss: 0.007412/0.007241, acc: 0.664460/0.679000, kld: 0.000000
[3] loss: 0.006138/0.006675, acc: 0.724960/0.707000, kld: 0.000000
[4] loss: 0.005308/0.006092, acc: 0.761660/0.733300, kld: 0.000000
[5] loss: 0.004744/0.005840, acc: 0.787560/0.747200, kld: 0.000000
[6] loss: 0.004274/0.005975, acc: 0.806300/0.746000, kld: 0.000000
[7] loss: 0.003853/0.005820, acc: 0.828420/0.757800, kld: 0.000000
[8] loss: 0.003507/0.006040, acc: 0.843000/0.748800, kld: 0.000000
[9] loss: 0.003193/0.006187, acc: 0.854360/0.751100, kld: 0.000000
[10] loss: 0.002849/0.006514, acc: 0.869460/0.749300, kld: 0.000000
[11] loss: 0.002678/0.006465, acc: 0.876800/0.752000, kld: 0.000000
[12] loss: 0.002365/0.007127, acc: 0.891540/0.736700, kld: 0.000000
[13] loss: 0.002169/0.007379, acc: 0.899960/0.743500, kld: 0.000000
[14] loss: 0.002019/0.007738, acc: 0.907500/0.739200, kld: 0.000000
[15] loss: 0.001904/0.007774, acc: 0.911720/0.747600, kld

KeyboardInterrupt: 

In [18]:
model = net_Flipout().cuda()
num_epochs = 200
lr = 1e-2
criterion = nn.CrossEntropyLoss()
optimizer = AdaBelief(
    model.parameters(),
    lr=lr, betas=(0.9, 0.999), eps=1e-8, weight_decouple = True, rectify=False,
    print_change_log = False,
)

for epoch in range(num_epochs):
    train_loss = 0.0
    train_acc = 0.0
    test_loss = 0.0
    test_acc = 0.0
    
    model.train()
    for ind, (x, y) in enumerate(train_loader):
        y = y.cuda()
        optimizer.zero_grad()
        pred = model(x.cuda())
        loss = criterion(pred, y)
        kld = model.kld()
        #loss.backward()
        (loss + kld).backward()
        #break
        optimizer.step()
        train_loss += loss.item()
        train_acc += (pred.argmax(dim = -1) == y).float().sum()
    train_loss /= len(train_set)
    train_acc /= len(train_set)
    
    model.eval()
    with torch.no_grad():
        for ind, (x, y) in enumerate(test_loader):
            y = y.cuda()
            pred = model(x.cuda())
            loss = criterion(pred, y)
            test_loss += loss.item()
            test_acc += (pred.argmax(dim = -1) == y).float().sum()
        test_loss /= len(test_set)
        test_acc /= len(test_set)
        
        kld = model.kld()
        
    print(
        '[%d] loss: %.6f/%.6f, acc: %.6f/%.6f, kld: %.6f' %
        (epoch + 1, train_loss, test_loss, train_acc, test_acc, kld)
    )

[1] loss: 0.010948/0.009657, acc: 0.489980/0.561900, kld: 7.432627
[2] loss: 0.008287/0.008162, acc: 0.625160/0.629200, kld: 6.963572
[3] loss: 0.007275/0.007871, acc: 0.671960/0.647400, kld: 6.588394
[4] loss: 0.006841/0.007433, acc: 0.692320/0.675400, kld: 6.341869
[5] loss: 0.006627/0.007690, acc: 0.704660/0.659800, kld: 6.215690
[6] loss: 0.006416/0.007749, acc: 0.714460/0.660000, kld: 6.158609
[7] loss: 0.006246/0.007465, acc: 0.722040/0.675400, kld: 6.129203
[8] loss: 0.006022/0.007735, acc: 0.733040/0.667700, kld: 6.107191
[9] loss: 0.005899/0.007460, acc: 0.740680/0.682300, kld: 6.091041
[10] loss: 0.005714/0.007449, acc: 0.747940/0.688500, kld: 6.074485
[11] loss: 0.005553/0.007512, acc: 0.755380/0.681100, kld: 6.058074
[12] loss: 0.005419/0.007310, acc: 0.761800/0.689100, kld: 6.042441
[13] loss: 0.005227/0.007118, acc: 0.769900/0.698000, kld: 6.026711
[14] loss: 0.005134/0.007308, acc: 0.773940/0.688800, kld: 6.011161
[15] loss: 0.004997/0.007223, acc: 0.779460/0.698800, kld

KeyboardInterrupt: 

In [8]:
model = net_LRT().cuda()
num_epochs = 200
lr = 1e-2
criterion = nn.CrossEntropyLoss()
optimizer = AdaBelief(
    model.parameters(),
    lr=lr, betas=(0.9, 0.999), eps=1e-8, weight_decouple = True, rectify=False,
    print_change_log = False,
)

for epoch in range(num_epochs):
    train_loss = 0.0
    train_acc = 0.0
    test_loss = 0.0
    test_acc = 0.0
    
    model.train()
    for ind, (x, y) in enumerate(train_loader):
        y = y.cuda()
        optimizer.zero_grad()
        pred = model(x.cuda())
        loss = criterion(pred, y)
        kld = model.kld()
        #loss.backward()
        (loss + kld).backward()
        optimizer.step()
        train_loss += loss.item()
        train_acc += (pred.argmax(dim = -1) == y).float().sum()
    train_loss /= len(train_set)
    train_acc /= len(train_set)
    
    model.eval()
    with torch.no_grad():
        for ind, (x, y) in enumerate(test_loader):
            y = y.cuda()
            pred = model(x.cuda())
            loss = criterion(pred, y)
            test_loss += loss.item()
            test_acc += (pred.argmax(dim = -1) == y).float().sum()
        test_loss /= len(test_set)
        test_acc /= len(test_set)
        
        kld = model.kld()
        
    print(
        '[%d] loss: %.6f/%.6f, acc: %.6f/%.6f, kld: %.6f' %
        (epoch + 1, train_loss, test_loss, train_acc, test_acc, kld)
    )

Weight decoupling enabled in AdaBelief
[1] loss: 0.010528/0.009041, acc: 0.512180/0.596900, kld: 10.473148
[2] loss: 0.007785/0.007633, acc: 0.647540/0.657500, kld: 9.897487
[3] loss: 0.006807/0.007258, acc: 0.693140/0.681600, kld: 9.409855
[4] loss: 0.006280/0.007739, acc: 0.718000/0.665600, kld: 9.050241
[5] loss: 0.005991/0.007049, acc: 0.731180/0.693300, kld: 8.816916
[6] loss: 0.005771/0.007008, acc: 0.744780/0.701400, kld: 8.663939
[7] loss: 0.005614/0.007066, acc: 0.751280/0.704900, kld: 8.552006
[8] loss: 0.005427/0.007036, acc: 0.758900/0.703700, kld: 8.466328
[9] loss: 0.005167/0.007255, acc: 0.772160/0.695200, kld: 8.397057
[10] loss: 0.005131/0.006932, acc: 0.772020/0.705600, kld: 8.344603
[11] loss: 0.004941/0.007595, acc: 0.780800/0.692800, kld: 8.302938
[12] loss: 0.004853/0.007040, acc: 0.786220/0.713200, kld: 8.269974
[13] loss: 0.004674/0.007468, acc: 0.793720/0.691400, kld: 8.242685
[14] loss: 0.004578/0.007124, acc: 0.799560/0.705700, kld: 8.219312


KeyboardInterrupt: 