In [1]:
# test file for train.py

In [2]:
import sys
sys.path.append('/home/felix/Research/Adversarial Research/FGN---Research/')
import Finite_Gaussian_Network_lib as fgnl
import Finite_Gaussian_Network_lib.fgn_helper_lib as fgnh

In [3]:
from __future__ import print_function

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [4]:
# random seeds
torch.manual_seed(665)
np.random.seed(326)

In [5]:
# simple model
# classical feed forward model with variable number of hidden layers and units per layer
class Classic_MNIST_Net(nn.Module):
    
    def __init__(self,):
        super(Classic_MNIST_Net, self).__init__()
        
        # one layer
        in_feats=28*28
        self.fl = nn.Linear(in_feats, 10)
        
    def forward(self, x):
        # squash the image
        x = x.view(-1, 28*28)
        x = self.fl(x)
        # softmax
        x = F.log_softmax(x, dim=-1)
        return x

In [6]:
# Define what device we are using
print("CUDA Available: ",torch.cuda.is_available())
use_cuda = True
device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")

CUDA Available:  True


In [7]:
# define model
model = Classic_MNIST_Net().to(device)

In [8]:
# dataloader declaration
batch_size = 10000
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('/home/felix/Research/Adversarial Research/MNIST-dataset', train=True, download=False, 
                   transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))
            ])), 
        batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('/home/felix/Research/Adversarial Research/MNIST-dataset', train=False, download=False, 
                   transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))
            ])), 
        batch_size=batch_size, shuffle=False)

In [9]:
# loss function
def loss_func(model, output, target):
    return F.nll_loss(output, target)

In [10]:
# optimizer
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

In [11]:
# epoch number
epochs = 3

In [12]:
# number of correct pred function
def pred_func(output, target):
    pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
    correct = pred.eq(target.view_as(pred)).sum().item()
    return correct

In [13]:
### test 1: run on same device
res = fgnh.train(model, train_loader, loss_func, optimizer, epochs, save_hist=2, verbose=True, pred_func=pred_func, test_loader=test_loader)

Epoch 0 Train set - Average loss: 2.0711, Accuracy: 17386/60000 (29%)
Test set - Average loss: 1.5468, Accuracy: 5902/10000 (59%)
Epoch 1 Train set - Average loss: 1.3447, Accuracy: 39769/60000 (66%)
Test set - Average loss: 1.0968, Accuracy: 7441/10000 (74%)
Epoch 2 Train set - Average loss: 1.0226, Accuracy: 45794/60000 (76%)
Test set - Average loss: 0.8932, Accuracy: 7981/10000 (80%)


In [14]:
# # Expected

# Epoch 0 Train set - Average loss: 2.0711, Accuracy: 17386/60000 (29%)
# Test set - Average loss: 1.5468, Accuracy: 5902/10000 (59%)
# Epoch 1 Train set - Average loss: 1.3447, Accuracy: 39769/60000 (66%)
# Test set - Average loss: 1.0968, Accuracy: 7441/10000 (74%)
# Epoch 2 Train set - Average loss: 1.0226, Accuracy: 45794/60000 (76%)
# Test set - Average loss: 0.8932, Accuracy: 7981/10000 (80%)

In [15]:
next(model.parameters()).device

device(type='cuda', index=0)

In [16]:
### test 2 :run on cpu instead of cuda
res = fgnh.train(model, train_loader, loss_func, optimizer, epochs, save_hist=2, verbose=True, pred_func=pred_func, test_loader=test_loader, device=torch.device('cpu'))

Epoch 0 Train set - Average loss: 0.8631, Accuracy: 48105/60000 (80%)
Test set - Average loss: 0.7783, Accuracy: 8226/10000 (82%)
Epoch 1 Train set - Average loss: 0.7674, Accuracy: 49345/60000 (82%)
Test set - Average loss: 0.7038, Accuracy: 8377/10000 (84%)
Epoch 2 Train set - Average loss: 0.7030, Accuracy: 50126/60000 (84%)
Test set - Average loss: 0.6513, Accuracy: 8479/10000 (85%)


