In [1]:
# !unzip drive/MyDrive/Pets-data/images.zip
# !cp -r drive/MyDrive/Pets-data/annotations .
# !ls

In [2]:
import os
import yaml
from typing import Tuple, List

import torch
import torchvision
from tqdm import tqdm
from PIL import Image
import torch.nn as nn
from pydantic import BaseModel
import matplotlib.pyplot as plt
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

In [3]:
# Define constants

CONFIG_FILE_PATH = 'config.yaml'
IMAGE_COL_IDX = 0
CLASS_ID_COL_IDX = 1
SPECIES_COL_IDX = 2
POSSIBLE_NUM_CLASSES = {2, 37}

In [4]:
class AdamOptimizerConfig(BaseModel):
    lr: float
    weight_decay: float

In [5]:
class Config(BaseModel):
    device: str
    num_classes: int
    batch_size: int
    max_num_epochs: int
    patience: int
    adam_optimizer_config: AdamOptimizerConfig
    num_batch_norm_layers_to_train_params: int
    num_batch_norm_layers_to_update_running_stats: int
    train_earlier_layers_delay: int
    n_hidden_layers_to_train: int

In [6]:
with open(CONFIG_FILE_PATH, encoding='utf-8') as f:
    config_dict = yaml.load(f, Loader=yaml.FullLoader)

config = Config.model_validate(config_dict)

assert config.num_classes in POSSIBLE_NUM_CLASSES
assert 0 <= config.num_batch_norm_layers_to_train_params <= 36  # 36 batch norm layers in resnet34
assert 0 <= config.num_batch_norm_layers_to_update_running_stats <= 36  # 36 batch norm layers in resnet34

In [7]:
class ImageDataset(Dataset):
    def __init__(self, filenames: List[str], labels: List[int], use_augmentations: bool) -> None:
        self.filenames = filenames
        self.labels = labels
        # self.transformation = torchvision.models.ResNet34_Weights.IMAGENET1K_V1.transforms()
        self.transformation = (
            transforms.Compose([
                torchvision.models.ResNet34_Weights.IMAGENET1K_V1.transforms(),
                # transforms.RandomCrop((what size, what other size)),
                transforms.RandomHorizontalFlip(),
                # transforms.RandomRotation((-10, 10)),
                # transforms.RandomErasing(),
            ])
            if use_augmentations
            else torchvision.models.ResNet34_Weights.IMAGENET1K_V1.transforms()
        )

    def __len__(self) -> int:
        return len(self.filenames)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        image = Image.open(os.path.join('images', f'{self.filenames[idx]}.jpg')).convert('RGB')
        label = self.labels[idx]

        transformed_img = self.transformation(image)

        return transformed_img.to(config.device), torch.tensor(label).to(config.device)

In [8]:
def get_image_names_and_labels(annotations_file_path: str, num_classes: int) -> Tuple[List[str], List[int]]:
    filenames: List[str] = []
    labels: List[int] = []

    with open(annotations_file_path, encoding='utf-8') as f:
        lines = f.readlines()

    label_col_idx = SPECIES_COL_IDX if num_classes == 2 else CLASS_ID_COL_IDX

    for line in lines:
        line_split = line.split()
        filenames.append(line_split[IMAGE_COL_IDX])
        labels.append(int(line_split[label_col_idx]) - 1)

    return filenames, labels

In [9]:
def get_batch_norm_layers(model: nn.Module) -> List[nn.Module]:
    return [
        module
        for module in model.modules()
        if isinstance(module, nn.BatchNorm2d)
    ]

In [10]:
def get_pretrained_model_with_trainable_last_layer(num_batch_norm_layers_to_train_params: int) -> nn.Module:
    model = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)

    for param in model.parameters():
        param.requires_grad = False

    batch_norm_layers = get_batch_norm_layers(model)
    batch_norm_layers_to_train_params = batch_norm_layers[-num_batch_norm_layers_to_train_params:] if num_batch_norm_layers_to_train_params else []

    for batch_norm_layer in batch_norm_layers_to_train_params:
        for param in batch_norm_layer.parameters():
            param.requires_grad = True

    model.fc = nn.Linear(in_features=model.fc.in_features, out_features=config.num_classes)

    return model.to(config.device)

In [11]:
def get_pretrained_model_with_trainable_n_layers(num_batch_norm_layers_to_train_params: int, number_of_hidden_layers_to_train: int=0) -> nn.Module:
    model = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)

    for param in model.parameters():
        param.requires_grad = False

    batch_norm_layers = get_batch_norm_layers(model)
    batch_norm_layers_to_train_params = batch_norm_layers[-num_batch_norm_layers_to_train_params:] if num_batch_norm_layers_to_train_params else []

    for batch_norm_layer in batch_norm_layers_to_train_params:
        for param in batch_norm_layer.parameters():
            param.requires_grad = True

    model.fc = nn.Linear(in_features=model.fc.in_features, out_features=config.num_classes)

    trainable_layers = [layer for layer in model.modules() if not isinstance(layer, torchvision.models.resnet.ResNet) and not isinstance(layer, torchvision.models.resnet.BasicBlock) and not isinstance(layer, nn.Sequential) and not isinstance(layer, nn.Sequential) and len(list(layer.parameters())) > 0]
    trainable_layers = list(reversed(trainable_layers[:-1]))

    for l in trainable_layers[:number_of_hidden_layers_to_train]:
        for p in l.parameters():
            p.requires_grad = True

    return model.to(config.device)

