In [2]:
!unzip drive/MyDrive/Pets-data/images.zip
!cp -r drive/MyDrive/Pets-data/annotations .
!ls

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: images/Egyptian_Mau_91.jpg  
  inflating: images/Egyptian_Mau_92.jpg  
  inflating: images/Egyptian_Mau_93.jpg  
  inflating: images/Egyptian_Mau_94.jpg  
  inflating: images/Egyptian_Mau_95.jpg  
  inflating: images/Egyptian_Mau_96.jpg  
  inflating: images/Egyptian_Mau_97.jpg  
  inflating: images/Egyptian_Mau_98.jpg  
  inflating: images/Egyptian_Mau_99.jpg  
  inflating: images/english_cocker_spaniel_1.jpg  
  inflating: images/english_cocker_spaniel_10.jpg  
  inflating: images/english_cocker_spaniel_100.jpg  
  inflating: images/english_cocker_spaniel_101.jpg  
  inflating: images/english_cocker_spaniel_102.jpg  
  inflating: images/english_cocker_spaniel_103.jpg  
  inflating: images/english_cocker_spaniel_104.jpg  
  inflating: images/english_cocker_spaniel_105.jpg  
  inflating: images/english_cocker_spaniel_106.jpg  
  inflating: images/english_cocker_spaniel_107.jpg  
  inflating: images/english_co

In [99]:
import os
import yaml
from typing import Tuple, List

import torch
import torchvision
from tqdm import tqdm
from PIL import Image
import torch.nn as nn
from pydantic import BaseModel
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
import numpy as np

In [56]:
# Define constants

CONFIG_FILE_PATH = 'config.yaml'
IMAGE_COL_IDX = 0
CLASS_ID_COL_IDX = 1
SPECIES_COL_IDX = 2
POSSIBLE_NUM_CLASSES = {2, 37}

In [57]:
class AdamOptimizerConfig(BaseModel):
    lr: float
    weight_decay: float

In [60]:
class Config(BaseModel):
    device: str
    num_classes: int
    batch_size: int
    max_num_epochs: int
    patience: int
    adam_optimizer_config: AdamOptimizerConfig
    num_batch_norm_layers_to_train_params: int
    num_batch_norm_layers_to_update_running_stats: int
    train_earlier_layers_delay: int
    n_hidden_layers_to_train: int
    use_pseudo_labelling: int
    labelled_data_ratio: float
    T_1: int
    T_2: int
    ALPHA_F: float

In [136]:
with open(CONFIG_FILE_PATH, encoding='utf-8') as f:
    config_dict = yaml.load(f, Loader=yaml.FullLoader)

config = Config.model_validate(config_dict)

assert config.num_classes in POSSIBLE_NUM_CLASSES
assert 0 <= config.num_batch_norm_layers_to_train_params <= 36  # 36 batch norm layers in resnet34
assert 0 <= config.num_batch_norm_layers_to_update_running_stats <= 36  # 36 batch norm layers in resnet34

In [62]:
class ImageDataset(Dataset):
    def __init__(self, filenames: List[str], labels: List[int], use_augmentations: bool) -> None:
        self.filenames = filenames
        self.labels = labels
        # self.transformation = torchvision.models.ResNet34_Weights.IMAGENET1K_V1.transforms()
        self.transformation = (
            transforms.Compose([
                torchvision.models.ResNet34_Weights.IMAGENET1K_V1.transforms(),
                # transforms.RandomCrop((what size, what other size)),
                transforms.RandomHorizontalFlip(),
                # transforms.RandomRotation((-10, 10)),
                # transforms.RandomErasing(),
            ])
            if use_augmentations
            else torchvision.models.ResNet34_Weights.IMAGENET1K_V1.transforms()
        )

    def __len__(self) -> int:
        return len(self.filenames)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        image = Image.open(os.path.join('images', f'{self.filenames[idx]}.jpg')).convert('RGB')
        label = self.labels[idx]

        transformed_img = self.transformation(image)

        return transformed_img.to(config.device), torch.tensor(label).to(config.device)

