In [1]:
from PIL import Image
from torch.utils.data.dataset import Dataset
from scipy.misc import imread
import numpy as np
import torch
import torch.optim as optim
from core import networks

import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import copy

In [2]:
#Hyperparameters for training
mini_batch_size = 128
lambda_ = 1

In [3]:
#dataload (MNIST)
train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('mnist-data/', train=True, download=True,
                       transform=transforms.Compose([transforms.ToTensor(),])),
        batch_size=mini_batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('mnist-data/', train=False, transform=transforms.Compose([transforms.ToTensor(),])
                       ),
        batch_size=mini_batch_size, shuffle=True)

In [4]:
#custom regularization

import torch.nn as nn
from core.networks import BayesianNetwork
def custom_regularization(saver_net, trainer_net,mini_batch_size, lambda_, loss=None):
    
    mean_reg = 0
    sigma_reg = 0
    
    #net1, net2에서 각 레이어에 있는 mean, sigma를 이용하여 regularization 구현

    #각 모델에 module 접근
    for saver, trainer in zip(saver_net.modules(),trainer_net.modules()):
        
        #만약 BayesianNetwork 이면
        if isinstance(saver,BayesianNetwork) and isinstance(trainer,BayesianNetwork):
            
            i = 0
            
            #Network 내부의 layer에 순차적으로 접근
            for saver_layer, trainer_layer in zip(saver.layer_arr, trainer.layer_arr):
            
            # calculate mean regularization

                trainer_mu = trainer_layer.weight_mu
                saver_mu = saver_layer.weight_mu
                
                trainer_sigma = torch.log1p(torch.exp(trainer_layer.weight_rho))
                saver_sigma = torch.log1p(torch.exp(saver_layer.weight_rho))
                
                
                
                #mean_reg += lambda_*(torch.div(trainer_layer.weight_mu, saver_layer.weight_rho)-torch.div(trainer_layer.weight_mu, trainer_layer.weight_rho)).norm(2)
                mean_reg += lambda_*(torch.div(trainer_mu, saver_sigma)-torch.div(saver_mu, saver_sigma)).norm(2)
    
            # calculate sigma_reg regularization
            
                #sigma_reg += torch.sum(torch.div(trainer_layer.weight_rho, saver_layer.weight_rho) - torch.log(torch.div(trainer_layer.weight_rho, saver_layer.weight_rho)))
                sigma_reg += torch.sum(torch.div(trainer_sigma*trainer_sigma, saver_sigma*saver_sigma) - torch.log(torch.div(trainer_sigma*trainer_sigma, saver_sigma*saver_sigma)))

            sigma_reg = sigma_reg/(mini_batch_size*2)
            mean_reg = mean_reg/(mini_batch_size*2)
            loss = loss/mini_batch_size
                
#             print (mean_reg, sigma_reg) # regularization value 확인

    loss = loss + mean_reg + sigma_reg 

    return loss

In [5]:
def train(saver_net,trainer_net, optimizer, epoch, mini_batch_size, lambda_, DEVICE):
    trainer_net.train()
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(DEVICE), target.to(DEVICE)
        if data.shape[0] == mini_batch_size:
            #trainer_net.zero_grad()
            loss = trainer_net.sample_elbo(data, target, mini_batch_size, DEVICE)
            #loss = custom_regularization(saver_net, trainer_net, mini_batch_size, lambda_, loss)
            loss.backward()
            #print(trainer_net.l2.weight.rho.grad)
            #print(trainer_net.l.weight.rho.grad)
            #print(trainer_net.l1.weight.rho.grad.norm(2))
            
            optimizer.step()


In [6]:
import torch.nn.functional as F
def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data, sample=True)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True) # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)

    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

In [7]:
# GPU 설정
device_num = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#Model Initialization
#Saver_Net : mu = 0, sigma = log(1+exp(1))
#trainer_ner : mu = [-5,5], sigma = log(1+exp([+1,+1]))
saver_net = networks.BayesianNetwork(init_type = 'zero', DEVICE = device_num).to(device_num)
trainer_net = networks.BayesianNetwork(init_type = 'random', DEVICE = device_num).to(device_num)

optimizer = optim.Adam(saver_net.parameters())
optimizer = optim.Adam(trainer_net.parameters())

for epoch in range(10):

    #0. trainet_net variance init
    
    trainer_net.variance_init() #trainer net의 variance크게 init
    trainer_net = trainer_net.to(device_num)
    
    #1. trainer_net training 하는데 regularization을 위해서 saver_net의 정보 이용
    
    train(saver_net, trainer_net, optimizer, epoch, mini_batch_size, lambda_, device_num)

    #2. 1 batch가 끝나면 saver_net에 trainet_net을 복사 (weight = mean, sigma)
    
    saver_net = copy.deepcopy(trainer_net)
    
    test(trainer_net, device_num, test_loader)
    




Test set: Average loss: 119.6466, Accuracy: 8734/10000 (87%)


Test set: Average loss: 86.9003, Accuracy: 9078/10000 (91%)


Test set: Average loss: 53.7971, Accuracy: 9173/10000 (92%)


Test set: Average loss: 33.2610, Accuracy: 9230/10000 (92%)


Test set: Average loss: 22.8122, Accuracy: 9127/10000 (91%)


Test set: Average loss: 17.7324, Accuracy: 8871/10000 (89%)


Test set: Average loss: 16.2816, Accuracy: 8564/10000 (86%)


Test set: Average loss: 14.3006, Accuracy: 8375/10000 (84%)


Test set: Average loss: 12.2172, Accuracy: 8222/10000 (82%)


Test set: Average loss: 11.0288, Accuracy: 8402/10000 (84%)

