## Train ResNet34 model

In [5]:
from importlib.util import find_spec
if find_spec("model") is None:
    import sys
    sys.path.append('..')

In [6]:
import torch.nn.functional as F
import torch

In [7]:
from model.model import Resnet34
from data_loader.data_loaders import Cifar100DataLoader

In [8]:
torch.cuda.is_available()

True

In [9]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [10]:
dl = Cifar100DataLoader('../data', 32)

Files already downloaded and verified


In [11]:
data, target = next(iter(dl))
data.shape, target

(torch.Size([32, 3, 224, 224]),
 tensor([35, 20,  1, 57,  8, 34, 42, 46, 42, 65, 10, 87, 21, 93, 90, 48, 68, 37,
         46, 70,  8, 47, 32, 60, 42, 20,  2, 35, 98, 87, 32, 44]))

In [12]:
model = Resnet34()
model.to(device)

Resnet34(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3))
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (global_avg_pooling): AvgPool2d(kernel_size=7, stride=7, padding=0)
  (fc): Linear(in_features=512, out_features=100, bias=True)
  (res_block1): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)
    (3): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (res_block2): Sequential(
    (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU(inplace=True)


In [13]:
model(data.to(device))

tensor([[-1.0534, -1.1376, -2.0474,  ..., -2.2924,  1.6199,  4.1499],
        [-2.6068, -1.1981, -1.8668,  ..., -3.7633,  1.3853,  5.3627],
        [-1.4275, -1.0648, -1.9524,  ..., -2.4178,  1.7122,  3.7906],
        ...,
        [-0.8671, -0.8447, -1.9430,  ..., -1.9440,  2.0806,  3.7591],
        [-2.3039, -1.0482, -1.6786,  ..., -3.0560,  1.1409,  4.9372],
        [-1.0696, -1.1045, -1.8125,  ..., -1.9601,  1.6558,  3.8309]],
       device='cuda:0', grad_fn=<AddmmBackward>)

### Define Loss function and optimizer

In [10]:
trainable_params = filter(lambda p: p.requires_grad, model.parameters())
criterion = F.cross_entropy
optimizer = torch.optim.SGD(trainable_params, lr=0.01, momentum=0.9)

In [11]:
for epoch in range(10):
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(dl):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        
        optimizer.step()
        running_loss+= loss.item()
        
        if batch_idx % 199 == 0: # print every 2000 mini_batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, batch_idx + 1, running_loss / 200))
            running_loss = 0.0
        

[1,     1] loss: 0.023
[1,   200] loss: 4.432
[1,   399] loss: 4.165
[1,   598] loss: 4.023
[1,   797] loss: 3.880
[1,   996] loss: 3.758
[1,  1195] loss: 3.673
[1,  1394] loss: 3.579
[2,     1] loss: 0.019
[2,   200] loss: 3.420
[2,   399] loss: 3.312
[2,   598] loss: 3.215
[2,   797] loss: 3.141
[2,   996] loss: 3.036
[2,  1195] loss: 2.938
[2,  1394] loss: 2.859
[3,     1] loss: 0.013
[3,   200] loss: 2.666
[3,   399] loss: 2.616
[3,   598] loss: 2.528
[3,   797] loss: 2.465
[3,   996] loss: 2.403
[3,  1195] loss: 2.304
[3,  1394] loss: 2.261
[4,     1] loss: 0.010
[4,   200] loss: 2.091
[4,   399] loss: 2.076
[4,   598] loss: 2.018
[4,   797] loss: 1.977
[4,   996] loss: 1.956
[4,  1195] loss: 1.931
[4,  1394] loss: 1.899
[5,     1] loss: 0.009
[5,   200] loss: 1.693
[5,   399] loss: 1.734
[5,   598] loss: 1.673
[5,   797] loss: 1.676
[5,   996] loss: 1.689
[5,  1195] loss: 1.646
[5,  1394] loss: 1.613
[6,     1] loss: 0.009
[6,   200] loss: 1.423
[6,   399] loss: 1.441
[6,   598] 