In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.optim.lr_scheduler import StepLR

import hyperparams as hp
from tqdm import tqdm

In [3]:
use_cuda = not hp.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else 'cpu')
torch.manual_seed(hp.seed)

<torch._C.Generator at 0x7ffb3933d670>

In [4]:
tr_dataset = datasets.MNIST('../data', train=True, download=False,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ]))

te_dataset = datasets.MNIST('../data', train=False, download=False,
                       transform=transforms.Compose([
                           transforms.ToTensor(),
                           transforms.Normalize((0.1307,), (0.3081,))
                       ]))

In [5]:
model = Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=hp.lr)

In [6]:
scheduler = StepLR(optimizer, step_size=1, gamma=hp.gamma)

In [7]:
for epoch in range(1, hp.epochs+1):

    model.train()
    dataloader = torch.utils.data.DataLoader(tr_dataset, 
                                              batch_size=hp.tr_batch_size, 
                                              shuffle=True,
                                              num_workers=8,
                                              pin_memory=True)

    tr_loss = 0
    correct = 0
    pbar = tqdm(dataloader)
    for batch_idx, (data, target) in enumerate(pbar, 1):
        pbar.set_description(f'Train epoch {epoch : 3d}')
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()

        tr_loss += loss
        optimizer.step()
        
        pred = output.argmax(dim=1, keepdim=True)
        correct += pred.eq(target.view_as(pred)).sum()
        
        pbar.set_postfix({'loss' : f'{tr_loss.item() / (batch_idx * hp.tr_batch_size) : .3f}',
                          'correct' : f'{correct.item() / (batch_idx * hp.tr_batch_size) : .3f}'})

    if epoch%hp.log_interval == 0:
        model.eval()
        dataloader = torch.utils.data.DataLoader(te_dataset, 
                                                  batch_size=hp.te_batch_size, 
                                                  shuffle=True,
                                                  num_workers=8,
                                                  pin_memory=True)

        te_loss = 0
        correct = 0
        for batch_idx, (data, target) in enumerate(dataloader, 1):
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = F.nll_loss(output, target)    
            te_loss += loss

            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()


        print(f'Epoch {epoch} : test loss = {te_loss/len(dataloader.dataset) : .3f}, test correct = {correct/len(dataloader.dataset) : .3f}')

Train epoch   1: 100%|██████████| 938/938 [00:04<00:00, 193.27it/s, loss=0.003, correct=0.940]
Train epoch   2: 100%|██████████| 938/938 [00:04<00:00, 197.35it/s, loss=0.001, correct=0.975]
Train epoch   3: 100%|██████████| 938/938 [00:04<00:00, 195.66it/s, loss=0.001, correct=0.981]
Train epoch   4: 100%|██████████| 938/938 [00:04<00:00, 192.47it/s, loss=0.001, correct=0.984]
Train epoch   5: 100%|██████████| 938/938 [00:05<00:00, 182.21it/s, loss=0.001, correct=0.986]
  0%|          | 0/938 [00:00<?, ?it/s]

Epoch 5 : test loss =  0.000, test correct =  0.990


Train epoch   6: 100%|██████████| 938/938 [00:05<00:00, 179.02it/s, loss=0.001, correct=0.987]
Train epoch   7: 100%|██████████| 938/938 [00:05<00:00, 166.38it/s, loss=0.001, correct=0.989]
Train epoch   8: 100%|██████████| 938/938 [00:05<00:00, 183.64it/s, loss=0.001, correct=0.990]
Train epoch   9: 100%|██████████| 938/938 [00:04<00:00, 188.95it/s, loss=0.000, correct=0.991]
Train epoch  10: 100%|██████████| 938/938 [00:05<00:00, 184.16it/s, loss=0.000, correct=0.991]


Epoch 10 : test loss =  0.000, test correct =  0.992