In [63]:
class ImageDatasetPseudoLabels(Dataset):
    def __init__(self, labelled_filenames: List[str], labels: List[int], unlabelled_filenames: List[str], use_augmentations: bool) -> None:
        self.labelled_filenames = labelled_filenames
        self.labels = labels
        self.unlabelled_filenames = unlabelled_filenames

        assert len(labelled_filenames) == len(labels), 'labels must have the same size as labelled_filenames'

        # self.transformation = torchvision.models.ResNet34_Weights.IMAGENET1K_V1.transforms()
        self.transformation = (
            transforms.Compose([
                torchvision.models.ResNet34_Weights.IMAGENET1K_V1.transforms(),
                # transforms.RandomCrop((what size, what other size)),
                transforms.RandomHorizontalFlip(),
                # transforms.RandomRotation((-10, 10)),
                # transforms.RandomErasing(),
            ])
            if use_augmentations
            else torchvision.models.ResNet34_Weights.IMAGENET1K_V1.transforms()
        )

    def __len__(self) -> int:
        return len(self.labelled_filenames) + len(self.unlabelled_filenames)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        if idx >= len(self.labelled_filenames): # Returning an unlabelled image
            image = Image.open(os.path.join('images', f'{self.unlabelled_filenames[idx - len(self.labelled_filenames)]}.jpg')).convert('RGB')
            transformed_img = self.transformation(image)
            return transformed_img.to(config.device), torch.tensor(-1).float().to(config.device)

        image = Image.open(os.path.join('images', f'{self.labelled_filenames[idx]}.jpg')).convert('RGB')
        label = self.labels[idx]

        transformed_img = self.transformation(image)

        return transformed_img.to(config.device), torch.tensor(label).to(config.device)

In [64]:
def get_image_names_and_labels(annotations_file_path: str, num_classes: int) -> Tuple[List[str], List[int]]:
    filenames: List[str] = []
    labels: List[int] = []

    with open(annotations_file_path, encoding='utf-8') as f:
        lines = f.readlines()

    label_col_idx = SPECIES_COL_IDX if num_classes == 2 else CLASS_ID_COL_IDX

    for line in lines:
        line_split = line.split()
        filenames.append(line_split[IMAGE_COL_IDX])
        labels.append(int(line_split[label_col_idx]) - 1)

    return filenames, labels

In [65]:
def get_batch_norm_layers(model: nn.Module) -> List[nn.Module]:
    return [
        module
        for module in model.modules()
        if isinstance(module, nn.BatchNorm2d)
    ]

In [66]:
def get_pretrained_model_with_trainable_last_layer(num_batch_norm_layers_to_train_params: int) -> nn.Module:
    model = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)

    for param in model.parameters():
        param.requires_grad = False

    batch_norm_layers = get_batch_norm_layers(model)
    batch_norm_layers_to_train_params = batch_norm_layers[-num_batch_norm_layers_to_train_params:] if num_batch_norm_layers_to_train_params else []

    for batch_norm_layer in batch_norm_layers_to_train_params:
        for param in batch_norm_layer.parameters():
            param.requires_grad = True

    model.fc = nn.Linear(in_features=model.fc.in_features, out_features=config.num_classes)

    return model.to(config.device)

In [67]:
def get_pretrained_model_with_trainable_n_layers(num_batch_norm_layers_to_train_params: int, number_of_hidden_layers_to_train: int=0) -> nn.Module:
    model = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)

    for param in model.parameters():
        param.requires_grad = False

    batch_norm_layers = get_batch_norm_layers(model)
    batch_norm_layers_to_train_params = batch_norm_layers[-num_batch_norm_layers_to_train_params:] if num_batch_norm_layers_to_train_params else []

    for batch_norm_layer in batch_norm_layers_to_train_params:
        for param in batch_norm_layer.parameters():
            param.requires_grad = True

    model.fc = nn.Linear(in_features=model.fc.in_features, out_features=config.num_classes)

    trainable_layers = [layer for layer in model.modules() if not isinstance(layer, torchvision.models.resnet.ResNet) and not isinstance(layer, torchvision.models.resnet.BasicBlock) and not isinstance(layer, nn.Sequential) and not isinstance(layer, nn.Sequential) and len(list(layer.parameters())) > 0]
    trainable_layers = list(reversed(trainable_layers[:-1]))

    for l in trainable_layers[:number_of_hidden_layers_to_train]:
        for p in l.parameters():
            p.requires_grad = True

    return model.to(config.device)

