In [1]:
# test file for train.py

In [2]:
from train import train

In [3]:
from __future__ import print_function

import numpy as np

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [4]:
# random seeds
torch.manual_seed(665)
np.random.seed(326)

In [5]:
# simply model
# classical feed forward model with variable number of hidden layers and units per layer
class Classic_MNIST_Net(nn.Module):
    
    def __init__(self,):
        super(Classic_MNIST_Net, self).__init__()
        
        # one layer
        in_feats=28*28
        self.fl = nn.Linear(in_feats, 10)
        
    def forward(self, x):
        # squash the image
        x = x.view(-1, 28*28)
        x = self.fl(x)
        # softmax
        x = F.log_softmax(x, dim=-1)
        return x

In [6]:
# Device: cpu only for reproducibility
device = torch.device('cpu')

In [7]:
# define model
model = Classic_MNIST_Net().to(device)

In [8]:
# dataloader declaration
batch_size = 10000
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('/home/felix/Research/Adversarial Research/MNIST-dataset', train=True, download=False, 
                   transform=transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))
            ])), 
        batch_size=batch_size, shuffle=True)

In [9]:
# loss function
def loss_func(model, output, target):
    return F.nll_loss(output, target)

In [10]:
# optimizer
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.5)

In [11]:
# epoch number
epochs = 3

In [12]:
# number of correct pred function
def pred_func(output, target):
    pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
    correct = pred.eq(target.view_as(pred)).sum().item()
    return correct

In [13]:
### test 1:
res = train(model, device, train_loader, loss_func, optimizer, epochs, save_hist=2, verbose=True, pred_func=pred_func)

Train set: Average loss: 2.0711
Train set: Accuracy: 17386/60000 (29%)
Train set: Average loss: 1.3446
Train set: Accuracy: 39720/60000 (66%)
Train set: Average loss: 1.0228
Train set: Accuracy: 45769/60000 (76%)


In [14]:
# # Expected

# Train set: Average loss: 2.0711
# Train set: Accuracy: 17386/60000 (29%)
# Train set: Average loss: 1.3446
# Train set: Accuracy: 39720/60000 (66%)
# Train set: Average loss: 1.0228
# Train set: Accuracy: 45769/60000 (76%)


In [15]:
print(len(res))

3


In [16]:
train_loss_hist = res[0]
print(len(train_loss_hist))

3


In [17]:
train_acc_hist = res[1]
print(len(train_acc_hist))

3


In [18]:
histories = res[2]
for k in histories.keys():
    print(k)
    print(np.shape(histories[k]))

fl.bias
(19, 10)
fl.weight
(19, 10, 784)


In [19]:
print(histories)

