In [7]:
import time
import torch
import torch.nn as nn
import numpy as np
import module
from train import train, eval_, normal_grad, newton_grad
from dataset_wrapper import Cifar10Wrapper
import module
import matplotlib.pyplot as plt
import random


class Cifar10Model1(nn.Module):
    def __init__(self):
        super(Cifar10Model1, self).__init__()

        model = module.Sequential()
        # model.add_module('c1', nn.Conv2d(3, 33, 3, 2, 1))
        # model.add_module('relu1', nn.ReLU())
        # model.add_module('c2', nn.Conv2d(33, 64, 3, 2, 1))
        # model.add_module('relu2', nn.ReLU())
        # model.add_module('l3', nn.Conv2d(64, 128, 3, 2, 1))
        # model.add_module('relu3', nn.ReLU())
        # model.add_module('avg_pool', nn.AvgPool2d(4))

        model.add_module('c1', module.WnConv2d(3, 33, 3, 2, 1))
        model.add_module('relu1', nn.ReLU())
        model.add_module('c2', module.WnConv2d(33, 64, 3, 2, 1))
        model.add_module('relu2', nn.ReLU())
        model.add_module('l3', module.WnConv2d(64, 128, 3, 2, 1))
        model.add_module('relu3', nn.ReLU())
        model.add_module('avg_pool', nn.AvgPool2d(4))
        self.main = model
        # self.fc = module.WnLinear(128, 10)
        self.fc = nn.Linear(128, 10)

    def get_param_g(self):
        return self.main.get_param_g() + [self.fc.get_param_g()]

    def forward(self, x):
        y = self.main(x).squeeze()
        y = self.fc(y)
        return y


class Cifar10Model2(nn.Module):
    def __init__(self):
        super(Cifar10Model2, self).__init__()

        model = module.Sequential()
        model.add_module('c1', nn.Conv2d(3, 32, 3, 2, 1))
        model.add_module('relu1', nn.ReLU())
        model.add_module('c2', nn.Conv2d(32, 64, 3, 2, 1))
        model.add_module('relu2', nn.ReLU())
        model.add_module('l3', nn.Conv2d(64, 128, 3, 2, 1))
        model.add_module('relu3', nn.ReLU())
        model.add_module('avg_pool', nn.AvgPool2d(4))
        self.main = model
        self.fc = nn.Linear(128, 10)

    def get_param_g(self):
        return self.main.get_param_g() + [self.fc.get_param_g()]

    def forward(self, x):
        y = self.main(x).squeeze()
        y = self.fc(y)
        return y


def set_all_seeds(rand_seed):
    def large_randint():
        return random.randint(int(1e5), int(1e6))

    random.seed(rand_seed)
    np.random.seed(large_randint())
    torch.manual_seed(large_randint())
    torch.cuda.manual_seed(large_randint())


In [8]:
set_all_seeds(100009)

dataset = Cifar10Wrapper.load_default()
dataset.train_ys = dataset.train_ys.astype(np.int32).reshape((-1,))
dataset.test_ys = dataset.test_ys.astype(np.int32).reshape((-1,))
print dataset.train_ys.min()
print dataset.train_xs.min(), dataset.train_xs.max()
print dataset.train_xs.shape

Dataset loaded from /home/hengyuah/datasets/cifar10/cifar10.h5
0
-1.0 1.0
(50000, 3, 32, 32)


In [18]:
epochs = 20

wn_model = Cifar10Model1().cuda()
t = time.time()
_, wn_train_acc, wn_test_acc, wn_times = train(wn_model, dataset, normal_grad, 0.1, False, epochs)
print 'time:', time.time() - t

wn_kfac_model = Cifar10Model1().cuda()
t = time.time()
_, wn_kfac_train_acc, wn_kfac_test_acc, wn_kfac_times = train(wn_kfac_model, dataset, normal_grad, 0.1, True, epochs)
print 'time:', time.time() - t

model = Cifar10Model2().cuda()
t = time.time()
_, train_acc, test_acc, times = train(model, dataset, normal_grad, 0.1, False, epochs)
print 'time:', time.time() - t

kfac_model = Cifar10Model2().cuda()
t = time.time()
_, kfac_train_acc, kfac_test_acc, kfac_times = train(kfac_model, dataset, normal_grad, 0.1, True, epochs)
print 'time:', time.time() - t

