In [1]:
import torch
from torch import nn
from torchvision import models, datasets, transforms
import time
from tqdm.auto import tqdm
import os

In [2]:
def set_requires_grad(model, value=False):
    for param in model.parameters():
        param.requires_grad = value

input_size = 224
batch_size = 4
normalize = transforms.Compose([
    transforms.Resize(input_size),
    transforms.CenterCrop(input_size),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

num_classes = 10

criterion = nn.CrossEntropyLoss()

def train_model(model, dataloaders, criterion, optimizer, num_epochs=3):
    since = time.time()

    acc_history = {'train': [], 'val': []}
    loss_history = {'train': [], 'val': []}

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            n_batches = len(dataloaders[phase])
            for inputs, labels in tqdm(dataloaders[phase], total=n_batches):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    # Get model outputs and calculate loss
                    # Special case for inception because in training it has an auxiliary output. In train
                    #   mode we calculate the loss by summing the final output and the auxiliary output
                    #   but in testing we only consider the final output.
                    outputs = model(inputs)
                    loss = criterion(outputs, labels)

                    _, preds = torch.max(outputs, 1)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(dataloaders[phase].dataset)
            epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)

            print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))
            loss_history[phase].append(epoch_loss)
            acc_history[phase].append(epoch_acc)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

    return model, acc_history

def test_model(model, testloader, criterion):
    n_batches = len(testloader)
    running_loss = 0.0
    running_corrects = 0

    for inputs, labels in tqdm(testloader, total=n_batches):
        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        _, preds = torch.max(outputs, 1)

        running_loss += loss.item() * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    loss = running_loss / len(testloader.dataset)
    acc = running_corrects.double() / len(testloader.dataset)
    print('Val Loss: {:.4f} Val Acc: {:.4f}'.format(loss, acc))
    return acc, loss

cuda:0


In [3]:
# Load model

resnet18_model = models.resnet18(pretrained=True)
num_ftrs = resnet18_model.fc.in_features
resnet18_model.fc = nn.Linear(num_ftrs, num_classes)

In [4]:
# Load dataset

# download dataset and unpack:
dataset_path = "imagenette2"
if not os.path.exists(dataset_path):
    !wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz
    !tar xvzf imagenette2.tgz

trainset_imagenette = datasets.ImageFolder(root='./%s/train/' % dataset_path, transform=normalize)
trainloader_imagenette = torch.utils.data.DataLoader(trainset_imagenette, batch_size=batch_size, shuffle=True, num_workers=2)

testset_imagenette = datasets.ImageFolder(root='./%s/val/' % dataset_path, transform=normalize)
testloader_imagenette = torch.utils.data.DataLoader(testset_imagenette, batch_size=batch_size, shuffle=False, num_workers=2)

loaders_imagenette = {'train': trainloader_imagenette, 'val': testloader_imagenette}

In [5]:
train_optimizer = torch.optim.SGD(resnet18_model.parameters(), lr=0.001, momentum=0.9)
resnet18_model = resnet18_model.to(device)

In [6]:
# train model
train_model(resnet18_model, loaders_imagenette, criterion, train_optimizer, num_epochs=50)

Epoch 0/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.5597 Acc: 0.8279


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.1696 Acc: 0.9455

Epoch 1/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.3450 Acc: 0.8972


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.2330 Acc: 0.9254

Epoch 2/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.2404 Acc: 0.9284


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.1711 Acc: 0.9478

Epoch 3/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.1906 Acc: 0.9449


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.1753 Acc: 0.9508

Epoch 4/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.1532 Acc: 0.9538


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.1911 Acc: 0.9447

Epoch 5/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.1281 Acc: 0.9620


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.2136 Acc: 0.9442

Epoch 6/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.1024 Acc: 0.9688


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.1655 Acc: 0.9546

Epoch 7/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0921 Acc: 0.9759


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.1773 Acc: 0.9526

Epoch 8/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0725 Acc: 0.9804


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.1815 Acc: 0.9480

