## Train ResNet34 model

In [1]:
from importlib.util import find_spec
if find_spec("model") is None:
    import sys
    sys.path.append('..')

In [2]:
import torch.nn.functional as F
import torch

In [3]:
from model.model import Resnet34
from data_loader.data_loaders import Cifar100DataLoader

In [4]:
torch.cuda.is_available()

True

In [5]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [6]:
dl = Cifar100DataLoader('../data', 32)

Files already downloaded and verified


In [7]:
data, target = next(iter(dl))
data.shape, target

(torch.Size([32, 3, 224, 224]),
 tensor([21, 88, 83, 57,  7, 34, 32, 64, 88, 81, 35, 47, 67, 19, 92, 34,  3, 74,
         86, 15,  7, 92, 46, 69, 57, 33, 94, 56, 72, 96, 52, 86]))

In [8]:
model = Resnet34()
model.to(device)

Resnet34(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (res_block1): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (res_block2): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (res_block3): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (res_block4): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), paddi

In [9]:
model(data.to(device))

tensor([[ 0.0165,  0.3172,  0.0550,  ..., -0.0256, -0.2807,  0.1660],
        [ 0.0516,  0.6298,  0.1419,  ..., -0.1033, -0.5431,  0.3836],
        [-0.0096,  0.5368,  0.1708,  ..., -0.1154, -0.5084,  0.4780],
        ...,
        [-0.0083,  0.3715,  0.1078,  ..., -0.0188, -0.3420,  0.2538],
        [-0.0057,  0.4617,  0.0986,  ..., -0.0284, -0.4133,  0.3126],
        [ 0.0614,  0.5981,  0.0983,  ..., -0.0798, -0.4713,  0.3259]],
       device='cuda:0', grad_fn=<AddmmBackward>)

### Define Loss function and optimizer

In [10]:
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
criterion = F.cross_entropy
optimizer = torch.optim.SGD(trainable_params, lr=0.01, momentum=0.9)

In [11]:
for epoch in range(10):
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(dl):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        
        optimizer.step()
        running_loss+= loss.item()
        
        if batch_idx % 199 == 0: # print every 2000 mini_batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, batch_idx + 1, running_loss / 200))
            running_loss = 0.0
        

[1,     1] loss: 0.023
[1,   200] loss: 4.432
[1,   399] loss: 4.165
[1,   598] loss: 4.023
[1,   797] loss: 3.880
[1,   996] loss: 3.758
[1,  1195] loss: 3.673
[1,  1394] loss: 3.579
[2,     1] loss: 0.019
[2,   200] loss: 3.420
[2,   399] loss: 3.312
[2,   598] loss: 3.215
[2,   797] loss: 3.141
[2,   996] loss: 3.036
[2,  1195] loss: 2.938
[2,  1394] loss: 2.859
[3,     1] loss: 0.013
[3,   200] loss: 2.666
[3,   399] loss: 2.616
[3,   598] loss: 2.528
[3,   797] loss: 2.465
[3,   996] loss: 2.403
[3,  1195] loss: 2.304
[3,  1394] loss: 2.261
[4,     1] loss: 0.010
[4,   200] loss: 2.091
[4,   399] loss: 2.076
[4,   598] loss: 2.018
[4,   797] loss: 1.977
[4,   996] loss: 1.956
[4,  1195] loss: 1.931
[4,  1394] loss: 1.899
[5,     1] loss: 0.009
[5,   200] loss: 1.693
[5,   399] loss: 1.734
[5,   598] loss: 1.673
[5,   797] loss: 1.676
[5,   996] loss: 1.689
[5,  1195] loss: 1.646
[5,  1394] loss: 1.613
[6,     1] loss: 0.009
[6,   200] loss: 1.423
[6,   399] loss: 1.441
[6,   598] 