In [1]:
# test file for train.py

In [2]:
from train import train

In [3]:
from __future__ import print_function

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [4]:
# random seeds
torch.manual_seed(665)
np.random.seed(326)

In [5]:
# simple model
# classical feed forward model with variable number of hidden layers and units per layer
class Classic_MNIST_Net(nn.Module):
    
    def __init__(self,):
        super(Classic_MNIST_Net, self).__init__()
        
        # one layer
        in_feats=28*28
        self.fl = nn.Linear(in_feats, 10)
        
    def forward(self, x):
        # squash the image
        x = x.view(-1, 28*28)
        x = self.fl(x)
        # softmax
        x = F.log_softmax(x, dim=-1)
        return x

In [6]:
# Define what device we are using
print("CUDA Available: ",torch.cuda.is_available())
use_cuda = True
device = torch.device("cuda" if (use_cuda and torch.cuda.is_available()) else "cpu")

CUDA Available:  True


In [7]:
# define model
model = Classic_MNIST_Net().to(device)

In [8]:
# dataloader declaration
batch_size = 10000
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('/home/felix/Research/Adversarial Research/MNIST-dataset', train=True, download=False, 
                   transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))
            ])), 
        batch_size=batch_size, shuffle=True)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('/home/felix/Research/Adversarial Research/MNIST-dataset', train=False, download=False, 
                   transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))
            ])), 
        batch_size=batch_size, shuffle=False)

In [9]:
# loss function
def loss_func(model, output, target):
    return F.nll_loss(output, target)

In [10]:
# optimizer
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

In [11]:
# epoch number
epochs = 3

In [12]:
# number of correct pred function
def pred_func(output, target):
    pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
    correct = pred.eq(target.view_as(pred)).sum().item()
    return correct

In [13]:
### test 1:
res = train(model, device, train_loader, loss_func, optimizer, epochs, save_hist=2, verbose=True, pred_func=pred_func, test_loader=test_loader)

Epoch 0 Train set - Average loss: 2.0711, Accuracy: 17386/60000 (29%)
Test set - Average loss: 0.0002, Accuracy: 5902/10000 (59%)
Epoch 1 Train set - Average loss: 1.3447, Accuracy: 39769/60000 (66%)
Test set - Average loss: 0.0001, Accuracy: 7441/10000 (74%)
Epoch 2 Train set - Average loss: 1.0226, Accuracy: 45794/60000 (76%)
Test set - Average loss: 0.0001, Accuracy: 7981/10000 (80%)


In [14]:
# # Expected

# Epoch 0 Train set - Average loss: 2.0711, Accuracy: 17386/60000 (29%)
# Test set - Average loss: 0.0002, Accuracy: 5902/10000 (59%)
# Epoch 1 Train set - Average loss: 1.3447, Accuracy: 39769/60000 (66%)
# Test set - Average loss: 0.0001, Accuracy: 7441/10000 (74%)
# Epoch 2 Train set - Average loss: 1.0226, Accuracy: 45794/60000 (76%)
# Test set - Average loss: 0.0001, Accuracy: 7981/10000 (80%)

In [15]:
print(len(res))
print(res.keys())

5
['train_loss_hist', 'train_acc_hist', 'test_loss_hist', 'test_acc_hist', 'histories']


In [16]:
train_loss_hist = res['train_loss_hist']
print(len(train_loss_hist))

3


In [17]:
train_acc_hist = res['train_acc_hist']
print(len(train_acc_hist))

3


In [18]:
histories = res['histories']
for k in histories.keys():
    print(k)
    print(np.shape(histories[k]))

fl.bias
(19, 10)
fl.weight
(19, 10, 784)


In [19]:
print(histories)