In [12]:
def make_hidden_layers_trainable(model: nn.Module, number_of_hidden_layers_to_train: int) -> nn.Module:
    trainable_layers = [layer for layer in model.modules() if not isinstance(layer, torchvision.models.resnet.ResNet) and not isinstance(layer, torchvision.models.resnet.BasicBlock) and not isinstance(layer, nn.Sequential) and not isinstance(layer, nn.Sequential) and len(list(layer.parameters())) > 0]
    trainable_layers = list(reversed(trainable_layers[:-1]))

    for l in trainable_layers[:number_of_hidden_layers_to_train]:
        for p in l.parameters():
            p.requires_grad = True

    return model

In [13]:
def get_model_accuracy(model: nn.Module, data_loader: DataLoader) -> float:
    correct_predictions_cnt = 0
    total_predictions_cnt = 0
    model.eval()
    with torch.no_grad():
        for inputs, labels in tqdm(data_loader, desc='Computing accuracy'):
            outputs = model(inputs)
            correct_predictions_cnt += (torch.argmax(outputs, axis=1) == labels).sum()
            total_predictions_cnt += len(outputs)
    return correct_predictions_cnt / total_predictions_cnt

In [14]:
def train_single_epoch(
        model: nn.Module,
        train_data_loader: DataLoader,
        criterion: nn.Module,
        optimizer: torch.optim.Optimizer,
        num_batch_norm_layers_to_update_running_stats: int,
        ) -> float:
    model.train()
    batch_norm_layers = get_batch_norm_layers(model)
    batch_norm_layers_to_not_update_running_stats = (
        batch_norm_layers[:-num_batch_norm_layers_to_update_running_stats]
        if num_batch_norm_layers_to_update_running_stats
        else batch_norm_layers
    )
    for batch_norm_layer in batch_norm_layers_to_not_update_running_stats:
        batch_norm_layer.eval()
    train_loss_sum = 0.0
    train_samples_cnt = 0
    for inputs, labels in tqdm(train_data_loader, desc='Training model'):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss_sum += loss.item() * len(outputs)
        train_samples_cnt += len(outputs)
    return train_loss_sum / train_samples_cnt

In [15]:
def save_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer) -> Tuple[nn.Module, torch.optim.Optimizer]:
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }
    torch.save(checkpoint, 'checkpoint.pt')

In [16]:
def load_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer) -> None:
    checkpoint = torch.load('checkpoint.pt')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [17]:
filenames_trainval, labels_trainval = get_image_names_and_labels('annotations/trainval.txt', num_classes=config.num_classes)
filenames_test, labels_test = get_image_names_and_labels('annotations/test.txt', num_classes=config.num_classes)
filenames_train, filenames_val, labels_train, labels_val = train_test_split(filenames_trainval, labels_trainval, test_size=0.2, stratify=labels_trainval)

dataset_train = ImageDataset(filenames_train, labels_train, use_augmentations=True)
dataset_val = ImageDataset(filenames_val, labels_val, use_augmentations=False)
dataset_test = ImageDataset(filenames_test, labels_test, use_augmentations=False)

train_data_loader = DataLoader(dataset_train, batch_size=config.batch_size, shuffle=True)
val_data_loader = DataLoader(dataset_val, batch_size=config.batch_size, shuffle=False)
test_data_loader = DataLoader(dataset_test, batch_size=config.batch_size, shuffle=False)

In [18]:
model = get_pretrained_model_with_trainable_last_layer(config.num_batch_norm_layers_to_train_params)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.adam_optimizer_config.lr, weight_decay=config.adam_optimizer_config.weight_decay)

max_val_accuracy = float('-inf')
argmax_epoch = -1

for epoch in range(config.max_num_epochs):
    print(f'Epoch #{epoch}:')
    train_loss = train_single_epoch(model, train_data_loader, criterion, optimizer, config.num_batch_norm_layers_to_update_running_stats)
    print(f'Train loss: {train_loss}')
    val_accuracy = get_model_accuracy(model, val_data_loader)
    if val_accuracy > max_val_accuracy:
        print(f'Validation accuracy: {100 * val_accuracy:.2f}% (new best)')
        max_val_accuracy = val_accuracy
        argmax_epoch = epoch
        save_checkpoint(model, optimizer)
        print('Checkpoint saved')
    else:
        print(f'Validation accuracy: {100 * val_accuracy:.2f}% (worse than {100 * max_val_accuracy:.2f}% of epoch {argmax_epoch})')
        if epoch > argmax_epoch + config.patience:
            print(f'Early stopping')
            break
    # Start training earlier layers
    if epoch == config.train_earlier_layers_delay and config.n_hidden_layers_to_train > 0:
        print (f'Making last {config.n_hidden_layers_to_train} hidden layers trainable')
        make_hidden_layers_trainable(model, config.n_hidden_layers_to_train)
        optimizer = torch.optim.Adam(model.parameters(), lr=config.adam_optimizer_config.lr / 100, weight_decay=config.adam_optimizer_config.weight_decay)
    print()