In [68]:
def make_hidden_layers_trainable(model: nn.Module, number_of_hidden_layers_to_train: int) -> nn.Module:
    trainable_layers = [layer for layer in model.modules() if not isinstance(layer, torchvision.models.resnet.ResNet) and not isinstance(layer, torchvision.models.resnet.BasicBlock) and not isinstance(layer, nn.Sequential) and not isinstance(layer, nn.Sequential) and len(list(layer.parameters())) > 0]
    trainable_layers = list(reversed(trainable_layers[:-1]))

    for l in trainable_layers[:number_of_hidden_layers_to_train]:
        for p in l.parameters():
            p.requires_grad = True

    return model

In [69]:
def get_model_accuracy(model: nn.Module, data_loader: DataLoader) -> float:
    correct_predictions_cnt = 0
    total_predictions_cnt = 0
    model.eval()
    with torch.no_grad():
        for inputs, labels in tqdm(data_loader, desc='Computing accuracy'):
            outputs = model(inputs)
            correct_predictions_cnt += (torch.argmax(outputs, axis=1) == labels).sum()
            total_predictions_cnt += len(outputs)
    return correct_predictions_cnt / total_predictions_cnt

In [70]:
def train_single_epoch(
        model: nn.Module,
        train_data_loader: DataLoader,
        criterion: nn.Module,
        optimizer: torch.optim.Optimizer,
        num_batch_norm_layers_to_update_running_stats: int,
        ) -> float:
    model.train()
    batch_norm_layers = get_batch_norm_layers(model)
    batch_norm_layers_to_not_update_running_stats = (
        batch_norm_layers[:-num_batch_norm_layers_to_update_running_stats]
        if num_batch_norm_layers_to_update_running_stats
        else batch_norm_layers
    )
    for batch_norm_layer in batch_norm_layers_to_not_update_running_stats:
        batch_norm_layer.eval()
    train_loss_sum = 0.0
    train_samples_cnt = 0
    for inputs, labels in tqdm(train_data_loader, desc='Training model'):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss_sum += loss.item() * len(outputs)
        train_samples_cnt += len(outputs)
    return train_loss_sum / train_samples_cnt

In [118]:
def alpha(t):
    if t < config.T_1:
        return 0
    elif config.T_1 <= t and t < config.T_2:
        return ((t-config.T_1)/(config.T_2-config.T_1))*config.ALPHA_F
    return config.ALPHA_F


def train_single_epoch_pseudo(
        model: nn.Module,
        train_data_loader: DataLoader,
        criterion: nn.Module,
        optimizer: torch.optim.Optimizer,
        num_batch_norm_layers_to_update_running_stats: int,
        epoch: int,
        ) -> float:
    model.train()
    batch_norm_layers = get_batch_norm_layers(model)
    batch_norm_layers_to_not_update_running_stats = (
        batch_norm_layers[:-num_batch_norm_layers_to_update_running_stats]
        if num_batch_norm_layers_to_update_running_stats
        else batch_norm_layers
    )
    for batch_norm_layer in batch_norm_layers_to_not_update_running_stats:
        batch_norm_layer.eval()
    train_loss_sum = 0.0
    train_samples_cnt = 0

    for inputs, labels in tqdm(train_data_loader, desc='Training model'):
        optimizer.zero_grad()

        labelled_idxs = torch.tensor([x[0] for x in enumerate(labels) if not x[1] == -1]).long().to(config.device)
        unlabelled_idxs = torch.tensor([x[0] for x in enumerate(labels) if x[1] == -1]).long().to(config.device)

        labelled_labels = torch.index_select(labels, 0, labelled_idxs).long()

        outputs = model(inputs)

        labelled_outputs = torch.index_select(outputs, 0, labelled_idxs)
        unlabelled_outputs = torch.index_select(outputs, 0, unlabelled_idxs)

        unlabelled_labels = torch.argmax(unlabelled_outputs, axis=1)

        labelled_loss = criterion(labelled_outputs, labelled_labels)
        unlabelled_loss = criterion(unlabelled_outputs, unlabelled_labels)
        loss = 0

        if not np.isnan(labelled_loss.item()):
            loss = loss + labelled_loss
        if not np.isnan(unlabelled_loss.item()):
            loss = loss + alpha(epoch) * unlabelled_loss

        loss.backward()
        optimizer.step()
        train_loss_sum += loss.item() * len(outputs)
        train_samples_cnt += len(outputs)

    return train_loss_sum / train_samples_cnt