In [17]:
next(model.parameters()).device

device(type='cpu')

In [18]:
# # Expected

# Warning: device specified. This might change model location (cuda<->cpu)
# Epoch 0 Train set - Average loss: 0.8631, Accuracy: 48105/60000 (80%)
# Warning: device specified. This might change model location (cuda<->cpu)
# Test set - Average loss: 0.7782, Accuracy: 8226/10000 (82%)
# Epoch 1 Train set - Average loss: 0.7674, Accuracy: 49345/60000 (82%)
# Warning: device specified. This might change model location (cuda<->cpu)
# Test set - Average loss: 0.7038, Accuracy: 8377/10000 (84%)
# Epoch 2 Train set - Average loss: 0.7030, Accuracy: 50126/60000 (84%)
# Warning: device specified. This might change model location (cuda<->cpu)
# Test set - Average loss: 0.6513, Accuracy: 8479/10000 (85%)

In [19]:
print(len(res))
print(res.keys())

5
['train_loss_hist', 'train_acc_hist', 'test_loss_hist', 'test_acc_hist', 'histories']


In [20]:
train_loss_hist = res['train_loss_hist']
print(len(train_loss_hist))

3


In [21]:
train_acc_hist = res['train_acc_hist']
print(len(train_acc_hist))

3


In [22]:
histories = res['histories']
for k in histories.keys():
    print(k)
    print(np.shape(histories[k]))

fl.bias
(19, 10)
fl.weight
(19, 10, 784)


In [23]:
print(histories)

