In [None]:
'''
Change your code in such a way that all of these are in their respective files:
model
training code
testing code
regularization techniques (dropout, L1, L2, etc)
dataloader/transformations/image-augmentations
misc items like finding misclassified images
'''

In [1]:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

from torchsummary import summary

# Let's visualize some of the images
#%matplotlib inline
import matplotlib.pyplot as plt
import argparse

from torch.optim.lr_scheduler import StepLR,OneCycleLR

from train import *
from test import *
from model import *
from plotter import *
from data import *
# from model_group_norm import *
# from model_layer_norm import *

#from parser_args import norm, epochs

In [2]:
norm='bn'
epochs=25

In [3]:
SEED = 1

# CUDA?
cuda = torch.cuda.is_available()
print("CUDA Available?", cuda)

# For reproducibility
torch.manual_seed(SEED)

if cuda:
    torch.cuda.manual_seed(SEED)



# train dataloader
train_loader = load_train()

# test dataloader
test_loader = load_test()

CUDA Available? True


In [4]:
# Printing the summary of the model
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print("Using: ",device)

if norm == 'bn':
	print("Loading Batchnorm Model")
	model = Net().to(device)
elif norm == 'group':
	print("Loading Group Model")
	model = Net_group_norm().to(device)

elif norm == 'layer':
	print("Loading layer Model")
	model = Net_layer_norm().to(device)

#model = Net().to(device)
#model = Net_group_norm().to(device)
#model = Net_layer_norm().to(device)

model.apply(weights_init)
summary(model, input_size=(1, 28, 28))


Using:  cuda
Loading Batchnorm Model
----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
            Conv2d-1            [-1, 8, 26, 26]              72
              ReLU-2            [-1, 8, 26, 26]               0
           Dropout-3            [-1, 8, 26, 26]               0
            Conv2d-4           [-1, 16, 24, 24]           1,152
              ReLU-5           [-1, 16, 24, 24]               0
       BatchNorm2d-6           [-1, 16, 24, 24]              32
           Dropout-7           [-1, 16, 24, 24]               0
         MaxPool2d-8           [-1, 16, 12, 12]               0
            Conv2d-9            [-1, 8, 12, 12]             128
           Conv2d-10           [-1, 10, 10, 10]             720
             ReLU-11           [-1, 10, 10, 10]               0
      BatchNorm2d-12           [-1, 10, 10, 10]              20
          Dropout-13           [-1, 10, 10, 10]               0
  

In [5]:
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

scheduler = OneCycleLR(optimizer, max_lr=0.020,epochs=epochs,steps_per_epoch=len(train_loader))


for epoch in range(epochs):
    print("EPOCH:", epoch+1)
    bn_train_losses, bn_train_acc = train(model, device, train_loader, optimizer, epoch)
    bn_test_losses, bn_test_acc = test(model, device, test_loader)

  0%|          | 0/469 [00:00<?, ?it/s]

EPOCH: 1


