In [1]:
from PIL import Image
from torch.utils.data.dataset import Dataset
from scipy.misc import imread
import numpy as np
import torch
import torch.optim as optim
from core import networks

import torchvision
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

import copy

In [2]:
#Hyperparameters for training
mini_batch_size = 128
lambda_ = 1e-8

In [3]:
#dataload (MNIST)
train_loader = torch.utils.data.DataLoader(
        datasets.MNIST('mnist-data/', train=True, download=True,
                       transform=transforms.Compose([transforms.ToTensor(),])),
        batch_size=mini_batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
        datasets.MNIST('mnist-data/', train=False, transform=transforms.Compose([transforms.ToTensor(),])
                       ),
        batch_size=mini_batch_size, shuffle=True)

In [4]:
#custom regularization

import torch.nn as nn
from core.networks import BayesianNetwork
def custom_regularization(saver_net, trainer_net,mini_batch_size, lambda_, loss=None):
    
    mean_reg = 0
    sigma_reg = 0
    
    #net1, net2에서 각 레이어에 있는 mean, sigma를 이용하여 regularization 구현

    #각 모델에 module 접근
    for saver, trainer in zip(saver_net.modules(),trainer_net.modules()):
        
        #만약 BayesianNetwork 이면
        if isinstance(saver,BayesianNetwork) and  isinstance(trainer,BayesianNetwork):
            
            #Network 내부의 layer에 순차적으로 접근
            for saver_layer, trainer_layer in zip(saver.layer_arr, trainer.layer_arr):
            
            # calculate mean regularization
            
                mean_reg += lambda_*(torch.div(trainer_layer.weight_mu, saver_layer.weight_sigma)-torch.div(trainer_layer.weight_mu, trainer_layer.weight_sigma)).norm(2)
                
                
            # calculate sigma_reg regularization
            
                sigma_reg += torch.sum(torch.div(trainer_layer.weight_sigma, saver_layer.weight_sigma) - torch.log(torch.div(trainer_layer.weight_sigma, saver_layer.weight_sigma)))
            
            sigma_reg = sigma_reg/(mini_batch_size*2)
            mean_reg = mean_reg/(mini_batch_size*2)
                
#             print (mean_reg, sigma_reg) # regularization value 확인
    print ('loss')
    print (loss)
    print ()
#     print ('mean_reg')
#     print (mean_reg)
#     print ()
#     print ('sigma_reg')
#     print (sigma_reg)
#     print ()
    
    loss = loss + mean_reg + sigma_reg 

    return loss

In [5]:
def train(saver_net,trainer_net, optimizer, epoch, mini_batch_size, lambda_, DEVICE):
    trainer_net.train()
    
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(DEVICE), target.to(DEVICE)
        trainer_net.zero_grad()
        loss = trainer_net.sample_elbo(data, target, mini_batch_size, DEVICE) #홍준 코딩
        loss = custom_regularization(saver_net, trainer_net, mini_batch_size, lambda_, loss)
        loss.backward()
        optimizer.step()


In [6]:
# GPU 설정
device_num = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#Model Initialization
#Saver_Net : mu = 0, sigma = log(1+exp(1))
#trainer_ner : mu = [-5,5], sigma = log(1+exp([-5,+5]))
saver_net = networks.BayesianNetwork(init_type = 'zero', DEVICE = device_num).to(device_num)
trainer_net = networks.BayesianNetwork(init_type = 'random', DEVICE = device_num).to(device_num)

optimizer = optim.Adam(saver_net.parameters())
optimizer = optim.Adam(trainer_net.parameters())

for epoch in range(10):

    #0. trainet_net variance init
    
    trainer_net.variance_init() #trainer net의 variance크게 init
    trainer_net = trainer_net.to(device_num)
    
    #1. trainet_net training 하는데 regularization을 위해서 saver_net의 정보 이용
    
    train(saver_net, trainer_net, optimizer, epoch, mini_batch_size, lambda_, device_num)

    #2. 1 batch가 끝나면 saver_net에 trainet_net을 복사 (weight = mean, sigma)
    
    saver_net = copy.deepcopy(trainer_net)
    



loss
tensor(1.5980e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.3020e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.1001e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.1140e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.4299e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.1075e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.3819e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.1357e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.5741e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.4960e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.1365e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.6115e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.6378e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.2260e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.1992e+08, device='cuda:0', grad_fn=<MeanBackward

loss
tensor(1.0817e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.3206e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.1867e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.4410e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.2242e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.4076e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.2670e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.3664e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.5156e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.1460e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.2696e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.2581e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.2778e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.2082e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.1821e+08, device='cuda:0', grad_fn=<MeanBackward

loss
tensor(1.1247e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(95095960., device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(93213016., device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.0723e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.4506e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.0165e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.0240e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.0012e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.2888e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.0022e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(97755584., device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.1978e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.1505e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(91543072., device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.1911e+08, device='cuda:0', grad_fn=<MeanBackward1>)


loss
tensor(1.0832e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(99774288., device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.0898e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(99440080., device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.1089e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.0657e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.3242e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.1608e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.1275e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.0196e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.0382e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(92992344., device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(1.0087e+08, device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(77828264., device='cuda:0', grad_fn=<MeanBackward1>)

loss
tensor(98287632., device='cuda:0', grad_fn=<MeanBackward1>)



RuntimeError: shape '[128, 1]' is invalid for input of size 96