{'fl.bias': array([[ 3.32987383e-02, -7.23120803e-03,  1.09616853e-02,
         6.93267584e-03,  1.65758226e-02, -8.95091798e-04,
         1.36601590e-02,  1.71544414e-03, -6.28255401e-03,
        -2.84874607e-02],
       [ 3.32987383e-02, -7.23120803e-03,  1.09616853e-02,
         6.93267584e-03,  1.65758226e-02, -8.95091798e-04,
         1.36601590e-02,  1.71544414e-03, -6.28255401e-03,
        -2.84874607e-02],
       [ 3.35246138e-02, -6.76682405e-03,  1.02818683e-02,
         7.13171065e-03,  1.65404957e-02, -8.33678932e-04,
         1.37796151e-02,  1.38693477e-03, -6.09182427e-03,
        -2.87047010e-02],
       [ 3.37206200e-02, -6.17475109e-03,  9.76385735e-03,
         7.26846326e-03,  1.64015424e-02, -7.47575599e-04,
         1.38603402e-02,  1.17455260e-03, -6.05171919e-03,
        -2.89671216e-02],
       [ 3.37153301e-02, -5.57578960e-03,  9.36043449e-03,
         7.33028213e-03,  1.63653549e-02, -6.56490680e-04,
         1.39077390e-02,  1.09073217e-03, -6.15823409e-03,

In [20]:
# # Expected 

# {'fl.bias': array([[ 3.32987383e-02, -7.23120803e-03,  1.09616853e-02,
#          6.93267584e-03,  1.65758226e-02, -8.95091798e-04,
#          1.36601590e-02,  1.71544414e-03, -6.28255401e-03,
#         -2.84874607e-02],
#        [ 3.32987383e-02, -7.23120803e-03,  1.09616853e-02,
#          6.93267584e-03,  1.65758226e-02, -8.95091798e-04,
#          1.36601590e-02,  1.71544414e-03, -6.28255401e-03,
#         -2.84874607e-02],
#        [ 3.35246138e-02, -6.76682405e-03,  1.02818683e-02,
#          7.13171065e-03,  1.65404957e-02, -8.33678932e-04,
#          1.37796151e-02,  1.38693477e-03, -6.09182427e-03,
#         -2.87047010e-02],
#        [ 3.37206200e-02, -6.17475109e-03,  9.76385735e-03,
#          7.26846326e-03,  1.64015424e-02, -7.47575599e-04,
#          1.38603402e-02,  1.17455260e-03, -6.05171919e-03,
#         -2.89671216e-02],
#        [ 3.37153301e-02, -5.57578960e-03,  9.36043449e-03,
#          7.33028213e-03,  1.63653549e-02, -6.56490680e-04,
#          1.39077390e-02,  1.09073217e-03, -6.15823409e-03,
#         -2.91311517e-02],
#        [ 3.36031877e-02, -5.02974400e-03,  9.16777831e-03,
#          7.30608916e-03,  1.63662881e-02, -5.73740632e-04,
#          1.38828484e-02,  1.12209178e-03, -6.36733603e-03,
#         -2.92292554e-02],
#        [ 3.33827250e-02, -4.54422366e-03,  9.08527710e-03,
#          7.29148835e-03,  1.63733810e-02, -4.76341142e-04,
#          1.38396751e-02,  1.22213317e-03, -6.60288846e-03,
#         -2.93230191e-02],
#        [ 3.30995917e-02, -4.17986605e-03,  9.11149476e-03,
#          7.25759333e-03,  1.64205637e-02, -3.63420433e-04,
#          1.37507198e-02,  1.34955288e-03, -6.83148578e-03,
#         -2.93665361e-02],
#        [ 3.28034498e-02, -3.89976054e-03,  9.18810349e-03,
#          7.19734840e-03,  1.65005494e-02, -1.97555535e-04,
#          1.36497617e-02,  1.48659886e-03, -7.10712094e-03,
#         -2.93731689e-02],
#        [ 3.25609073e-02, -3.69920279e-03,  9.20961797e-03,
#          7.14928238e-03,  1.66032892e-02,  1.69511384e-06,
#          1.35825314e-02,  1.61727204e-03, -7.36445421e-03,
#         -2.94127315e-02],
#        [ 3.23976353e-02, -3.54894344e-03,  9.22704116e-03,
#          7.08612800e-03,  1.67069882e-02,  1.69297753e-04,
#          1.34913605e-02,  1.75150228e-03, -7.60145951e-03,
#         -2.94313449e-02],
#        [ 3.22500393e-02, -3.38843977e-03,  9.24094114e-03,
#          7.01091625e-03,  1.67447850e-02,  3.70958151e-04,
#          1.34350937e-02,  1.88395521e-03, -7.79730687e-03,
#         -2.95027383e-02],
#        [ 3.21539268e-02, -3.24989366e-03,  9.20390896e-03,
#          6.93051750e-03,  1.67500861e-02,  5.91045828e-04,
#          1.34097682e-02,  1.99336372e-03, -7.96716753e-03,
#         -2.95673534e-02],
#        [ 3.20581198e-02, -3.15387081e-03,  9.19702463e-03,
#          6.88329339e-03,  1.67654864e-02,  7.78601388e-04,
#          1.34023214e-02,  2.06977408e-03, -8.14949162e-03,
#         -2.96030547e-02],
#        [ 3.19693238e-02, -3.10127437e-03,  9.19494405e-03,
#          6.83795055e-03,  1.67612601e-02,  9.82798636e-04,
#          1.33397384e-02,  2.13854923e-03, -8.24747980e-03,
#         -2.96276063e-02],
#        [ 3.18851210e-02, -3.09206359e-03,  9.22248047e-03,
#          6.79697888e-03,  1.67632680e-02,  1.19806919e-03,
#          1.32813193e-02,  2.23172572e-03, -8.36262945e-03,
#         -2.96760667e-02],
#        [ 3.18162255e-02, -3.09924828e-03,  9.24357586e-03,
#          6.79276092e-03,  1.68183465e-02,  1.38163066e-03,
#          1.32113155e-02,  2.32186145e-03, -8.49305093e-03,
#         -2.97452156e-02],
#        [ 3.17836069e-02, -3.10719362e-03,  9.22747143e-03,
#          6.79307850e-03,  1.68920662e-02,  1.52541790e-03,
#          1.31858150e-02,  2.39731511e-03, -8.66772234e-03,
#         -2.97816563e-02],
#        [ 3.17559876e-02, -3.08256340e-03,  9.19474382e-03,
#          6.78084092e-03,  1.69516094e-02,  1.64731813e-03,
#          1.31553337e-02,  2.47736787e-03, -8.79980810e-03,
#         -2.98326313e-02]], dtype=float32), 'fl.weight': array([[[-0.02761857,  0.01730156, -0.01803815, ...,  0.00740513,
#          -0.02643518, -0.03411844],
#         [ 0.01169194,  0.02675563,  0.0040841 , ...,  0.03023592,
#          -0.01391861, -0.01364717],
#         [ 0.00398423,  0.02647766, -0.01800681, ..., -0.01517222,
#          -0.02962218,  0.022289  ],
#         ...,
#         [-0.0224168 , -0.02705372,  0.0266892 , ..., -0.02287552,
#          -0.01109305,  0.03482215],
#         [ 0.02593573, -0.03128365, -0.00146213, ...,  0.02788687,
#          -0.00809344,  0.02571486],
#         [ 0.01351358,  0.00582217, -0.00266019, ..., -0.02303619,
#          -0.00391343,  0.02609985]],

#        [[-0.02761857,  0.01730156, -0.01803815, ...,  0.00740513,
#          -0.02643518, -0.03411844],
#         [ 0.01169194,  0.02675563,  0.0040841 , ...,  0.03023592,
#          -0.01391861, -0.01364717],
#         [ 0.00398423,  0.02647766, -0.01800681, ..., -0.01517222,
#          -0.02962218,  0.022289  ],
#         ...,
#         [-0.0224168 , -0.02705372,  0.0266892 , ..., -0.02287552,
#          -0.01109305,  0.03482215],
#         [ 0.02593573, -0.03128365, -0.00146213, ...,  0.02788687,
#          -0.00809344,  0.02571486],
#         [ 0.01351358,  0.00582217, -0.00266019, ..., -0.02303619,
#          -0.00391343,  0.02609985]],

#        [[-0.02771439,  0.01720574, -0.01813397, ...,  0.00730931,
#          -0.026531  , -0.03421426],
#         [ 0.01149494,  0.02655863,  0.00388711, ...,  0.03003892,
#          -0.01411561, -0.01384417],
#         [ 0.00427262,  0.02676605, -0.01771842, ..., -0.01488384,
#          -0.02933379,  0.02257739],
#         ...,
#         [-0.02227744, -0.02691436,  0.02682855, ..., -0.02273616,
#          -0.0109537 ,  0.03496151],
#         [ 0.02585482, -0.03136456, -0.00154304, ...,  0.02780596,
#          -0.00817434,  0.02563395],
#         [ 0.01360574,  0.00591433, -0.00256803, ..., -0.02294403,
#          -0.00382127,  0.02619201]],

#        ...,

#        [[-0.02698967,  0.01793046, -0.01740925, ...,  0.00803403,
#          -0.02580627, -0.03348954],
#         [ 0.00993911,  0.0250028 ,  0.00233127, ...,  0.02848309,
#          -0.01567145, -0.0154    ],
#         [ 0.00471308,  0.0272065 , -0.01727797, ..., -0.01444338,
#          -0.02889334,  0.02301785],
#         ...,
#         [-0.02267405, -0.02731097,  0.02643195, ..., -0.02313277,
#          -0.0113503 ,  0.0345649 ],
#         [ 0.02687345, -0.03034594, -0.00052441, ...,  0.02882459,
#          -0.00715571,  0.02665257],
#         [ 0.01404714,  0.00635573, -0.00212663, ..., -0.02250263,
#          -0.00337987,  0.0266334 ]],

#        [[-0.02697583,  0.0179443 , -0.01739541, ...,  0.00804787,
#          -0.02579243, -0.03347571],
#         [ 0.00994248,  0.02500617,  0.00233464, ...,  0.02848646,
#          -0.01566808, -0.01539663],
#         [ 0.00471991,  0.02721334, -0.01727113, ..., -0.01443655,
#          -0.0288865 ,  0.02302468],
#         ...,
#         [-0.02270606, -0.02734298,  0.02639994, ..., -0.02316478,
#          -0.01138231,  0.0345329 ],
#         [ 0.02694754, -0.03027184, -0.00045031, ...,  0.02889869,
#          -0.00708162,  0.02672667],
#         [ 0.0140626 ,  0.00637118, -0.00211118, ..., -0.02248717,
#          -0.00336442,  0.02664886]],

#        [[-0.02696411,  0.01795601, -0.01738369, ...,  0.00805958,
#          -0.02578072, -0.03346399],
#         [ 0.00993203,  0.02499572,  0.00232419, ...,  0.02847601,
#          -0.01567852, -0.01540708],
#         [ 0.00473379,  0.02722722, -0.01725725, ..., -0.01442267,
#          -0.02887262,  0.02303856],
#         ...,
#         [-0.02274002, -0.02737693,  0.02636598, ..., -0.02319874,
#          -0.01141627,  0.03449894],
#         [ 0.02700358, -0.03021581, -0.00039428, ...,  0.02895472,
#          -0.00702558,  0.0267827 ],
#         [ 0.01408422,  0.00639281, -0.00208955, ..., -0.02246555,
#          -0.00334279,  0.02667048]]], dtype=float32)}