In [72]:
def save_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer) -> Tuple[nn.Module, torch.optim.Optimizer]:
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }
    torch.save(checkpoint, 'checkpoint.pt')

In [73]:
def load_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer) -> None:
    checkpoint = torch.load('checkpoint.pt')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [104]:
filenames_trainval, labels_trainval = get_image_names_and_labels('annotations/trainval.txt', num_classes=config.num_classes)
filenames_test, labels_test = get_image_names_and_labels('annotations/test.txt', num_classes=config.num_classes)
filenames_train, filenames_val, labels_train, labels_val = train_test_split(filenames_trainval, labels_trainval, test_size=0.2, stratify=labels_trainval)

if config.use_pseudo_labelling:
    filenames_train, unlabelled_filenames_train, labels_train, _ = train_test_split(filenames_train, labels_train, train_size=config.labelled_data_ratio, stratify=labels_train)
    dataset_train = ImageDatasetPseudoLabels(filenames_train, labels_train, unlabelled_filenames_train, use_augmentations=True)
    labelled_dataset_train = ImageDataset(filenames_train, labels_train, use_augmentations=True)
    labelled_train_data_loader = DataLoader(labelled_dataset_train, batch_size=config.batch_size, shuffle=True)
else:
    dataset_train = ImageDataset(filenames_train, labels_train, use_augmentations=True)

dataset_val = ImageDataset(filenames_val, labels_val, use_augmentations=False)
dataset_test = ImageDataset(filenames_test, labels_test, use_augmentations=False)

train_data_loader = DataLoader(dataset_train, batch_size=config.batch_size, shuffle=True)
val_data_loader = DataLoader(dataset_val, batch_size=config.batch_size, shuffle=False)
test_data_loader = DataLoader(dataset_test, batch_size=config.batch_size, shuffle=False)

In [129]:
def train_model_without_pseudo_labelling():
    model = get_pretrained_model_with_trainable_last_layer(config.num_batch_norm_layers_to_train_params)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.adam_optimizer_config.lr, weight_decay=config.adam_optimizer_config.weight_decay)

    max_val_accuracy = float('-inf')
    argmax_epoch = -1

    for epoch in range(config.max_num_epochs):
        print(f'Epoch #{epoch}:')
        train_loss = train_single_epoch(model, train_data_loader, criterion, optimizer, config.num_batch_norm_layers_to_update_running_stats)
        print(f'Train loss: {train_loss}')
        val_accuracy = get_model_accuracy(model, val_data_loader)
        if val_accuracy > max_val_accuracy:
            print(f'Validation accuracy: {100 * val_accuracy:.2f}% (new best)')
            max_val_accuracy = val_accuracy
            argmax_epoch = epoch
            save_checkpoint(model, optimizer)
            print('Checkpoint saved')
        else:
            print(f'Validation accuracy: {100 * val_accuracy:.2f}% (worse than {100 * max_val_accuracy:.2f}% of epoch {argmax_epoch})')
            if epoch > argmax_epoch + config.patience:
                print(f'Early stopping')
                break
        # Start training earlier layers
        if epoch == config.train_earlier_layers_delay and config.n_hidden_layers_to_train > 0:
            print (f'Making last {config.n_hidden_layers_to_train} hidden layers trainable')
            make_hidden_layers_trainable(model, config.n_hidden_layers_to_train)
            optimizer = torch.optim.Adam(model.parameters(), lr=config.adam_optimizer_config.lr / 100, weight_decay=config.adam_optimizer_config.weight_decay)
        print()
    return model, optimizer

