In [1]:
import numpy as np 
import torch
from torch.autograd import Variable
import torch.nn as nn 
import torch.nn.functional as F 
import torchvision 
import torch.optim as optim 
from torchvision import transforms
from tqdm import *

In [2]:
# DEFINE NETWORK
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        # my network is composed of only affine layers 
        self.fc1 = nn.Linear(28*28, 300)
        self.fc2 = nn.Linear(300, 100)
        self.fc3 = nn.Linear(100, 10)   
        # self.r   = nn.Parameter(data=torch.randn(5,5), requires_grad=True)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [3]:
net = Net()
print(net)

Net (
  (fc1): Linear (784 -> 300)
  (fc2): Linear (300 -> 100)
  (fc3): Linear (100 -> 10)
)


In [7]:
SoftmaxWithXent = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-04)

In [4]:
# DATA LOADERS 
def flat_trans(x):
    x.resize_(28*28)
    return x
mnist_transform = transforms.Compose(
                    [transforms.ToTensor(), transforms.Lambda(flat_trans)]
                  )
traindata = torchvision.datasets.MNIST(root="./mnist", train=True, download=True, transform=mnist_transform)
trainloader = torch.utils.data.DataLoader(traindata, batch_size=256, shuffle=True, num_workers=2)
testdata  = torchvision.datasets.MNIST(root="./mnist", train=False, download=True, transform=mnist_transform)
testloader = torch.utils.data.DataLoader(testdata, batch_size=256, shuffle=True, num_workers=2)

In [8]:
# TRAIN 
for epoch in range(100):

    print("Epoch: {}".format(epoch))
    running_loss = 0.0 
    # import ipdb; ipdb.set_trace()
    for data in tqdm(trainloader):
        
        # get the inputs 
        inputs, labels = data 
        # wrap them in a variable 
        inputs, labels = Variable(inputs), Variable(labels)
        # zero the gradients 
        optimizer.zero_grad() 
        
        # forward + loss + backward 
        outputs = net(inputs) # forward pass 
        loss = SoftmaxWithXent(outputs, labels) # compute softmax -> loss 
        loss.backward() # get gradients on params 
        optimizer.step() # SGD update 

        # print statistics 
        running_loss += loss.data[0]

    print('Epoch: {} | Loss: {}'.format(epoch, running_loss/2000.0))

print("Finished Training")


  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 0