{'fl.bias': array([[ 3.30802016e-02, -7.54990801e-03,  1.15686804e-02,
         6.81053102e-03,  1.66102387e-02, -9.75035131e-04,
         1.35433599e-02,  1.97258219e-03, -6.46133162e-03,
        -2.83511076e-02],
       [ 3.32987383e-02, -7.23120803e-03,  1.09616853e-02,
         6.93267584e-03,  1.65758226e-02, -8.95091856e-04,
         1.36601590e-02,  1.71544345e-03, -6.28255401e-03,
        -2.84874607e-02],
       [ 3.35246138e-02, -6.76682405e-03,  1.02818692e-02,
         7.13171111e-03,  1.65404957e-02, -8.33678932e-04,
         1.37796151e-02,  1.38693373e-03, -6.09182473e-03,
        -2.87047010e-02],
       [ 3.37206200e-02, -6.17475063e-03,  9.76385828e-03,
         7.26846373e-03,  1.64015424e-02, -7.47575657e-04,
         1.38603402e-02,  1.17455143e-03, -6.05171965e-03,
        -2.89671216e-02],
       [ 3.37153301e-02, -5.57578960e-03,  9.36043542e-03,
         7.33028259e-03,  1.63653549e-02, -6.56490913e-04,
         1.39077390e-02,  1.09073101e-03, -6.15823455e-03,

In [20]:
# # Expected ? may be hard to get exact same numbers

# {'fl.bias': array([[ 3.30802016e-02, -7.54990801e-03,  1.15686804e-02,
#          6.81053102e-03,  1.66102387e-02, -9.75035131e-04,
#          1.35433599e-02,  1.97258219e-03, -6.46133162e-03,
#         -2.83511076e-02],
#        [ 3.32987383e-02, -7.23120803e-03,  1.09616853e-02,
#          6.93267584e-03,  1.65758226e-02, -8.95091856e-04,
#          1.36601590e-02,  1.71544345e-03, -6.28255401e-03,
#         -2.84874607e-02],
#        [ 3.35246138e-02, -6.76682405e-03,  1.02818692e-02,
#          7.13171111e-03,  1.65404957e-02, -8.33678932e-04,
#          1.37796151e-02,  1.38693373e-03, -6.09182473e-03,
#         -2.87047010e-02],
#        [ 3.37206200e-02, -6.17475063e-03,  9.76385828e-03,
#          7.26846373e-03,  1.64015424e-02, -7.47575657e-04,
#          1.38603402e-02,  1.17455143e-03, -6.05171965e-03,
#         -2.89671216e-02],
#        [ 3.37153301e-02, -5.57578960e-03,  9.36043542e-03,
#          7.33028259e-03,  1.63653549e-02, -6.56490913e-04,
#          1.39077390e-02,  1.09073101e-03, -6.15823455e-03,
#         -2.91311517e-02],
#        [ 3.36031839e-02, -5.02974354e-03,  9.16777924e-03,
#          7.30608962e-03,  1.63662881e-02, -5.73740981e-04,
#          1.38828484e-02,  1.12209062e-03, -6.36733649e-03,
#         -2.92292535e-02],
#        [ 3.33827212e-02, -4.54422319e-03,  9.08527803e-03,
#          7.29148882e-03,  1.63733810e-02, -4.76341491e-04,
#          1.38396751e-02,  1.22213177e-03, -6.60288846e-03,
#         -2.93230172e-02],
#        [ 3.31147946e-02, -4.13689064e-03,  9.05423239e-03,
#          7.24545959e-03,  1.64349154e-02, -3.39392980e-04,
#          1.37555692e-02,  1.40491198e-03, -6.88442029e-03,
#         -2.94009708e-02],
#        [ 3.28243934e-02, -3.80651280e-03,  9.10748262e-03,
#          7.19439890e-03,  1.65162701e-02, -2.07743185e-04,
#          1.36944819e-02,  1.50851230e-03, -7.15497555e-03,
#         -2.94281021e-02],
#        [ 3.25809531e-02, -3.58860637e-03,  9.15370136e-03,
#          7.13093020e-03,  1.65710170e-02, -4.26878869e-05,
#          1.36209792e-02,  1.64037093e-03, -7.39150960e-03,
#         -2.94269435e-02],
#        [ 3.23986970e-02, -3.43345827e-03,  9.19918623e-03,
#          7.02731311e-03,  1.65903680e-02,  1.86262696e-04,
#          1.35570318e-02,  1.79030304e-03, -7.59520801e-03,
#         -2.94722915e-02],
#        [ 3.22443321e-02, -3.34258121e-03,  9.22231469e-03,
#          7.00453296e-03,  1.66846029e-02,  3.78808356e-04,
#          1.34809557e-02,  1.85851101e-03, -7.78895989e-03,
#         -2.94943135e-02],
#        [ 3.21359597e-02, -3.28371208e-03,  9.21868626e-03,
#          6.96988916e-03,  1.67470351e-02,  5.52189711e-04,
#          1.34067256e-02,  1.95969245e-03, -7.92022515e-03,
#         -2.95380354e-02],
#        [ 3.20568271e-02, -3.25440825e-03,  9.24321730e-03,
#          6.91054249e-03,  1.68044157e-02,  7.43567653e-04,
#          1.33412303e-02,  2.05963012e-03, -8.08411278e-03,
#         -2.95727029e-02],
#        [ 3.19695100e-02, -3.23011167e-03,  9.24087968e-03,
#          6.83414610e-03,  1.68623682e-02,  9.71414440e-04,
#          1.32981492e-02,  2.16782233e-03, -8.23879614e-03,
#         -2.96271760e-02],
#        [ 3.19090448e-02, -3.17251147e-03,  9.20837466e-03,
#          6.79868087e-03,  1.68935917e-02,  1.18574023e-03,
#          1.32460361e-02,  2.25504744e-03, -8.37189239e-03,
#         -2.97039058e-02],
#        [ 3.18209343e-02, -3.14599765e-03,  9.21025872e-03,
#          6.81078807e-03,  1.69065818e-02,  1.35088956e-03,
#          1.32041685e-02,  2.32337625e-03, -8.48955475e-03,
#         -2.97432374e-02],
#        [ 3.17889675e-02, -3.09890183e-03,  9.20089427e-03,
#          6.83333678e-03,  1.69086009e-02,  1.48432131e-03,
#          1.31568331e-02,  2.37838528e-03, -8.62084981e-03,
#         -2.97833793e-02],
#        [ 3.17582972e-02, -3.06912046e-03,  9.18835867e-03,
#          6.82099862e-03,  1.69085953e-02,  1.61701499e-03,
#          1.31467422e-02,  2.44241301e-03, -8.76716617e-03,
#         -2.97979247e-02]], dtype=float32), 'fl.weight': array([[[-0.02752586,  0.01739426, -0.01794544, ...,  0.00749784,
#          -0.02634247, -0.03402574],
#         [ 0.01182714,  0.02689083,  0.0042193 , ...,  0.03037112,
#          -0.01378342, -0.01351198],
#         [ 0.00372674,  0.02622016, -0.01826431, ..., -0.01542972,
#          -0.02987968,  0.0220315 ],
#         ...,
#         [-0.02252588, -0.0271628 ,  0.02658011, ..., -0.0229846 ,
#          -0.01120214,  0.03471307],
#         [ 0.02601157, -0.03120781, -0.00138629, ...,  0.02796271,
#          -0.0080176 ,  0.0257907 ],
#         [ 0.01345574,  0.00576433, -0.00271803, ..., -0.02309403,
#          -0.00397127,  0.02604201]],

#        [[-0.02761857,  0.01730156, -0.01803815, ...,  0.00740513,
#          -0.02643518, -0.03411844],
#         [ 0.01169194,  0.02675563,  0.0040841 , ...,  0.03023592,
#          -0.01391861, -0.01364717],
#         [ 0.00398423,  0.02647766, -0.01800681, ..., -0.01517222,
#          -0.02962218,  0.022289  ],
#         ...,
#         [-0.0224168 , -0.02705372,  0.0266892 , ..., -0.02287552,
#          -0.01109305,  0.03482215],
#         [ 0.02593573, -0.03128365, -0.00146213, ...,  0.02788687,
#          -0.00809344,  0.02571486],
#         [ 0.01351358,  0.00582217, -0.00266019, ..., -0.02303619,
#          -0.00391343,  0.02609985]],

#        [[-0.02771439,  0.01720574, -0.01813397, ...,  0.00730931,
#          -0.026531  , -0.03421426],
#         [ 0.01149494,  0.02655863,  0.00388711, ...,  0.03003892,
#          -0.01411561, -0.01384417],
#         [ 0.00427262,  0.02676605, -0.01771842, ..., -0.01488384,
#          -0.02933379,  0.02257739],
#         ...,
#         [-0.02227744, -0.02691436,  0.02682855, ..., -0.02273616,
#          -0.0109537 ,  0.03496151],
#         [ 0.02585482, -0.03136456, -0.00154304, ...,  0.02780596,
#          -0.00817434,  0.02563395],
#         [ 0.01360574,  0.00591433, -0.00256803, ..., -0.02294403,
#          -0.00382127,  0.02619201]],

#        ...,

#        [[-0.02699167,  0.01792846, -0.01741125, ...,  0.00803203,
#          -0.02580827, -0.03349154],
#         [ 0.00995894,  0.02502263,  0.00235111, ...,  0.02850292,
#          -0.01565161, -0.01538017],
#         [ 0.00472721,  0.02722064, -0.01726383, ..., -0.01442925,
#          -0.0288792 ,  0.02303198],
#         ...,
#         [-0.02267469, -0.02731161,  0.0264313 , ..., -0.02313342,
#          -0.01135095,  0.03456426],
#         [ 0.02687196, -0.03034742, -0.00052589, ...,  0.02882311,
#          -0.0071572 ,  0.02665109],
#         [ 0.0140463 ,  0.00635489, -0.00212747, ..., -0.02250347,
#          -0.00338071,  0.02663256]],

#        [[-0.02697811,  0.01794202, -0.01739769, ...,  0.00804559,
#          -0.02579471, -0.03347798],
#         [ 0.00993896,  0.02500265,  0.00233113, ...,  0.02848295,
#          -0.01567159, -0.01540015],
#         [ 0.00473118,  0.02722461, -0.01725986, ..., -0.01442527,
#          -0.02887523,  0.02303595],
#         ...,
#         [-0.02269803, -0.02733495,  0.02640797, ..., -0.02315675,
#          -0.01137428,  0.03454093],
#         [ 0.02692766, -0.03029172, -0.00047019, ...,  0.0288788 ,
#          -0.0071015 ,  0.02670679],
#         [ 0.01406333,  0.00637191, -0.00211044, ..., -0.02248644,
#          -0.00336368,  0.02664959]],

#        [[-0.0269651 ,  0.01795503, -0.01738468, ...,  0.00805861,
#          -0.0257817 , -0.03346497],
#         [ 0.00992633,  0.02499002,  0.00231849, ...,  0.02847031,
#          -0.01568423, -0.01541279],
#         [ 0.0047365 ,  0.02722993, -0.01725454, ..., -0.01441996,
#          -0.02886991,  0.02304127],
#         ...,
#         [-0.02272519, -0.02736211,  0.02638081, ..., -0.02318391,
#          -0.01140144,  0.03451377],
#         [ 0.02698973, -0.03022965, -0.00040812, ...,  0.02894087,
#          -0.00703943,  0.02676886],
#         [ 0.0140695 ,  0.00637808, -0.00210427, ..., -0.02248027,
#          -0.00335751,  0.02665576]]], dtype=float32)}