In [130]:
# Train using pseudo labelling
def train_model_with_pseudo_labelling():
    model = get_pretrained_model_with_trainable_last_layer(config.num_batch_norm_layers_to_train_params)

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=config.adam_optimizer_config.lr, weight_decay=config.adam_optimizer_config.weight_decay)

    max_val_accuracy = float('-inf')
    argmax_epoch = -1

    for epoch in range(config.max_num_epochs):
        print(f'Epoch #{epoch}:')
        if epoch < config.T_1:
            train_loss = train_single_epoch(model, labelled_train_data_loader, criterion, optimizer, config.num_batch_norm_layers_to_update_running_stats)
        else:
            train_loss = train_single_epoch_pseudo(model, train_data_loader, criterion, optimizer, config.num_batch_norm_layers_to_update_running_stats, epoch)
        print(f'Train loss: {train_loss}')
        val_accuracy = get_model_accuracy(model, val_data_loader)
        if val_accuracy > max_val_accuracy:
            print(f'Validation accuracy: {100 * val_accuracy:.2f}% (new best)')
            max_val_accuracy = val_accuracy
            argmax_epoch = epoch
            save_checkpoint(model, optimizer)
            print('Checkpoint saved')
        else:
            print(f'Validation accuracy: {100 * val_accuracy:.2f}% (worse than {100 * max_val_accuracy:.2f}% of epoch {argmax_epoch})')
            if epoch > argmax_epoch + config.patience:
                print(f'Early stopping')
                break
        # Start training earlier layers
        if epoch == config.train_earlier_layers_delay and config.n_hidden_layers_to_train > 0:
            print (f'Making last {config.n_hidden_layers_to_train} hidden layers trainable')
            make_hidden_layers_trainable(model, config.n_hidden_layers_to_train)
            optimizer = torch.optim.Adam(model.parameters(), lr=config.adam_optimizer_config.lr / 100, weight_decay=config.adam_optimizer_config.weight_decay)
        print()
    return model, optimizer

In [137]:
# Train the model
if config.use_pseudo_labelling:
    model, optimizer = train_model_with_pseudo_labelling()
else:
    model, optimizer = train_model_without_pseudo_labelling()

Epoch #0:


Training model: 100%|██████████| 10/10 [00:01<00:00,  5.63it/s]


Train loss: 3.62202474860107


Computing accuracy: 100%|██████████| 23/23 [00:04<00:00,  5.20it/s]


Validation accuracy: 27.45% (new best)
Checkpoint saved
20

Epoch #1:


Training model: 100%|██████████| 10/10 [00:01<00:00,  5.78it/s]


Train loss: 2.6361574834706833


Computing accuracy: 100%|██████████| 23/23 [00:04<00:00,  5.02it/s]


Validation accuracy: 48.10% (new best)
Checkpoint saved
20

Epoch #2:


Training model: 100%|██████████| 10/10 [00:01<00:00,  5.72it/s]


Train loss: 1.9394163347425915


Computing accuracy: 100%|██████████| 23/23 [00:04<00:00,  5.15it/s]


Validation accuracy: 65.62% (new best)
Checkpoint saved
20

Epoch #3:


Training model: 100%|██████████| 10/10 [00:01<00:00,  5.66it/s]


Train loss: 1.405381446792966


Computing accuracy: 100%|██████████| 23/23 [00:04<00:00,  5.03it/s]


Validation accuracy: 77.17% (new best)
Checkpoint saved
20

Epoch #4:


Training model: 100%|██████████| 10/10 [00:01<00:00,  5.40it/s]


Train loss: 1.0447046643211728


Computing accuracy: 100%|██████████| 23/23 [00:04<00:00,  5.09it/s]


Validation accuracy: 80.43% (new best)
Checkpoint saved
20

Epoch #5:


Training model: 100%|██████████| 10/10 [00:01<00:00,  5.79it/s]


Train loss: 0.8161704653785342


Computing accuracy: 100%|██████████| 23/23 [00:04<00:00,  5.10it/s]


Validation accuracy: 82.47% (new best)
Checkpoint saved
Making last 20 hidden layers trainable
20

Epoch #6:


Training model: 100%|██████████| 10/10 [00:02<00:00,  4.70it/s]


Train loss: 0.48174960434842273


Computing accuracy: 100%|██████████| 23/23 [00:04<00:00,  5.13it/s]


Validation accuracy: 82.47% (worse than 82.47% of epoch 5)
20

Epoch #7:


Training model: 100%|██████████| 10/10 [00:01<00:00,  5.19it/s]


Train loss: 0.1930431545186205


Computing accuracy: 100%|██████████| 23/23 [00:04<00:00,  5.18it/s]


Validation accuracy: 82.74% (new best)
Checkpoint saved
20

Epoch #8:


Training model: 100%|██████████| 10/10 [00:02<00:00,  4.83it/s]


Train loss: 0.1035721719670458


Computing accuracy: 100%|██████████| 23/23 [00:04<00:00,  5.03it/s]


Validation accuracy: 83.15% (new best)
Checkpoint saved
20

Epoch #9:


Training model: 100%|██████████| 10/10 [00:01<00:00,  5.31it/s]


Train loss: 0.059961599912367713


Computing accuracy: 100%|██████████| 23/23 [00:04<00:00,  5.18it/s]


Validation accuracy: 82.61% (worse than 83.15% of epoch 8)
20

Epoch #10:


Training model: 100%|██████████| 10/10 [00:01<00:00,  5.25it/s]


Train loss: 0.03926236973125107


Computing accuracy: 100%|██████████| 23/23 [00:04<00:00,  5.08it/s]


Validation accuracy: 83.70% (new best)
Checkpoint saved
20

Epoch #11:


Training model: 100%|██████████| 10/10 [00:01<00:00,  5.12it/s]


Train loss: 0.025408938622140154


Computing accuracy: 100%|██████████| 23/23 [00:04<00:00,  5.11it/s]


Validation accuracy: 84.92% (new best)
Checkpoint saved
20

Epoch #12:


Training model: 100%|██████████| 10/10 [00:01<00:00,  5.15it/s]


Train loss: 0.017211835096482518


Computing accuracy: 100%|██████████| 23/23 [00:04<00:00,  5.09it/s]


Validation accuracy: 85.05% (new best)
Checkpoint saved
20

Epoch #13:


Training model: 100%|██████████| 10/10 [00:01<00:00,  5.38it/s]


Train loss: 0.013835453713426784


Computing accuracy: 100%|██████████| 23/23 [00:04<00:00,  5.16it/s]


Validation accuracy: 84.10% (worse than 85.05% of epoch 12)
20

Epoch #14:


Training model: 100%|██████████| 10/10 [00:01<00:00,  5.26it/s]


Train loss: 0.010459635425739142


Computing accuracy: 100%|██████████| 23/23 [00:04<00:00,  5.12it/s]


Validation accuracy: 83.97% (worse than 85.05% of epoch 12)
20

Epoch #15:


Training model: 100%|██████████| 10/10 [00:01<00:00,  5.33it/s]


Train loss: 0.009738714893532245


Computing accuracy: 100%|██████████| 23/23 [00:04<00:00,  5.22it/s]


Validation accuracy: 84.10% (worse than 85.05% of epoch 12)
20

Epoch #16:


Training model: 100%|██████████| 10/10 [00:01<00:00,  5.24it/s]


Train loss: 0.006852827326324927


Computing accuracy: 100%|██████████| 23/23 [00:04<00:00,  5.09it/s]


Validation accuracy: 84.38% (worse than 85.05% of epoch 12)
20

Epoch #17:


Training model: 100%|██████████| 10/10 [00:01<00:00,  5.26it/s]


Train loss: 0.005093538173835496


Computing accuracy: 100%|██████████| 23/23 [00:04<00:00,  5.21it/s]


Validation accuracy: 84.24% (worse than 85.05% of epoch 12)
20

Epoch #18:


Training model: 100%|██████████| 10/10 [00:01<00:00,  5.34it/s]


Train loss: 0.004741057397189493


Computing accuracy: 100%|██████████| 23/23 [00:04<00:00,  5.09it/s]

Validation accuracy: 84.38% (worse than 85.05% of epoch 12)
Early stopping





In [138]:
load_checkpoint(model, optimizer)
print('Checkpoint loaded')

test_accuracy = get_model_accuracy(model, test_data_loader)

print(f'Test accuracy: {100 * test_accuracy:.2f}%')

Checkpoint loaded


Computing accuracy: 100%|██████████| 115/115 [00:22<00:00,  5.05it/s]

Test accuracy: 86.29%