{'fl.bias': array([[ 0.0317583 , -0.00306912,  0.00918836,  0.006821  ,  0.0169086 ,
         0.00161701,  0.01314674,  0.00244241, -0.00876717, -0.02979792],
       [ 0.03169271, -0.00306558,  0.00917751,  0.00681374,  0.01689558,
         0.0017463 ,  0.01312565,  0.00254749, -0.00887485, -0.02981035],
       [ 0.03164063, -0.00301919,  0.00917185,  0.00675552,  0.01692882,
         0.00192847,  0.01309301,  0.00261771, -0.00901529, -0.02985333],
       [ 0.03160081, -0.00297409,  0.00916784,  0.00670845,  0.01696748,
         0.00212249,  0.01306231,  0.00269219, -0.0091623 , -0.02993697],
       [ 0.03158792, -0.00299298,  0.00921473,  0.00666074,  0.01697877,
         0.00227977,  0.01305747,  0.00274772, -0.00927933, -0.03000659],
       [ 0.03155868, -0.00301797,  0.00926989,  0.00665513,  0.01696343,
         0.00239735,  0.01303432,  0.00284923, -0.00940312, -0.03005874],
       [ 0.03156726, -0.0030394 ,  0.00931849,  0.00665033,  0.01695803,
         0.00251349,  0.01299941,

In [24]:
# # Expected ? may be hard to get exact same numbers
# {'fl.bias': array([[ 0.0317583 , -0.00306912,  0.00918836,  0.006821  ,  0.0169086 ,
#          0.00161701,  0.01314674,  0.00244241, -0.00876717, -0.02979792],
#        [ 0.03169271, -0.00306558,  0.00917751,  0.00681374,  0.01689558,
#          0.0017463 ,  0.01312565,  0.00254749, -0.00887485, -0.02981035],
#        [ 0.03164063, -0.00301919,  0.00917185,  0.00675552,  0.01692882,
#          0.00192847,  0.01309301,  0.00261771, -0.00901529, -0.02985333],
#        [ 0.03160081, -0.00297409,  0.00916784,  0.00670845,  0.01696748,
#          0.00212249,  0.01306231,  0.00269219, -0.0091623 , -0.02993697],
#        [ 0.03158792, -0.00299298,  0.00921473,  0.00666074,  0.01697877,
#          0.00227977,  0.01305747,  0.00274772, -0.00927933, -0.03000659],
#        [ 0.03155868, -0.00301797,  0.00926989,  0.00665513,  0.01696342,
#          0.00239735,  0.01303432,  0.00284923, -0.00940312, -0.03005874],
#        [ 0.03156726, -0.0030394 ,  0.00931849,  0.00665033,  0.01695803,
#          0.00251349,  0.01299941,  0.00291684, -0.00951009, -0.03012616],
#        [ 0.03156338, -0.00305229,  0.00935068,  0.00665324,  0.01694838,
#          0.00262779,  0.01296424,  0.00300088, -0.00962061, -0.03018747],
#        [ 0.03154842, -0.00304772,  0.00936692,  0.00661394,  0.01692975,
#          0.00274667,  0.01294731,  0.00307006, -0.00971752, -0.03020962],
#        [ 0.0315061 , -0.00303468,  0.00935065,  0.00659341,  0.01697835,
#          0.00287742,  0.01291875,  0.00312745, -0.00981106, -0.03025818],
#        [ 0.03148526, -0.00303961,  0.00934593,  0.00659364,  0.01695795,
#          0.00298199,  0.01290194,  0.00320101, -0.00988235, -0.03029756],
#        [ 0.03144568, -0.0030544 ,  0.00936549,  0.00657824,  0.01692879,
#          0.0031133 ,  0.01292681,  0.00326856, -0.00997129, -0.03035296],
#        [ 0.03141832, -0.00306858,  0.00940121,  0.00656269,  0.01690525,
#          0.00328245,  0.01290004,  0.00334009, -0.01006637, -0.03042688],
#        [ 0.03137737, -0.0030847 ,  0.00943903,  0.00650129,  0.01685603,
#          0.00345934,  0.0128796 ,  0.00342903, -0.0101368 , -0.03047198],
#        [ 0.03134366, -0.00308979,  0.0094732 ,  0.00645276,  0.01683635,
#          0.00359652,  0.01285537,  0.00350935, -0.01022709, -0.03050212],
#        [ 0.03132072, -0.00308222,  0.00950977,  0.00642632,  0.01685657,
#          0.00370408,  0.01284076,  0.00355569, -0.01032062, -0.03056287],
#        [ 0.0313307 , -0.00308388,  0.00954992,  0.00644277,  0.01688461,
#          0.00377405,  0.0128127 ,  0.00358326, -0.01041259, -0.03063334],
#        [ 0.03133314, -0.00310814,  0.00953834,  0.00648684,  0.01686492,
#          0.00387385,  0.01280858,  0.00363231, -0.01050922, -0.03067241],
#        [ 0.03132744, -0.00312061,  0.00951506,  0.00653627,  0.01685102,
#          0.0039609 ,  0.0128044 ,  0.00369232, -0.01059169, -0.03072691]],
#       dtype=float32), 'fl.weight': array([[[-0.0269651 ,  0.01795503, -0.01738468, ...,  0.00805861,
#          -0.0257817 , -0.03346497],
#         [ 0.00992633,  0.02499002,  0.00231849, ...,  0.02847031,
#          -0.01568423, -0.01541279],
#         [ 0.0047365 ,  0.02722993, -0.01725454, ..., -0.01441996,
#          -0.02886991,  0.02304127],
#         ...,
#         [-0.02272519, -0.02736211,  0.02638081, ..., -0.02318391,
#          -0.01140144,  0.03451377],
#         [ 0.02698973, -0.03022965, -0.00040812, ...,  0.02894087,
#          -0.00703943,  0.02676886],
#         [ 0.0140695 ,  0.00637808, -0.00210427, ..., -0.02248027,
#          -0.00335751,  0.02665576]],

#        [[-0.02693727,  0.01798285, -0.01735685, ...,  0.00808643,
#          -0.02575388, -0.03343715],
#         [ 0.00992483,  0.02498852,  0.00231699, ...,  0.02846881,
#          -0.01568573, -0.01541429],
#         [ 0.0047411 ,  0.02723453, -0.01724994, ..., -0.01441535,
#          -0.02886531,  0.02304587],
#         ...,
#         [-0.02276976, -0.02740668,  0.02633623, ..., -0.02322849,
#          -0.01144602,  0.03446919],
#         [ 0.02703541, -0.03018397, -0.00036244, ...,  0.02898655,
#          -0.00699375,  0.02681454],
#         [ 0.01407477,  0.00638335, -0.00209901, ..., -0.022475  ,
#          -0.00335225,  0.02666103]],

#        [[-0.02691518,  0.01800495, -0.01733476, ...,  0.00810852,
#          -0.02573179, -0.03341505],
#         [ 0.00990515,  0.02496884,  0.00229731, ...,  0.02844913,
#          -0.01570541, -0.01543397],
#         [ 0.0047435 ,  0.02723693, -0.01724754, ..., -0.01441295,
#          -0.02886291,  0.02304827],
#         ...,
#         [-0.02279955, -0.02743647,  0.02630644, ..., -0.02325827,
#          -0.01147581,  0.03443941],
#         [ 0.02709499, -0.03012439, -0.00030286, ...,  0.02904613,
#          -0.00693417,  0.02687412],
#         [ 0.014093  ,  0.00640159, -0.00208077, ..., -0.02245677,
#          -0.00333401,  0.02667927]],

#        ...,

#        [[-0.0267837 ,  0.01813642, -0.01720328, ...,  0.00824   ,
#          -0.02560031, -0.03328357],
#         [ 0.00993259,  0.02499628,  0.00232475, ...,  0.02847657,
#          -0.01567797, -0.01540652],
#         [ 0.00458312,  0.02707655, -0.01740792, ..., -0.01457334,
#          -0.02902329,  0.02288789],
#         ...,
#         [-0.02320915, -0.02784607,  0.02589684, ..., -0.02366787,
#          -0.01188541,  0.03402981],
#         [ 0.02768774, -0.02953164,  0.00028989, ...,  0.02963888,
#          -0.00634142,  0.02746687],
#         [ 0.01442389,  0.00673248, -0.00174988, ..., -0.02212588,
#          -0.00300312,  0.02701016]],

#        [[-0.02678474,  0.01813539, -0.01720432, ...,  0.00823896,
#          -0.02560134, -0.03328461],
#         [ 0.00994288,  0.02500657,  0.00233504, ...,  0.02848686,
#          -0.01566768, -0.01539623],
#         [ 0.00458803,  0.02708146, -0.01740301, ..., -0.01456842,
#          -0.02901838,  0.0228928 ],
#         ...,
#         [-0.02322996, -0.02786688,  0.02587604, ..., -0.02368868,
#          -0.01190621,  0.03400901],
#         [ 0.02772873, -0.02949065,  0.00033088, ...,  0.02967988,
#          -0.00630043,  0.02750786],
#         [ 0.01444047,  0.00674905, -0.00173331, ..., -0.02210931,
#          -0.00298655,  0.02702673]],

#        [[-0.02678232,  0.01813781, -0.0172019 , ...,  0.00824138,
#          -0.02559893, -0.03328219],
#         [ 0.00994817,  0.02501186,  0.00234034, ...,  0.02849215,
#          -0.01566239, -0.01539094],
#         [ 0.00459791,  0.02709134, -0.01739313, ..., -0.01455855,
#          -0.0290085 ,  0.02290268],
#         ...,
#         [-0.02325542, -0.02789233,  0.02585058, ..., -0.02371414,
#          -0.01193167,  0.03398355],
#         [ 0.02776372, -0.02945566,  0.00036586, ...,  0.02971486,
#          -0.00626545,  0.02754284],
#         [ 0.01446358,  0.00677217, -0.00171019, ..., -0.02208619,
#          -0.00296343,  0.02704985]]], dtype=float32)}