Epoch 9/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0854 Acc: 0.9773


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.2186 Acc: 0.9447

Epoch 10/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0605 Acc: 0.9823


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.2165 Acc: 0.9450

Epoch 11/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0540 Acc: 0.9844


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.2169 Acc: 0.9432

Epoch 12/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0634 Acc: 0.9834


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.1802 Acc: 0.9513

Epoch 13/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0387 Acc: 0.9887


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.1758 Acc: 0.9549

Epoch 14/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0534 Acc: 0.9851


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.1966 Acc: 0.9501

Epoch 15/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0436 Acc: 0.9884


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.1677 Acc: 0.9557

Epoch 16/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0299 Acc: 0.9929


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.1772 Acc: 0.9524

Epoch 17/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0461 Acc: 0.9883


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.1963 Acc: 0.9506

Epoch 18/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0281 Acc: 0.9925


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.2004 Acc: 0.9508

Epoch 19/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0302 Acc: 0.9914


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.2014 Acc: 0.9470

Epoch 20/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0281 Acc: 0.9927


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.2064 Acc: 0.9511

Epoch 21/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0304 Acc: 0.9913


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.1805 Acc: 0.9549

Epoch 22/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0343 Acc: 0.9906


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.1955 Acc: 0.9498

Epoch 23/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0356 Acc: 0.9909


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.2012 Acc: 0.9485

Epoch 24/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0339 Acc: 0.9913


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.1946 Acc: 0.9508

Epoch 25/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


val Loss: 0.2508 Acc: 0.9457

Epoch 36/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0183 Acc: 0.9952


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.2123 Acc: 0.9501

Epoch 37/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0140 Acc: 0.9965


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.2310 Acc: 0.9501

Epoch 38/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0110 Acc: 0.9967


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.1799 Acc: 0.9575

Epoch 39/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0202 Acc: 0.9955


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.2077 Acc: 0.9493

Epoch 40/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0227 Acc: 0.9955


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.2297 Acc: 0.9490

Epoch 41/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0120 Acc: 0.9965


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.1833 Acc: 0.9506

Epoch 42/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0140 Acc: 0.9970


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.1931 Acc: 0.9524

Epoch 43/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0251 Acc: 0.9955


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.2303 Acc: 0.9411

Epoch 44/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0120 Acc: 0.9976


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.1797 Acc: 0.9503

Epoch 45/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0172 Acc: 0.9963


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.1697 Acc: 0.9546

Epoch 46/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0080 Acc: 0.9979


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.2077 Acc: 0.9490

Epoch 47/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0110 Acc: 0.9978


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.2012 Acc: 0.9524

Epoch 48/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0160 Acc: 0.9973


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.1837 Acc: 0.9516

Epoch 49/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 0.0109 Acc: 0.9980


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 0.1696 Acc: 0.9541

Training complete in 34m 3s


