In [1]:
from importlib.util import find_spec
if find_spec("model") is None:
    import sys
    sys.path.append('..')

In [2]:
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F
import torch

In [3]:
from data_loader.data_loaders import Cifar100DataLoader

In [4]:
torch.cuda.is_available()

True

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else 'cpu')

In [6]:
resnet = models.resnet34(pretrained=False)

In [None]:
resnet

In [7]:
num_features = resnet.fc.in_features
num_features

512

In [8]:
resnet.fc = nn.Linear(num_features, 100)
resnet

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [9]:
resnet.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [10]:
dl = Cifar100DataLoader('../data', 32)

Files already downloaded and verified


### Define Loss function and optimizer

In [12]:
trainable_params = filter(lambda p: p.requires_grad, resnet.parameters())
criterion = F.cross_entropy
optimizer = torch.optim.SGD(trainable_params, lr=0.01, momentum=0.9)

In [13]:
for epoch in range(10):
    
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(dl):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = resnet(data)
        loss = criterion(output, target)
        loss.backward()
        
        optimizer.step()
        running_loss+= loss.item()
        
        if batch_idx % 199 == 0: # print every 2000 mini_batches
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, batch_idx + 1, running_loss / 200))
            running_loss = 0.0
        

[1,     1] loss: 0.023
[1,   200] loss: 4.520
[1,   399] loss: 4.124
[1,   598] loss: 3.961
[1,   797] loss: 3.852
[1,   996] loss: 3.749
[1,  1195] loss: 3.641
[1,  1394] loss: 3.551
[2,     1] loss: 0.017
[2,   200] loss: 3.297
[2,   399] loss: 3.230
[2,   598] loss: 3.151
[2,   797] loss: 3.044
[2,   996] loss: 2.931
[2,  1195] loss: 2.869
[2,  1394] loss: 2.787
[3,     1] loss: 0.014
[3,   200] loss: 2.527
[3,   399] loss: 2.497
[3,   598] loss: 2.396
[3,   797] loss: 2.349
[3,   996] loss: 2.294
[3,  1195] loss: 2.264
[3,  1394] loss: 2.196
[4,     1] loss: 0.009
[4,   200] loss: 2.010
[4,   399] loss: 1.972
[4,   598] loss: 1.943
[4,   797] loss: 1.954
[4,   996] loss: 1.917
[4,  1195] loss: 1.893
[4,  1394] loss: 1.825
[5,     1] loss: 0.006
[5,   200] loss: 1.655
[5,   399] loss: 1.674
[5,   598] loss: 1.637
[5,   797] loss: 1.642
[5,   996] loss: 1.637
[5,  1195] loss: 1.627
[5,  1394] loss: 1.634
[6,     1] loss: 0.007
[6,   200] loss: 1.423
[6,   399] loss: 1.411
[6,   598] 