100%|██████████| 235/235 [00:02<00:00, 90.19it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 0 | Loss: 0.26510209834575654
Epoch: 1


100%|██████████| 235/235 [00:02<00:00, 89.16it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 1 | Loss: 0.24053700566291808
Epoch: 2


100%|██████████| 235/235 [00:02<00:00, 88.27it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 2 | Loss: 0.1729998619556427
Epoch: 3


100%|██████████| 235/235 [00:02<00:00, 88.38it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 3 | Loss: 0.10580456021428108
Epoch: 4


100%|██████████| 235/235 [00:02<00:00, 86.85it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 4 | Loss: 0.07702752785384655
Epoch: 5


100%|██████████| 235/235 [00:02<00:00, 94.40it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 5 | Loss: 0.06401847395300865
Epoch: 6


100%|██████████| 235/235 [00:02<00:00, 91.00it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 6 | Loss: 0.05675997690856457
Epoch: 7


100%|██████████| 235/235 [00:02<00:00, 89.04it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 7 | Loss: 0.05220143035054207
Epoch: 8


100%|██████████| 235/235 [00:02<00:00, 89.12it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 8 | Loss: 0.04889440377056599
Epoch: 9


100%|██████████| 235/235 [00:02<00:00, 84.31it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 9 | Loss: 0.04650342792272568
Epoch: 10


100%|██████████| 235/235 [00:02<00:00, 84.01it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 10 | Loss: 0.04453187927603722
Epoch: 11


100%|██████████| 235/235 [00:02<00:00, 82.95it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 11 | Loss: 0.04299124870449304
Epoch: 12


100%|██████████| 235/235 [00:02<00:00, 90.99it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 12 | Loss: 0.041677878201007844
Epoch: 13


100%|██████████| 235/235 [00:02<00:00, 86.69it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 13 | Loss: 0.04054376235604286
Epoch: 14


100%|██████████| 235/235 [00:02<00:00, 90.25it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 14 | Loss: 0.039399667359888556
Epoch: 15


100%|██████████| 235/235 [00:02<00:00, 88.92it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 15 | Loss: 0.03847125098854303
Epoch: 16


100%|██████████| 235/235 [00:02<00:00, 88.16it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 16 | Loss: 0.0376138912960887
Epoch: 17


100%|██████████| 235/235 [00:02<00:00, 87.08it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 17 | Loss: 0.03678096740692854
Epoch: 18


100%|██████████| 235/235 [00:02<00:00, 86.05it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 18 | Loss: 0.03601983071118593
Epoch: 19


100%|██████████| 235/235 [00:02<00:00, 87.80it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 19 | Loss: 0.03529281847923994
Epoch: 20


100%|██████████| 235/235 [00:02<00:00, 89.25it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 20 | Loss: 0.0345717043876648
Epoch: 21


100%|██████████| 235/235 [00:02<00:00, 90.09it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 21 | Loss: 0.03394990995526314
Epoch: 22


100%|██████████| 235/235 [00:02<00:00, 87.91it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 22 | Loss: 0.03331883928179741
Epoch: 23


100%|██████████| 235/235 [00:02<00:00, 91.94it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 23 | Loss: 0.03271557410806417
Epoch: 24


100%|██████████| 235/235 [00:02<00:00, 91.26it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 24 | Loss: 0.03214496745914221
Epoch: 25


100%|██████████| 235/235 [00:02<00:00, 93.51it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 25 | Loss: 0.031612935788929465
Epoch: 26


100%|██████████| 235/235 [00:02<00:00, 92.32it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 26 | Loss: 0.03103170207887888
Epoch: 27


100%|██████████| 235/235 [00:02<00:00, 91.27it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 27 | Loss: 0.030534479841589926
Epoch: 28


100%|██████████| 235/235 [00:02<00:00, 92.56it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 28 | Loss: 0.030111917696893215
Epoch: 29


100%|██████████| 235/235 [00:02<00:00, 92.44it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 29 | Loss: 0.029537266813218593
Epoch: 30


100%|██████████| 235/235 [00:02<00:00, 81.28it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 30 | Loss: 0.029061766885221003
Epoch: 31


100%|██████████| 235/235 [00:02<00:00, 86.13it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 31 | Loss: 0.028611181169748305
Epoch: 32


100%|██████████| 235/235 [00:02<00:00, 83.13it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 32 | Loss: 0.02819912164658308
Epoch: 33


100%|██████████| 235/235 [00:02<00:00, 84.85it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 33 | Loss: 0.02770839723944664
Epoch: 34


100%|██████████| 235/235 [00:02<00:00, 89.16it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 34 | Loss: 0.027307927526533604
Epoch: 35


100%|██████████| 235/235 [00:02<00:00, 83.25it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 35 | Loss: 0.026880054093897342
Epoch: 36


100%|██████████| 235/235 [00:02<00:00, 85.93it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 36 | Loss: 0.026493503049015998
Epoch: 37


100%|██████████| 235/235 [00:02<00:00, 86.02it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 37 | Loss: 0.026096525978296994
Epoch: 38


100%|██████████| 235/235 [00:02<00:00, 89.16it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 38 | Loss: 0.025685831509530545
Epoch: 39


100%|██████████| 235/235 [00:02<00:00, 86.99it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 39 | Loss: 0.025285733807832004
Epoch: 40


100%|██████████| 235/235 [00:02<00:00, 83.21it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 40 | Loss: 0.02490786326676607
Epoch: 41


100%|██████████| 235/235 [00:03<00:00, 78.25it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 41 | Loss: 0.02457058247178793
Epoch: 42


100%|██████████| 235/235 [00:02<00:00, 85.74it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 42 | Loss: 0.0242060499638319
Epoch: 43


100%|██████████| 235/235 [00:02<00:00, 85.60it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 43 | Loss: 0.023886387936770916
Epoch: 44


100%|██████████| 235/235 [00:02<00:00, 83.56it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 44 | Loss: 0.02353991438448429
Epoch: 45


100%|██████████| 235/235 [00:02<00:00, 92.18it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 45 | Loss: 0.02318738928064704
Epoch: 46


100%|██████████| 235/235 [00:02<00:00, 93.73it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 46 | Loss: 0.022866623654961588
Epoch: 47


100%|██████████| 235/235 [00:02<00:00, 94.16it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 47 | Loss: 0.022589462604373695
Epoch: 48


100%|██████████| 235/235 [00:02<00:00, 92.28it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 48 | Loss: 0.022265738792717458
Epoch: 49


100%|██████████| 235/235 [00:02<00:00, 90.84it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 49 | Loss: 0.022013390723615883
Epoch: 50


100%|██████████| 235/235 [00:02<00:00, 82.24it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 50 | Loss: 0.021684760563075542
Epoch: 51


100%|██████████| 235/235 [00:02<00:00, 91.29it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 51 | Loss: 0.021391395904123783
Epoch: 52


100%|██████████| 235/235 [00:02<00:00, 91.66it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 52 | Loss: 0.021101510643959046
Epoch: 53


100%|██████████| 235/235 [00:02<00:00, 91.08it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 53 | Loss: 0.020797466799616814
Epoch: 54


100%|██████████| 235/235 [00:02<00:00, 92.60it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 54 | Loss: 0.020576763160526753
Epoch: 55


100%|██████████| 235/235 [00:02<00:00, 91.58it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 55 | Loss: 0.020267052918672562
Epoch: 56


100%|██████████| 235/235 [00:02<00:00, 91.93it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 56 | Loss: 0.02000119899213314
Epoch: 57


100%|██████████| 235/235 [00:02<00:00, 92.36it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 57 | Loss: 0.01974423221498728
Epoch: 58


100%|██████████| 235/235 [00:02<00:00, 91.10it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 58 | Loss: 0.019510283313691618
Epoch: 59


100%|██████████| 235/235 [00:02<00:00, 90.74it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 59 | Loss: 0.01928082912415266
Epoch: 60


100%|██████████| 235/235 [00:02<00:00, 90.72it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 60 | Loss: 0.019033031221479178
Epoch: 61


100%|██████████| 235/235 [00:02<00:00, 89.96it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 61 | Loss: 0.01878956948593259
Epoch: 62


100%|██████████| 235/235 [00:02<00:00, 85.72it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 62 | Loss: 0.01856284584477544
Epoch: 63


100%|██████████| 235/235 [00:02<00:00, 79.55it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 63 | Loss: 0.01834191329777241
Epoch: 64


100%|██████████| 235/235 [00:02<00:00, 89.81it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 64 | Loss: 0.01811850141361356
Epoch: 65


100%|██████████| 235/235 [00:02<00:00, 92.33it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 65 | Loss: 0.017908441960811615
Epoch: 66


100%|██████████| 235/235 [00:02<00:00, 90.55it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 66 | Loss: 0.0176704741679132
Epoch: 67


100%|██████████| 235/235 [00:02<00:00, 91.26it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 67 | Loss: 0.017521689131855964
Epoch: 68


100%|██████████| 235/235 [00:02<00:00, 92.95it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 68 | Loss: 0.017243793427944184
Epoch: 69


100%|██████████| 235/235 [00:02<00:00, 85.52it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 69 | Loss: 0.01706268846616149
Epoch: 70


100%|██████████| 235/235 [00:02<00:00, 86.40it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 70 | Loss: 0.01686530910432339
Epoch: 71


100%|██████████| 235/235 [00:02<00:00, 87.19it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 71 | Loss: 0.016666311677545308
Epoch: 72


100%|██████████| 235/235 [00:02<00:00, 92.90it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 72 | Loss: 0.01647274700179696
Epoch: 73


100%|██████████| 235/235 [00:02<00:00, 89.44it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 73 | Loss: 0.01628171836771071
Epoch: 74


100%|██████████| 235/235 [00:02<00:00, 87.56it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 74 | Loss: 0.016079639868810773
Epoch: 75


100%|██████████| 235/235 [00:03<00:00, 77.75it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 75 | Loss: 0.015897211652249096
Epoch: 76


100%|██████████| 235/235 [00:02<00:00, 79.22it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 76 | Loss: 0.015712793178856373
Epoch: 77


100%|██████████| 235/235 [00:02<00:00, 91.24it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 77 | Loss: 0.015555827017873525
Epoch: 78


100%|██████████| 235/235 [00:03<00:00, 64.73it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 78 | Loss: 0.015382702829316258
Epoch: 79


100%|██████████| 235/235 [00:02<00:00, 85.75it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 79 | Loss: 0.015258693866431713
Epoch: 80


100%|██████████| 235/235 [00:02<00:00, 88.64it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 80 | Loss: 0.015066837957128883
Epoch: 81


100%|██████████| 235/235 [00:02<00:00, 86.06it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 81 | Loss: 0.014875676104798913
Epoch: 82


100%|██████████| 235/235 [00:02<00:00, 87.82it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 82 | Loss: 0.014742665207013488
Epoch: 83


100%|██████████| 235/235 [00:02<00:00, 90.03it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 83 | Loss: 0.01455264101549983
Epoch: 84


100%|██████████| 235/235 [00:02<00:00, 92.54it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 84 | Loss: 0.014405724147334695
Epoch: 85


100%|██████████| 235/235 [00:02<00:00, 93.56it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 85 | Loss: 0.01423355701379478
Epoch: 86


100%|██████████| 235/235 [00:02<00:00, 92.56it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 86 | Loss: 0.014113330921158195
Epoch: 87


100%|██████████| 235/235 [00:02<00:00, 89.00it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 87 | Loss: 0.013999794153496623
Epoch: 88


100%|██████████| 235/235 [00:02<00:00, 89.13it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 88 | Loss: 0.013806370962411165
Epoch: 89


100%|██████████| 235/235 [00:02<00:00, 86.60it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 89 | Loss: 0.01367637206055224
Epoch: 90


100%|██████████| 235/235 [00:02<00:00, 89.34it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 90 | Loss: 0.013544081555679441
Epoch: 91


100%|██████████| 235/235 [00:02<00:00, 93.48it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 91 | Loss: 0.01340012520365417
Epoch: 92


100%|██████████| 235/235 [00:02<00:00, 96.47it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 92 | Loss: 0.013241797314956784
Epoch: 93


100%|██████████| 235/235 [00:02<00:00, 87.03it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 93 | Loss: 0.013086641211993993
Epoch: 94


100%|██████████| 235/235 [00:02<00:00, 87.61it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 94 | Loss: 0.012960964800789952
Epoch: 95


100%|██████████| 235/235 [00:02<00:00, 90.16it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 95 | Loss: 0.01282974653504789
Epoch: 96


100%|██████████| 235/235 [00:02<00:00, 93.26it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 96 | Loss: 0.012720767145976425
Epoch: 97


100%|██████████| 235/235 [00:02<00:00, 91.84it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 97 | Loss: 0.012593369944021106
Epoch: 98


100%|██████████| 235/235 [00:02<00:00, 91.09it/s]
  0%|          | 0/235 [00:00<?, ?it/s]

Epoch: 98 | Loss: 0.012487378928810359
Epoch: 99


100%|██████████| 235/235 [00:02<00:00, 93.11it/s]

Epoch: 99 | Loss: 0.012336424391716718
Finished Training





In [10]:
# TEST 
correct = 0.0 
total = 0 
for data in testloader:
    images, labels = data 
    outputs = net(Variable(images))
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0) 
    correct += (predicted == labels).sum()

print("Accuracy: {}".format(correct/total))

print("Dumping weights to disk")
weights_dict = {} 
# import ipdb; ipdb.set_trace()
for param in list(net.named_parameters()):
    print("Serializing Param", param[0])
    weights_dict[param[0]] = param[1] 
with open("weights.pkl","wb") as f:
    import pickle 
    pickle.dump(weights_dict, f)
print("Finished dumping to disk..")


Accuracy: 0.9643
Dumping weights to disk
Serializing Param fc1.weight
Serializing Param fc1.bias
Serializing Param fc2.weight
Serializing Param fc2.bias
Serializing Param fc3.weight
Serializing Param fc3.bias
Finished dumping to disk..
