In [None]:
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.init as init
import torch.nn.functional as F
import torch.utils as utils
import torchvision
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
import torchvision.datasets as dsets
from torch.utils.data import Dataset,DataLoader
from torchvision import models

import cv2

import matplotlib.pyplot as plt

import time
import copy
import random as rd
import sys
import os
os.chdir("/juhyeong/projects/2019연구학점제/")
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

#codes.py
sys.path.insert(0, '../')
import codes

device = torch.device('cuda')

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='../../data', train=True,
                                        download=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=256,
                                          shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='../../data', train=False,
                                       download=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=256,
                                         shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
examples = enumerate(testloader)
batch_idx, (example_data, example_targets) = next(examples)

fig = plt.figure()
for i in range(6):
    plt.subplot(2,3,i+1)
    plt.tight_layout()
    plt.imshow((example_data[i].transpose(0, 1).transpose(1, 2) + 1) / 2)
    plt.title("Ground Truth: {}".format(example_targets[i]))
    plt.xticks([])
    plt.yticks([])

In [None]:
class flatten(nn.Module):
    def __init__(self):
        super(flatten, self).__init__()

    def forward(self, x):
        shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, shape)

In [10]:
class BNN(nn.Module):
    def __init__(self):
        super(BNN, self).__init__()
        self.means = codes.resnet152(pretrained = True)
        self.vars = codes.resnet152(pretrained = False)
        self.sample_net = codes.resnet152(pretrained = False)
        
    def sample(self):
        for mean, var, sample in zip(self.means.modules(), self.vars.modules(), self.sample_net.modules()):
            if isinstance(sample, nn.Conv2d):
                sample.weight = torch.nn.Parameter(torch.randn(mean.weight.shape, device=device) * var.weight.exp() + mean.weight)
            elif isinstance(sample, nn.BatchNorm2d):
                sample.weight = torch.nn.Parameter(torch.randn(mean.weight.shape, device=device) * var.weight.exp() + mean.weight)
                sample.bias = torch.nn.Parameter(torch.randn(mean.bias.shape, device=device) * var.bias.exp() + mean.bias)
            elif isinstance(sample, nn.Linear):
                sample.bias = torch.nn.Parameter(torch.randn(mean.bias.shape, device=device) * var.bias.exp() + mean.bias)
            
        
    def forward(self, x):
        return self.sample_net(x)

model = BNN().cuda()

<h1>Training</h1>

In [11]:
epoch = 1000
lr = 3e-4
optimizer_vars = torch.optim.Adam(model.vars.parameters(), lr = lr)

In [12]:
criterion = nn.CrossEntropyLoss()

In [None]:
for run in range(epoch):
    start = time.time()
    
    #Training
    model.train()
    for ind, data in enumerate(trainloader):
        model.sample()
        optimizer_vars.zero_grad()
        img, label = data
        output = model(img.cuda())
        loss = criterion(output, label.cuda())
        loss.backward()
        optimizer_vars.step()
    
    #Test
    model.eval()
    with torch.no_grad():
        test_loss = 0.0
        for ind, data in enumerate(testloader):
            model.sample()
            img, label = data
            output = model(img.cuda())
            test_loss += criterion(output, label.cuda())
        test_loss /= (len(testloader) * 256)
    print("epoch : %d, test loss = %5.5f, time: %f sec" %(run, test_loss, time.time() - start))

epoch : 0, test loss = 19.55981, time: 67.989458 sec
epoch : 1, test loss = 17.75086, time: 68.327421 sec
epoch : 2, test loss = 19.77752, time: 68.443818 sec
epoch : 3, test loss = 19.99368, time: 61.357231 sec