Epoch #0:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.19it/s]


Train loss: 1.6885751090619876


Computing accuracy: 100%|██████████| 23/23 [00:02<00:00,  7.88it/s]


Validation accuracy: 88.18% (new best)
Checkpoint saved

Epoch #1:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.44it/s]


Train loss: 0.5084084165485009


Computing accuracy: 100%|██████████| 23/23 [00:02<00:00,  8.15it/s]


Validation accuracy: 91.85% (new best)
Checkpoint saved

Epoch #2:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.60it/s]


Train loss: 0.33681971288245655


Computing accuracy: 100%|██████████| 23/23 [00:02<00:00,  8.19it/s]


Validation accuracy: 91.58% (worse than 91.85% of epoch 1)

Epoch #3:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.62it/s]


Train loss: 0.25884414504727593


Computing accuracy: 100%|██████████| 23/23 [00:02<00:00,  8.08it/s]


Validation accuracy: 93.48% (new best)
Checkpoint saved

Epoch #4:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.44it/s]


Train loss: 0.20880886062007883


Computing accuracy: 100%|██████████| 23/23 [00:02<00:00,  8.13it/s]


Validation accuracy: 93.34% (worse than 93.48% of epoch 3)

Epoch #5:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.58it/s]


Train loss: 0.17626360577085745


Computing accuracy: 100%|██████████| 23/23 [00:02<00:00,  8.12it/s]


Validation accuracy: 93.61% (new best)
Checkpoint saved
Making last 3 hidden layers trainable

Epoch #6:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.33it/s]


Train loss: 0.14512542906500722


Computing accuracy: 100%|██████████| 23/23 [00:02<00:00,  8.00it/s]


Validation accuracy: 93.61% (worse than 93.61% of epoch 5)

Epoch #7:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.54it/s]


Train loss: 0.1267847473449681


Computing accuracy: 100%|██████████| 23/23 [00:02<00:00,  8.14it/s]


Validation accuracy: 93.48% (worse than 93.61% of epoch 5)

Epoch #8:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.57it/s]


Train loss: 0.11769688429067963


Computing accuracy: 100%|██████████| 23/23 [00:02<00:00,  8.23it/s]


Validation accuracy: 93.48% (worse than 93.61% of epoch 5)

Epoch #9:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.53it/s]


Train loss: 0.10921055382198613


Computing accuracy: 100%|██████████| 23/23 [00:02<00:00,  8.32it/s]


Validation accuracy: 93.61% (worse than 93.61% of epoch 5)

Epoch #10:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.34it/s]


Train loss: 0.09937814983498791


Computing accuracy: 100%|██████████| 23/23 [00:02<00:00,  8.14it/s]


Validation accuracy: 94.29% (new best)
Checkpoint saved

Epoch #11:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.45it/s]


Train loss: 0.09259062867773615


Computing accuracy: 100%|██████████| 23/23 [00:02<00:00,  8.04it/s]


Validation accuracy: 94.16% (worse than 94.29% of epoch 10)

Epoch #12:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.40it/s]


Train loss: 0.08814183286512675


Computing accuracy: 100%|██████████| 23/23 [00:02<00:00,  8.01it/s]


Validation accuracy: 94.16% (worse than 94.29% of epoch 10)

Epoch #13:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.51it/s]


Train loss: 0.08415749513179711


Computing accuracy: 100%|██████████| 23/23 [00:02<00:00,  7.96it/s]


Validation accuracy: 93.61% (worse than 94.29% of epoch 10)

Epoch #14:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.47it/s]


Train loss: 0.07949819259912423


Computing accuracy: 100%|██████████| 23/23 [00:02<00:00,  8.28it/s]


Validation accuracy: 93.61% (worse than 94.29% of epoch 10)

Epoch #15:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.44it/s]


Train loss: 0.07381837369631165


Computing accuracy: 100%|██████████| 23/23 [00:02<00:00,  8.13it/s]


Validation accuracy: 93.89% (worse than 94.29% of epoch 10)

Epoch #16:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.59it/s]


Train loss: 0.07130614143755773


Computing accuracy: 100%|██████████| 23/23 [00:02<00:00,  8.05it/s]

Validation accuracy: 93.07% (worse than 94.29% of epoch 10)
Early stopping





In [19]:
load_checkpoint(model, optimizer)
print('Checkpoint loaded')

test_accuracy = get_model_accuracy(model, test_data_loader)

print(f'Test accuracy: {100 * test_accuracy:.2f}%')

Checkpoint loaded


Computing accuracy: 100%|██████████| 115/115 [00:15<00:00,  7.57it/s]

Test accuracy: 91.44%



