In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torch.nn.functional as F

In [2]:
pre_processing = transforms.Compose([
#         transforms.RandomResizedCrop(28),
#         transforms.ColorJitter(brightness=0.5,contrast=0.5,saturation=0.3),
#         transforms.RandomHorizontalFlip(),
        transforms.ToTensor()
])

In [3]:
train_ds = torchvision.datasets.MNIST(
    root='../data',
    train=True,
    transform=pre_processing,
    download=True
)
train_ds

Dataset MNIST
    Number of datapoints: 60000
    Root location: ../data
    Split: Train

In [4]:
test_ds = torchvision.datasets.MNIST(
    root='../data',
    train=False,
    transform=transforms.ToTensor()
)
test_ds

Dataset MNIST
    Number of datapoints: 10000
    Root location: ../data
    Split: Test

In [5]:
train_dl = torch.utils.data.DataLoader(train_ds,32,True)
test_dl = torch.utils.data.DataLoader(test_ds,32,False)

In [6]:
input_size = train_ds[0][0].numel()

In [7]:
train_ds[0][0].size()

torch.Size([1, 28, 28])

In [8]:
def conv(ni,nf,ks=3,stride=1,bias=False):
    return nn.Conv2d(ni,nf,kernel_size=ks,stride=stride,padding=ks//2,bias=bias)

In [9]:
act_fn = nn.ReLU(inplace=True)

In [10]:
def conv_layer(ni,nf,ks=3,stride=1,zero_bn=False,act=True):
    bn = nn.BatchNorm2d(nf)
    nn.init.constant_(bn.weight,0. if zero_bn else 1.)
    layers = [conv(ni,nf,stride=stride),conv(nf,nf,stride=2),bn]
    if act: layers.append(act_fn)
    return nn.Sequential(*layers)

In [11]:
conv_layer(1,16)

Sequential(
  (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
  (1): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
  (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (3): ReLU(inplace)
)

In [12]:
class Model(nn.Module):
    def __init__(self,num_classes):
        super(Model,self).__init__()
        self.cnn1 = conv_layer(1,16) # 28/2=14
        self.cnn2 = conv_layer(16,32) # 14/2=8
        self.cnn3 = conv_layer(32,64) # 8/2=4
        self.cnn4 = conv_layer(64,128) # 4/2=2
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(128,num_classes)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout()
        
    def forward(self,x):
        x = self.cnn1(x)
#         x = self.dropout(x)
        x = self.cnn2(x)
#         x = self.dropout(x)
        x = self.cnn3(x)
#         x = self.dropout(x)
        x = self.cnn4(x)
#         x = self.dropout(x)
        x = self.pool(x)
        x = x.view(x.size(0),-1)
#         print(x.size())
        x = self.classifier(x)
        x = F.softmax(x,1)
        return x

In [13]:
model = Model(len(train_ds.classes))
model

Model(
  (cnn1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU(inplace)
  )
  (cnn2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): Conv2d(32, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU(inplace)
  )
  (cnn3): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (1): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): ReLU(inplace)
  )
  (cnn4): Sequential(
    (0): Conv2d(64, 128, kernel_size

In [14]:
images = next(iter(train_dl))

In [15]:
model(images[0]);

In [16]:
# criterion = nn.MSELoss()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)

In [17]:
from tqdm import tqdm_notebook as tqdm

In [18]:
def one_epoch_man(model,train_dl,test_dl):
    model = model.cuda()
    for i,(images,labels) in enumerate(tqdm(train_dl)):
        model.train()
        images,labels = images.cuda(),labels.cuda()
        optimizer.zero_grad()

        preds = model(images)

        loss = criterion(preds,labels)

        loss.backward()
        optimizer.step()
        if(i%500==0):
            correct = 0
            total = 0
            print(f'loss: {loss.item()}')
    for img,y in test_dl:
        model.eval()
        img,y = img.cuda(),y.cuda()
        preds = model(img)
        preds = preds.argmax(1)
        total += preds.size(0)
        correct += (preds==y).sum()
    print(f'{correct} / {total}')

In [None]:
for i in range(10):
    one_epoch_man(model,train_dl,test_dl)

HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 2.307800769805908
loss: 1.4669852256774902
loss: 1.52872896194458
loss: 1.5193616151809692

9793 / 10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 1.5167657136917114
loss: 1.5006762742996216
loss: 1.5315839052200317
loss: 1.5372432470321655

9853 / 10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 1.4878425598144531
loss: 1.4626083374023438
loss: 1.4801675081253052
loss: 1.497884750366211

9848 / 10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 1.5096697807312012
loss: 1.4635233879089355
loss: 1.4831551313400269
loss: 1.527153491973877

9865 / 10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 1.4822602272033691
loss: 1.5363680124282837
loss: 1.4924169778823853
loss: 1.498084545135498

9854 / 10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 1.5010157823562622
loss: 1.4611588716506958
loss: 1.465572714805603
loss: 1.4958045482635498

9919 / 10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 1.4611512422561646
loss: 1.4629374742507935
loss: 1.4924473762512207
loss: 1.4614546298980713

9895 / 10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 1.5046603679656982
loss: 1.4611519575119019
loss: 1.4746084213256836
