In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

In [2]:
train_ds = torchvision.datasets.MNIST(
    root='../data',
    train=True,
    transform=transforms.ToTensor(),
    download=True
)
train_ds

Dataset MNIST
    Number of datapoints: 60000
    Root location: ../data
    Split: Train

In [3]:
test_ds = torchvision.datasets.MNIST(
    root='../data',
    train=False,
    transform=transforms.ToTensor()
)
test_ds

Dataset MNIST
    Number of datapoints: 10000
    Root location: ../data
    Split: Test

In [4]:
train_dl = torch.utils.data.DataLoader(train_ds,32,True)
test_dl = torch.utils.data.DataLoader(test_ds,32,False)

In [5]:
input_size = train_ds[0][0].numel()

In [6]:
class Model(nn.Module):
    def __init__(self,input_size,hidden_size,num_classes):
        super(Model,self).__init__()
        self.fc1 = nn.Linear(input_size,hidden_size)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_size,num_classes)
        
    def forward(self,x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

In [7]:
model = Model(input_size,512,len(train_ds.classes))
model

Model(
  (fc1): Linear(in_features=784, out_features=512, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=512, out_features=10, bias=True)
)

In [8]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)

In [9]:
from tqdm import tqdm_notebook as tqdm

In [10]:
def one_epoch_man(model,train_dl,test_dl):
    model = model.cuda()
    for i,(images,labels) in enumerate(tqdm(train_dl)):
        model.train()
        images,labels = images.cuda(),labels.cuda()
        optimizer.zero_grad()

        preds = model(images.view(images.size(0),-1))

        loss = criterion(preds,labels)

        loss.backward()
        optimizer.step()
        if(i%500==0):
            correct = 0
            total = 0
            print(f'loss: {loss.item()}')
            for img,y in test_dl:
                model.eval()
                img,y = img.cuda(),y.cuda()
                preds = model(img.view(img.size(0),-1))
                preds = preds.argmax(1)
                total += preds.size(0)
                correct += (preds==y).sum()
            print(f'{correct} / {total}')

In [11]:
for i in range(10):
    one_epoch_man(model,train_dl,test_dl)

HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 2.303450107574463
2032 / 10000
loss: 0.21171051263809204
9408 / 10000
loss: 0.11840342730283737
9540 / 10000
loss: 0.08933466672897339
9595 / 10000



HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.15834201872348785
9649 / 10000
loss: 0.056095972657203674
9669 / 10000
loss: 0.06765744835138321
9719 / 10000
loss: 0.009176060557365417
9742 / 10000



HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.009727329015731812
9749 / 10000
loss: 0.13334141671657562
9756 / 10000
loss: 0.07056547701358795
9759 / 10000
loss: 0.10157201439142227
9792 / 10000



HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.06497161090373993
9725 / 10000
loss: 0.013200405985116959
9779 / 10000
loss: 0.04134632647037506
9752 / 10000
loss: 0.196173757314682
9768 / 10000



HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.010723695158958435
9794 / 10000
loss: 0.04777716100215912
9796 / 10000
loss: 0.021515458822250366
9810 / 10000
loss: 0.010138049721717834
9822 / 10000



HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.00419173389673233
9803 / 10000
loss: 0.0008922368288040161
9783 / 10000
loss: 0.00978529080748558
9801 / 10000
loss: 0.014179781079292297
9806 / 10000



HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.00013455748558044434
9783 / 10000
loss: 0.04466291517019272
9824 / 10000
loss: 0.004383556544780731
9797 / 10000
loss: 0.002289876341819763
9813 / 10000



HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.0007619261741638184
9788 / 10000
loss: 0.0008009970188140869
9813 / 10000
loss: 0.00011439621448516846
9771 / 10000
loss: 0.016691170632839203
9826 / 10000



HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.04656466096639633
9798 / 10000
loss: 0.0012153014540672302
9822 / 10000
loss: 0.004689980298280716
9799 / 10000
loss: 0.04137860983610153
9814 / 10000



HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.0009190738201141357
9816 / 10000
loss: 0.0008556023240089417
9813 / 10000
loss: 0.002465665340423584
9823 / 10000
loss: 0.005442731082439423
9807 / 10000