epoch: 1, loss: 1.9483
accumulate time 1.42750406265
train acc: 0.34964
eval acc: 0.352
----------------
epoch: 2, loss: 1.6738
accumulate time 2.83571887016
train acc: 0.43434
eval acc: 0.4414
----------------
epoch: 3, loss: 1.5315
accumulate time 4.24497485161
train acc: 0.47878
eval acc: 0.4762
----------------
epoch: 4, loss: 1.4431
accumulate time 5.65098905563
train acc: 0.497
eval acc: 0.4922
----------------
epoch: 5, loss: 1.3759
accumulate time 7.06046795845
train acc: 0.53364
eval acc: 0.5238
----------------
epoch: 6, loss: 1.3169
accumulate time 8.46775484085
train acc: 0.52498
eval acc: 0.5107
----------------
epoch: 7, loss: 1.2687
accumulate time 9.8728659153
train acc: 0.56068
eval acc: 0.552
----------------
epoch: 8, loss: 1.2257
accumulate time 11.281414032
train acc: 0.58208
eval acc: 0.5681
----------------
epoch: 9, loss: 1.1870
accumulate time 12.6903269291
train acc: 0.56282
eval acc: 0.5432
----------------
epoch: 10, loss: 1.1522
accumulate time 14.106271982

In [29]:
x = range(len(wn_train_acc))
plt.figure()
plt.plot(x, wn_train_acc, 'r-', label='Weight Norm')
plt.plot(x, wn_kfac_train_acc, 'b-', label='Weight Norm K-FAC')
plt.plot(x, train_acc, 'g-', label='Normal')
plt.plot(x, kfac_train_acc, 'c-', label='Normal K-FAC')
plt.ylabel('Train Accuracy')
plt.xlabel('Epoch')
plt.legend(loc='lower right')


#plt.show()
plt.savefig('./plots/train-acc.png')

plt.figure()
plt.plot(x, wn_test_acc, 'r-', label='Weight Norm')
plt.plot(x, wn_kfac_test_acc, 'b-', label='Weight Norm K-FAC')
plt.plot(x, test_acc, 'g-', label='Normal')
plt.plot(x, kfac_test_acc, 'c-', label='Normal K-FAC')
plt.legend(loc='lower right')
plt.ylabel('Test Accuracy')
plt.xlabel('Epoch')

#plt.show()
plt.savefig('./plots/test-acc.png')

plt.figure()
plt.plot(wn_times, wn_test_acc, 'r-', label='Weight Norm')
plt.plot(wn_kfac_times, wn_kfac_test_acc, 'b-', label='Weight Norm K-FAC')
plt.plot(times, test_acc, 'g-', label='Normal')
plt.plot(kfac_times, kfac_test_acc, 'c-', label='Normal K-FAC')
plt.ylabel('Test Accuracy')
plt.xlabel('Time')
plt.legend(loc='lower right')

#plt.show()
plt.savefig('./plots/acc-time.png')

In [28]:
plt.figure()

plt.plot(x, wn_train_acc, 'r-', label='Train')
plt.plot(x, wn_test_acc, 'b-', label='Test')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(loc='upper left')
plt.title('Weight Norm')

#plt.show()
plt.savefig('./plots/wn.png')

plt.figure()

plt.plot(x, wn_kfac_train_acc, 'r-', label='Train')
plt.plot(x, wn_kfac_test_acc, 'b-', label='Test')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(loc='upper left')
plt.title('Weight Norm K-FAC')

#plt.show()
plt.savefig('./plots/wn-kfac.png')

plt.figure()

plt.plot(x, train_acc, 'r-', label='Train')
plt.plot(x, test_acc, 'b-', label='Test')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(loc='upper left')
plt.title('Normal')

#plt.show()
plt.savefig('./plots/normal.png')

plt.figure()

plt.plot(x, kfac_train_acc, 'r-', label='Train')
plt.plot(x, kfac_test_acc, 'b-', label='Test')
plt.ylabel('Accuracy')
plt.xlabel('Epoch')
plt.legend(loc='upper left')
plt.title('Normal K-FAC')

#plt.show()
plt.savefig('./plots/normal-kfac.png')
