In [1]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms

In [2]:
train_ds = torchvision.datasets.MNIST(
    root='../data',
    train=True,
    transform=transforms.ToTensor(),
    download=True
)
train_ds

Dataset MNIST
    Number of datapoints: 60000
    Root location: ../data
    Split: Train

In [3]:
test_ds = torchvision.datasets.MNIST(
    root='../data',
    train=False,
    transform=transforms.ToTensor()
)
test_ds

Dataset MNIST
    Number of datapoints: 10000
    Root location: ../data
    Split: Test

In [4]:
train_dl = torch.utils.data.DataLoader(train_ds,32,True)
test_dl = torch.utils.data.DataLoader(test_ds,32,False)

In [5]:
input_size = train_ds[0][0].numel()

In [6]:
model = nn.Linear(input_size,len(train_ds.classes))
model

Linear(in_features=784, out_features=10, bias=True)

In [7]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=1e-3)

In [8]:
from tqdm import tqdm_notebook as tqdm

In [9]:
def one_epoch_man(model,train_dl,test_dl):
    model = model.cuda()
    for i,(images,labels) in enumerate(tqdm(train_dl)):
        model.train()
        images,labels = images.cuda(),labels.cuda()
        optimizer.zero_grad()

        preds = model(images.view(images.size(0),-1))

        loss = criterion(preds,labels)

        loss.backward()
        optimizer.step()
        if(i%500==0):
            correct = 0
            total = 0
            print(f'loss: {loss.item()}')
            for img,y in test_dl:
                model.eval()
                img,y = img.cuda(),y.cuda()
                preds = model(img.view(img.size(0),-1))
                preds = preds.argmax(1)
                total += preds.size(0)
                correct += (preds==y).sum()
            print(f'{correct} / {total}')

In [10]:
for i in range(10):
    one_epoch_man(model,train_dl,test_dl)

HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 2.307391881942749
1448 / 10000
loss: 0.5055556893348694
8944 / 10000
loss: 0.5419247150421143
9059 / 10000
loss: 0.33290350437164307
9134 / 10000



HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.10277147591114044
9141 / 10000
loss: 0.5010400414466858
9165 / 10000
loss: 0.3365917205810547
9194 / 10000
loss: 0.6060167551040649
9209 / 10000



HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.3381955623626709
9234 / 10000
loss: 0.4573444724082947
9219 / 10000
loss: 0.5584512948989868
9222 / 10000
loss: 0.1452184021472931
9249 / 10000



HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.26949429512023926
9238 / 10000
loss: 0.40749120712280273
9237 / 10000
loss: 0.4994536340236664
9245 / 10000
loss: 0.3265574276447296
9255 / 10000



HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.09514773637056351
9241 / 10000
loss: 0.16090840101242065
9236 / 10000
loss: 0.322355180978775
9237 / 10000
loss: 0.39800000190734863
9257 / 10000



HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.37501662969589233
9252 / 10000
loss: 0.0905991643667221
9254 / 10000
loss: 0.2562391757965088
9276 / 10000
loss: 0.18049339950084686
9268 / 10000



HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.3158821165561676
9273 / 10000
loss: 0.13705819845199585
9274 / 10000
loss: 0.0626390129327774
9261 / 10000
loss: 0.040510986000299454
9262 / 10000



HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.11182151734828949
9249 / 10000
loss: 0.10912596434354782
9279 / 10000
loss: 0.3939589560031891
9285 / 10000
loss: 0.3031099736690521
9267 / 10000



HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.24903181195259094
9253 / 10000
loss: 0.31026411056518555
9251 / 10000
loss: 0.2181904911994934
9250 / 10000
loss: 0.4772520661354065
9260 / 10000



HBox(children=(IntProgress(value=0, max=1875), HTML(value='')))

loss: 0.24217399954795837
9270 / 10000
loss: 0.23280517756938934
9242 / 10000
loss: 0.1942419409751892
9280 / 10000
loss: 0.44210106134414673
9260 / 10000