Loss=13.369365692138672 Batch_id=0 Accuracy=9.38:   0%|          | 0/469 [00:00<?, ?it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 2.5878, Accuracy: 985/10000 (9.85%)

EPOCH: 2


Loss=13.329697608947754 Batch_id=0 Accuracy=4.69:   0%|          | 0/469 [00:00<?, ?it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 2.4938, Accuracy: 936/10000 (9.36%)

EPOCH: 3


Loss=13.102058410644531 Batch_id=0 Accuracy=12.50:   0%|          | 0/469 [00:00<?, ?it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 2.4394, Accuracy: 945/10000 (9.45%)

EPOCH: 4


Loss=13.30330753326416 Batch_id=0 Accuracy=8.59:   0%|          | 0/469 [00:00<?, ?it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 2.4073, Accuracy: 984/10000 (9.84%)

EPOCH: 5


Loss=13.305551528930664 Batch_id=0 Accuracy=3.91:   0%|          | 0/469 [00:00<?, ?it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 2.3876, Accuracy: 854/10000 (8.54%)

EPOCH: 6


Loss=13.123869895935059 Batch_id=0 Accuracy=8.59:   0%|          | 0/469 [00:00<?, ?it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 2.3776, Accuracy: 732/10000 (7.32%)

EPOCH: 7


Loss=13.180496215820312 Batch_id=0 Accuracy=7.81:   0%|          | 0/469 [00:00<?, ?it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 2.3726, Accuracy: 653/10000 (6.53%)

EPOCH: 8


Loss=13.053777694702148 Batch_id=0 Accuracy=13.28:   0%|          | 0/469 [00:00<?, ?it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 2.3716, Accuracy: 626/10000 (6.26%)

EPOCH: 9


Loss=13.52599811553955 Batch_id=0 Accuracy=9.38:   0%|          | 0/469 [00:00<?, ?it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 2.3728, Accuracy: 600/10000 (6.00%)

EPOCH: 10


Loss=13.415253639221191 Batch_id=0 Accuracy=7.81:   0%|          | 0/469 [00:00<?, ?it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 2.3756, Accuracy: 600/10000 (6.00%)

EPOCH: 11


Loss=13.293859481811523 Batch_id=0 Accuracy=8.59:   0%|          | 0/469 [00:00<?, ?it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 2.3803, Accuracy: 571/10000 (5.71%)

EPOCH: 12


Loss=13.03773021697998 Batch_id=0 Accuracy=9.38:   0%|          | 0/469 [00:00<?, ?it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 2.3859, Accuracy: 614/10000 (6.14%)

EPOCH: 13


Loss=13.047112464904785 Batch_id=0 Accuracy=11.72:   0%|          | 0/469 [00:00<?, ?it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 2.3932, Accuracy: 662/10000 (6.62%)

EPOCH: 14


Loss=13.297542572021484 Batch_id=0 Accuracy=8.59:   0%|          | 0/469 [00:00<?, ?it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 2.4011, Accuracy: 667/10000 (6.67%)

EPOCH: 15


Loss=13.302099227905273 Batch_id=0 Accuracy=2.34:   0%|          | 0/469 [00:00<?, ?it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 2.4093, Accuracy: 697/10000 (6.97%)

EPOCH: 16


Loss=12.962848663330078 Batch_id=0 Accuracy=7.81:   0%|          | 0/469 [00:00<?, ?it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 2.4176, Accuracy: 713/10000 (7.13%)

EPOCH: 17


Loss=13.14380931854248 Batch_id=0 Accuracy=13.28:   0%|          | 0/469 [00:00<?, ?it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 2.4267, Accuracy: 760/10000 (7.60%)

EPOCH: 18


Loss=13.103028297424316 Batch_id=0 Accuracy=10.94:   0%|          | 0/469 [00:00<?, ?it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 2.4342, Accuracy: 789/10000 (7.89%)

EPOCH: 19


Loss=13.27099609375 Batch_id=0 Accuracy=9.38:   0%|          | 0/469 [00:00<?, ?it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 2.4449, Accuracy: 809/10000 (8.09%)

EPOCH: 20


Loss=13.125944137573242 Batch_id=0 Accuracy=7.03:   0%|          | 0/469 [00:00<?, ?it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 2.4553, Accuracy: 830/10000 (8.30%)

EPOCH: 21


Loss=13.291780471801758 Batch_id=0 Accuracy=10.16:   0%|          | 0/469 [00:00<?, ?it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 2.4652, Accuracy: 851/10000 (8.51%)

EPOCH: 22


Loss=13.27049446105957 Batch_id=0 Accuracy=10.16:   0%|          | 0/469 [00:00<?, ?it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 2.4774, Accuracy: 893/10000 (8.93%)

EPOCH: 23


Loss=13.178642272949219 Batch_id=0 Accuracy=11.72:   0%|          | 0/469 [00:00<?, ?it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 2.4899, Accuracy: 906/10000 (9.06%)

EPOCH: 24


Loss=13.291738510131836 Batch_id=0 Accuracy=10.16:   0%|          | 0/469 [00:00<?, ?it/s]
  0%|          | 0/469 [00:00<?, ?it/s]


Test set: Average loss: 2.5010, Accuracy: 923/10000 (9.23%)

EPOCH: 25


Loss=13.105228424072266 Batch_id=0 Accuracy=9.38:   0%|          | 0/469 [00:00<?, ?it/s]



Test set: Average loss: 2.5147, Accuracy: 932/10000 (9.32%)