(ResNet(
   (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
   (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (relu): ReLU(inplace=True)
   (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
   (layer1): Sequential(
     (0): BasicBlock(
       (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (relu): ReLU(inplace=True)
       (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     )
     (1): BasicBlock(
       (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (relu): ReLU

In [7]:
# test model
test_model(resnet18_model, testloader_imagenette, criterion)

HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


Val Loss: 0.1696 Val Acc: 0.9541


(tensor(0.9541, device='cuda:0', dtype=torch.float64), 0.16964547558195273)

In [8]:
# save fc layer trained on imagenette
imagenette_fc = resnet18_model.fc

In [9]:
trainset_cifar10 = datasets.CIFAR10(root='./data', train=True, download=True, transform=normalize)
trainloader_cifar10 = torch.utils.data.DataLoader(trainset_cifar10, batch_size=batch_size, shuffle=True, num_workers=2)

testset_cifar10 = datasets.CIFAR10(root='./data', train=False, download=True, transform=normalize)
testloader_cifar10 = torch.utils.data.DataLoader(testset_cifar10, batch_size=batch_size, shuffle=False, num_workers=2)

loaders_cifar10 = {'train': trainloader_cifar10, 'val': testloader_cifar10}

Files already downloaded and verified
Files already downloaded and verified


In [None]:
num_ftrs = resnet18_model.fc.in_features
set_requires_grad(resnet18_model, False)
resnet18_model.fc = nn.Linear(num_ftrs, num_classes)
resnet18_model = resnet18_model.to(device)

In [None]:
pretrain_optimizer = torch.optim.SGD(resnet18_model.fc.parameters(), lr=0.001, momentum=0.9)
train_model(resnet18_model, loaders_cifar10, criterion, pretrain_optimizer, num_epochs=50)
# Train
set_requires_grad(resnet18_model, True)
train_optimizer = torch.optim.SGD(resnet18_model.parameters(), lr=0.001, momentum=0.9)
train_model(resnet18_model, loaders_cifar10, criterion, train_optimizer, num_epochs=50)

Epoch 0/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.5051 Acc: 0.4850


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 1.0902 Acc: 0.6257

Epoch 1/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.3570 Acc: 0.5449


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 1.0723 Acc: 0.6480

Epoch 2/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.3250 Acc: 0.5559


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.9878 Acc: 0.6662

Epoch 3/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.3168 Acc: 0.5615


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.9834 Acc: 0.6639

Epoch 4/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2979 Acc: 0.5684


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.9920 Acc: 0.6707

Epoch 5/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2932 Acc: 0.5692


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 1.0453 Acc: 0.6638

Epoch 6/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2938 Acc: 0.5715


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 1.0955 Acc: 0.6437

Epoch 7/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2743 Acc: 0.5774


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.9679 Acc: 0.6788

Epoch 8/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2780 Acc: 0.5751


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 1.1595 Acc: 0.6372

Epoch 9/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2772 Acc: 0.5775


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 1.1272 Acc: 0.6401

Epoch 10/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2713 Acc: 0.5782


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.9967 Acc: 0.6645

Epoch 11/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2590 Acc: 0.5845


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 1.1687 Acc: 0.6324

Epoch 12/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2583 Acc: 0.5824


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 1.1443 Acc: 0.6431

Epoch 13/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2598 Acc: 0.5831


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.9364 Acc: 0.6878

Epoch 14/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2551 Acc: 0.5862


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.9676 Acc: 0.6776

Epoch 15/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2483 Acc: 0.5872


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.9989 Acc: 0.6703

Epoch 16/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2477 Acc: 0.5856


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 1.0036 Acc: 0.6715

Epoch 17/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2520 Acc: 0.5879


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 1.1222 Acc: 0.6421

Epoch 18/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2527 Acc: 0.5860


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 1.0016 Acc: 0.6715

Epoch 19/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2440 Acc: 0.5890


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 1.1373 Acc: 0.6389

Epoch 20/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2538 Acc: 0.5861


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 1.0498 Acc: 0.6569

Epoch 21/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2511 Acc: 0.5853


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 1.0586 Acc: 0.6574

Epoch 22/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2440 Acc: 0.5904


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 1.0777 Acc: 0.6612

Epoch 23/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2398 Acc: 0.5899


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 1.0220 Acc: 0.6664

Epoch 24/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2468 Acc: 0.5902


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 1.0684 Acc: 0.6519

Epoch 25/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2452 Acc: 0.5920


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.9977 Acc: 0.6747

Epoch 26/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2478 Acc: 0.5889


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.9372 Acc: 0.6859

Epoch 27/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2492 Acc: 0.5914


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.9782 Acc: 0.6858

Epoch 30/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2403 Acc: 0.5925


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 1.0454 Acc: 0.6561

Epoch 31/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2369 Acc: 0.5906


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 1.1477 Acc: 0.6333

Epoch 32/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2352 Acc: 0.5922


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.9728 Acc: 0.6830

Epoch 33/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2497 Acc: 0.5882


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 1.0062 Acc: 0.6735

Epoch 34/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2395 Acc: 0.5896


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.9822 Acc: 0.6755

Epoch 35/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2332 Acc: 0.5916


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 1.0117 Acc: 0.6668

Epoch 36/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2387 Acc: 0.5931


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 1.0932 Acc: 0.6526

Epoch 37/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2410 Acc: 0.5902


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.9672 Acc: 0.6735

Epoch 38/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2374 Acc: 0.5948


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.9897 Acc: 0.6728

Epoch 39/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2465 Acc: 0.5893


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 1.1875 Acc: 0.6301

Epoch 40/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2370 Acc: 0.5916


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 1.1676 Acc: 0.6362

Epoch 41/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2464 Acc: 0.5888


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.9761 Acc: 0.6732

Epoch 42/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2439 Acc: 0.5916


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.9985 Acc: 0.6724

Epoch 47/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2392 Acc: 0.5941


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.9845 Acc: 0.6817

Epoch 48/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2410 Acc: 0.5934


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 1.3176 Acc: 0.6109

Epoch 49/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 1.2506 Acc: 0.5902


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.9506 Acc: 0.6848

Training complete in 72m 38s
Epoch 0/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.8190 Acc: 0.7398


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.3997 Acc: 0.8675

Epoch 1/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.3905 Acc: 0.8701


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.3898 Acc: 0.8782

Epoch 2/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.2620 Acc: 0.9126


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2864 Acc: 0.9080

Epoch 3/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.1868 Acc: 0.9375


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2664 Acc: 0.9128

Epoch 4/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.1372 Acc: 0.9536


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2497 Acc: 0.9244

Epoch 5/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.1021 Acc: 0.9656


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2918 Acc: 0.9167

Epoch 6/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0800 Acc: 0.9734


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2669 Acc: 0.9249

Epoch 7/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0651 Acc: 0.9784


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2761 Acc: 0.9221

Epoch 8/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0597 Acc: 0.9800


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2752 Acc: 0.9259

Epoch 9/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0417 Acc: 0.9868


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2738 Acc: 0.9265

Epoch 10/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0336 Acc: 0.9890


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2635 Acc: 0.9277

Epoch 11/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0296 Acc: 0.9904


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2463 Acc: 0.9337

Epoch 12/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0244 Acc: 0.9926


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2621 Acc: 0.9323

Epoch 13/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0194 Acc: 0.9933


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2655 Acc: 0.9321

Epoch 14/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0166 Acc: 0.9946


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.3100 Acc: 0.9253

Epoch 15/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0143 Acc: 0.9955


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2581 Acc: 0.9338

Epoch 16/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0130 Acc: 0.9961


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2520 Acc: 0.9369

Epoch 17/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0130 Acc: 0.9964


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2471 Acc: 0.9347

Epoch 18/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0102 Acc: 0.9971


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2583 Acc: 0.9364

Epoch 19/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0121 Acc: 0.9963


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2765 Acc: 0.9324

Epoch 20/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0103 Acc: 0.9969


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2612 Acc: 0.9366

Epoch 21/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0071 Acc: 0.9980


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2699 Acc: 0.9356

Epoch 22/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0060 Acc: 0.9983


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2508 Acc: 0.9379

Epoch 23/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0057 Acc: 0.9984


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2663 Acc: 0.9385

Epoch 24/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0039 Acc: 0.9990


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2565 Acc: 0.9403

Epoch 25/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0056 Acc: 0.9982


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2511 Acc: 0.9385

Epoch 26/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0062 Acc: 0.9983


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2538 Acc: 0.9386

Epoch 27/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0040 Acc: 0.9990


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2566 Acc: 0.9378

Epoch 28/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0030 Acc: 0.9992


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2605 Acc: 0.9399

Epoch 30/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0032 Acc: 0.9993


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2737 Acc: 0.9375

Epoch 31/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0024 Acc: 0.9994


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2867 Acc: 0.9407

Epoch 32/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0032 Acc: 0.9994


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2578 Acc: 0.9432

Epoch 33/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0032 Acc: 0.9990


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2468 Acc: 0.9407

Epoch 34/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0020 Acc: 0.9995


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2540 Acc: 0.9429

Epoch 35/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0023 Acc: 0.9994


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2635 Acc: 0.9386

Epoch 36/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0028 Acc: 0.9992


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2900 Acc: 0.9366

Epoch 37/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0029 Acc: 0.9991


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2532 Acc: 0.9416

Epoch 38/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0017 Acc: 0.9995


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2622 Acc: 0.9417

Epoch 39/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0019 Acc: 0.9996


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2663 Acc: 0.9429

Epoch 41/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0016 Acc: 0.9996


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2725 Acc: 0.9409

Epoch 42/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0021 Acc: 0.9995


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2735 Acc: 0.9410

Epoch 43/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0016 Acc: 0.9996


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.3077 Acc: 0.9349

Epoch 44/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0028 Acc: 0.9991


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2791 Acc: 0.9377

Epoch 45/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0025 Acc: 0.9994


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2773 Acc: 0.9401

Epoch 46/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0017 Acc: 0.9996


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2822 Acc: 0.9399

Epoch 47/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0026 Acc: 0.9993


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2850 Acc: 0.9349

Epoch 48/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0019 Acc: 0.9996


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2844 Acc: 0.9386

Epoch 49/49
----------


HBox(children=(FloatProgress(value=0.0, max=12500.0), HTML(value='')))


train Loss: 0.0013 Acc: 0.9997


HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


val Loss: 0.2774 Acc: 0.9387

Training complete in 163m 11s


(ResNet(
   (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
   (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (relu): ReLU(inplace=True)
   (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
   (layer1): Sequential(
     (0): BasicBlock(
       (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (relu): ReLU(inplace=True)
       (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     )
     (1): BasicBlock(
       (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (relu): ReLU

In [None]:
# test model on cifar10
test_model(resnet18_model, testloader_cifar10, criterion)

HBox(children=(FloatProgress(value=0.0, max=2500.0), HTML(value='')))


Val Loss: 0.2774 Val Acc: 0.9387


(tensor(0.9387, device='cuda:0', dtype=torch.float64), 0.2774390351295471)

In [None]:
# restore fc layer
resnet18_model.fc = imagenette_fc
resnet18_model = resnet18_model.to(device)

In [None]:
# test model on imagenette2
test_model(resnet18_model, testloader_imagenette, criterion)

HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


Val Loss: 1.5433 Val Acc: 0.4736


(tensor(0.4736, device='cuda:0', dtype=torch.float64), 1.5433107855669252)

In [None]:
# train only last layer on imagenette2
set_requires_grad(resnet18_model, False)
set_requires_grad(resnet18_model.fc, True)
train_model(resnet18_model, loaders_imagenette, criterion, train_optimizer, num_epochs=50)

Epoch 0/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4081 Acc: 0.5312


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.2972 Acc: 0.5704

Epoch 1/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4077 Acc: 0.5330


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.2889 Acc: 0.5801

Epoch 2/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4178 Acc: 0.5274


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.3033 Acc: 0.5766

Epoch 3/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4091 Acc: 0.5394


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.2877 Acc: 0.5829

Epoch 4/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4122 Acc: 0.5324


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.3395 Acc: 0.5575

Epoch 5/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.3984 Acc: 0.5407


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.3332 Acc: 0.5631

Epoch 6/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4027 Acc: 0.5370


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.2639 Acc: 0.6023

Epoch 7/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.3994 Acc: 0.5423


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.3125 Acc: 0.5796

Epoch 8/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4119 Acc: 0.5309


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.3126 Acc: 0.5819

Epoch 9/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4034 Acc: 0.5379


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.2924 Acc: 0.5809

Epoch 10/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4060 Acc: 0.5313


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.3571 Acc: 0.5526

Epoch 11/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4107 Acc: 0.5346


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.3380 Acc: 0.5592

Epoch 12/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4140 Acc: 0.5337


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.3305 Acc: 0.5661

Epoch 13/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4159 Acc: 0.5243


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.2705 Acc: 0.5926

Epoch 14/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4060 Acc: 0.5354


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.3185 Acc: 0.5725

Epoch 15/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4060 Acc: 0.5316


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.3242 Acc: 0.5750

Epoch 16/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4118 Acc: 0.5323


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


train Loss: 1.4037 Acc: 0.5393


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.2572 Acc: 0.6036

Epoch 26/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4042 Acc: 0.5349


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.3886 Acc: 0.5363

Epoch 27/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4085 Acc: 0.5368


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.3249 Acc: 0.5758

Epoch 28/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4128 Acc: 0.5299


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.2846 Acc: 0.5824

Epoch 29/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4177 Acc: 0.5348


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.2908 Acc: 0.5862

Epoch 30/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4121 Acc: 0.5361


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.2544 Acc: 0.6005

Epoch 31/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4176 Acc: 0.5337


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.3024 Acc: 0.5865

Epoch 32/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4143 Acc: 0.5299


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.3502 Acc: 0.5615

Epoch 33/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4153 Acc: 0.5330


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.3169 Acc: 0.5743

Epoch 34/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4083 Acc: 0.5382


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.3119 Acc: 0.5687

Epoch 35/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4031 Acc: 0.5371


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.2931 Acc: 0.5783

Epoch 36/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4056 Acc: 0.5349


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.2970 Acc: 0.5837

Epoch 37/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4040 Acc: 0.5376


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.3194 Acc: 0.5707

Epoch 38/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4094 Acc: 0.5387


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.2643 Acc: 0.5977

Epoch 39/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4066 Acc: 0.5332


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.3148 Acc: 0.5763

Epoch 40/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4142 Acc: 0.5293


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.2862 Acc: 0.5804

Epoch 41/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4101 Acc: 0.5340


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.3232 Acc: 0.5682

Epoch 42/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.3992 Acc: 0.5407


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.2549 Acc: 0.6082

Epoch 45/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4031 Acc: 0.5364


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.2836 Acc: 0.5916

Epoch 46/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4106 Acc: 0.5370


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.2933 Acc: 0.5804

Epoch 47/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4099 Acc: 0.5337


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.2958 Acc: 0.5893

Epoch 48/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4126 Acc: 0.5356


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.3024 Acc: 0.5809

Epoch 49/49
----------


HBox(children=(FloatProgress(value=0.0, max=2368.0), HTML(value='')))


train Loss: 1.4098 Acc: 0.5331


HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


val Loss: 1.3561 Acc: 0.5496

Training complete in 42m 15s


(ResNet(
   (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
   (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
   (relu): ReLU(inplace=True)
   (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
   (layer1): Sequential(
     (0): BasicBlock(
       (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (relu): ReLU(inplace=True)
       (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
     )
     (1): BasicBlock(
       (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
       (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
       (relu): ReLU

In [None]:
# test model on imagenette2
test_model(resnet18_model, testloader_imagenette, criterion)

HBox(children=(FloatProgress(value=0.0, max=982.0), HTML(value='')))


Val Loss: 1.3561 Val Acc: 0.5496


(tensor(0.5496, device='cuda:0', dtype=torch.float64), 1.3561149500585665)

Итоги:

1. Трейн на imagenette2: Val Loss: 0.1696 Val Acc: 0.9541
2. После дообучения нового последнего слоя на CIFAR10: Val Loss: 0.2774 Val Acc: 0.9387
3. После возврата слоя тест на imagenette2: Val Loss: 1.5433 Val Acc: 0.4736
4. После дообучения возвращенного слоя на imagenette2: Val Loss: 1.3561 Val Acc: 0.5496

Исходное качество не достигнуто (accuracy было 0.9541, стало 0.5496).

Вывод: после замены слоя, обучения на другом датасете и возврата слоя сеть перестает обучаться нормально на исходном датасете. Тем самым подтверждается эффект катастрофического забывания.