In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

In [2]:
train_ds = torchvision.datasets.MNIST(
    root='../data',
    train=True,
    transform=transforms.ToTensor(),
    download=True
)
train_ds

Dataset MNIST
    Number of datapoints: 60000
    Root location: ../data
    Split: Train

In [3]:
test_ds = torchvision.datasets.MNIST(
    root='../data',
    train=False,
    transform=transforms.ToTensor()
)
test_ds

Dataset MNIST
    Number of datapoints: 10000
    Root location: ../data
    Split: Test

In [4]:
train_dl = torch.utils.data.DataLoader(train_ds,32,True)
test_dl = torch.utils.data.DataLoader(test_ds,32,False)

In [5]:
input_size = train_ds[0][0].numel()

In [6]:
train_ds[0][0].size()

torch.Size([1, 28, 28])

In [17]:
class Model(nn.Module):
    def __init__(self,num_classes):
        super(Model,self).__init__()
        self.cnn1 = nn.Conv2d(1,16,kernel_size=3,stride=2,padding=1) # 28/2=14
        self.cnn2 = nn.Conv2d(16,32,kernel_size=3,stride=2,padding=1) # 14/2=8
        self.cnn3 = nn.Conv2d(32,64,kernel_size=3,stride=2,padding=1) # 8/2=4
        self.cnn4 = nn.Conv2d(64,32,kernel_size=3,stride=2,padding=1) # 4/2=2
        self.pool = nn.AdaptiveAvgPool2d(1)
        self.classifier = nn.Linear(32,num_classes)
        self.relu = nn.ReLU()
        
    def forward(self,x):
        x = self.cnn1(x)
        x = self.relu(x)
        x = self.cnn2(x)
        x = self.relu(x)
        x = self.cnn3(x)
        x = self.relu(x)
        x = self.cnn4(x)
        x = self.relu(x)
        x = self.pool(x)
        x = x.view(x.size(0),-1)
#         print(x.size())
        x = self.classifier(x)
        return x

In [18]:
model = Model(len(train_ds.classes))
model

Model(
  (cnn1): Conv2d(1, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (cnn2): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (cnn3): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (cnn4): Conv2d(64, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (pool): AdaptiveAvgPool2d(output_size=1)
  (classifier): Linear(in_features=32, out_features=10, bias=True)
  (relu): ReLU()
)

In [19]:
images = next(iter(train_dl))

In [20]:
model(images[0]);

In [21]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)

In [22]:
from tqdm import tqdm_notebook as tqdm

In [23]:
def one_epoch_man(model,train_dl,test_dl):
    model = model.cuda()
    for i,(images,labels) in enumerate(tqdm(train_dl)):
        model.train()
        images,labels = images.cuda(),labels.cuda()
        optimizer.zero_grad()

        preds = model(images)

        loss = criterion(preds,labels)

        loss.backward()
        optimizer.step()
        if(i%500==0):
            correct = 0
            total = 0
            print(f'loss: {loss.item()}')
            for img,y in test_dl:
                model.eval()
                img,y = img.cuda(),y.cuda()
                preds = model(img)
                preds = preds.argmax(1)
                total += preds.size(0)
                correct += (preds==y).sum()
            print(f'{correct} / {total}')

In [None]:
for i in range(10):
    one_epoch_man(model,train_dl,test_dl)

HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 2.306312084197998
1135 / 10000
loss: 0.22519510984420776
9275 / 10000
loss: 0.09997265040874481
9500 / 10000
loss: 0.09409777820110321
9629 / 10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.1963108479976654
9655 / 10000
loss: 0.05183495581150055
9675 / 10000
loss: 0.25725218653678894
9727 / 10000
loss: 0.059354037046432495
9739 / 10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.11649468541145325
9755 / 10000
loss: 0.1441713571548462
9796 / 10000
loss: 0.036565572023391724
9766 / 10000
loss: 0.15716519951820374
9805 / 10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.03235571086406708
9802 / 10000
loss: 0.0048889219760894775
9794 / 10000
loss: 0.006922826170921326
9807 / 10000
loss: 0.12965995073318481
9802 / 10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.02122536301612854
9836 / 10000
loss: 0.012179389595985413
9853 / 10000
loss: 0.007911339402198792
9852 / 10000
loss: 0.131360724568367
9809 / 10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.030271992087364197
9838 / 10000
loss: 0.1358337700366974
9787 / 10000
loss: 0.04358801245689392
9831 / 10000
loss: 0.024636223912239075
9816 / 10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.00821542739868164
9831 / 10000
loss: 0.00041037797927856445
9826 / 10000
loss: 0.00013162195682525635
9802 / 10000
loss: 0.02342340350151062
9836 / 10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.08107158541679382
9805 / 10000
loss: 0.0002721250057220459
9840 / 10000
loss: 0.01000213623046875
9849 / 10000
loss: 0.02457377314567566
9838 / 10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.10464093089103699
9836 / 10000
loss: 0.05562080442905426
9850 / 10000
loss: 0.011946946382522583
9844 / 10000
loss: 0.03361567109823227
9852 / 10000


HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.0024401098489761353
9864 / 10000
