In [1]:
import torch as t
import torchvision
import torchvision.transforms as transforms
import numpy as np
import torch.nn.functional as F
import matplotlib.pylab as plt
from torch.nn.utils import clip_grad_value_
%matplotlib inline

from torch.autograd import Variable
import torch.nn as nn

import argparse

import torch.optim as optim

from primary_net import PrimaryNetwork

from torchvision import datasets
import tqdm
import os
import json

In [2]:
device = 'cuda' # cuda or cpu
device = t.device(device)
if device == 'cuda':
    t.backends.cudnn.deterministic = True
    t.backends.cudnn.benchmark = False

In [3]:
batch_size = 128
prior_sigma = 1.0 # априорная дисперсия
epoch_num = 250 #количество эпох
lamb = [0.01, 0.1, 1,  10, 100]
start_num = 5

lambda_encode = lambda x : (t.log(x) + 4.6052)/(4.6052+ 4.6052)
lambda_sample_num = 5
path_to_save = 'saved_cifar_2'

if not os.path.exists(path_to_save):
    os.mkdir(path_to_save)
    
learning_rate = 0.002
weight_decay = 0.0005
milestones = [168000, 336000, 400000, 450000, 550000, 600000]
max_iter = 1000000


In [4]:
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(root='../data', train=True,
                                        download=True, transform=transform_train)
trainloader = t.utils.data.DataLoader(trainset, batch_size=128,
                                          shuffle=True, num_workers=4)

testset = torchvision.datasets.CIFAR10(root='../data', train=False,
                                       download=True, transform=transform_test)
testloader = t.utils.data.DataLoader(testset, batch_size=128,
                                         shuffle=False, num_workers=4)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [5]:
lamb = [0.01, 0.1, 1,  10, 100]

In [6]:
def test_acc(net): # точность классификации
    acc = []    
    net.eval()
    lamb =  [0.01, 0.1, 1,  10, 100]
    for l in lamb:
        correct = 0
        for x,y in testloader: 
            x = x.to(device)
            y = y.to(device)  
            out = net(x,l)    
            correct += out.argmax(1).eq(y).sum().cpu().numpy()
            t.cuda.empty_cache()
        acc.append(correct / len(testset))
        t.cuda.empty_cache()
    net.train()
    return acc


In [7]:
def train_batches(net, loss_fn, optimizer, lam, label):
    tq = tqdm.tqdm(trainloader)
    losses = []
    for x,y in tq:            
        x = x.to(device)
        y = y.to(device)          
        optimizer.zero_grad()  
        loss = 0
        if lam is None:
            
            for _ in range(lambda_sample_num):  
                p = t.rand(1).to(device)*4 -2
                lam_param = 10**p[0]                
                #t.rand(1).to(device)[0]*100.0                  
                out = net(x, lambda_encode(lam_param))
                loss = loss + loss_fn(out, y)/lambda_sample_num
                loss += net.KLD(lambda_encode(lam_param))*lam_param/len(trainset)/lambda_sample_num
                #loss += net.KLD(lam_param)*t.log(lam_param)/len(trainset)/lambda_sample_num
                losses+=[loss.cpu().detach().numpy()]       
        tq.set_description(label+str(np.mean(losses)))
        loss.backward()       
        clip_grad_value_(net.parameters(), 1.0) # для стабильности градиента. С этим можно играться
        optimizer.step()
        
        #lr_scheduler.step()
    acc = test_acc(net)
    print (acc)
    return acc

In [None]:
t.manual_seed(0)
for start in range(start_num):         
    net = PrimaryNetwork(prior_sigma = prior_sigma, device = device)
    net = net.to(device)
    optim = t.optim.Adam(net.parameters(), lr=1e-4)
    #lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer=optimizer, milestones=milestones, gamma=0.5)
    loss_fn = nn.CrossEntropyLoss().to(device)  
    with open('acc.log', 'w') as out:
        pass
    for e in range(epoch_num):
        label = 'CIFAR, epoch {}: '.format(e)                
        acc = train_batches(net, loss_fn, optim, None, label)
        with open('acc.log', 'a') as out:
            out.write('{}:{}\n'.format(e, acc))
        t.save(net.state_dict(), os.path.join(path_to_save, 'cifar_epoch_{}.cpk'.format( e)))
    t.save(net.state_dict(), os.path.join(path_to_save, 'cifar_start_{}.cpk'.format( start)))

CIFAR, epoch 0: 1.7368293: 100%|██████████| 391/391 [06:43<00:00,  1.03s/it]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.0981, 0.0982, 0.0992, 0.1, 0.1]


CIFAR, epoch 1: 1.4371885: 100%|██████████| 391/391 [06:48<00:00,  1.04s/it]
  0%|          | 0/391 [00:00<?, ?it/s]

[0.1, 0.1, 0.1, 0.1, 0.1]


CIFAR, epoch 2: 1.4448034:   8%|▊         | 32/391 [00:33<06:10,  1.03s/it]

In [None]:
optim.zero_grad()

In [None]:
loss = net.KLD(lambda_encode(t.tensor(1.0)))/len(trainset)/lambda_sample_num

In [None]:
loss.backward()

In [None]:
self = net
self.h1_eps = t.distributions.Normal(t.zeros_like((w1_mean_all), device=self.device),
                                               t.ones_like(w1_sigma_all, device=self.device)*self.prior_sigma)
self.h2_eps = t.distributions.Normal(t.zeros_like((w2_mean_all), device=self.device),
                                               t.ones_like(w2_sigma_all, device=self.device)*self.prior_sigma)