In [1]:
# !unzip drive/MyDrive/Pets-data/images.zip
# !cp -r drive/MyDrive/Pets-data/annotations .
# !ls

In [2]:
import os
import yaml
from typing import Tuple, List

import torch
import torchvision
from tqdm import tqdm
from PIL import Image
import torch.nn as nn
from pydantic import BaseModel
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

In [3]:
# Define constants

CONFIG_FILE_PATH = 'config.yaml'
IMAGE_COL_IDX = 0
CLASS_ID_COL_IDX = 1
SPECIES_COL_IDX = 2
POSSIBLE_NUM_CLASSES = {2, 37}

In [4]:
class AdamOptimizerConfig(BaseModel):
    lr: float
    weight_decay: float

In [5]:
class Config(BaseModel):
    device: str
    num_classes: int
    batch_size: int
    max_num_epochs: int
    patience: int
    adam_optimizer_config: AdamOptimizerConfig
    num_batch_norm_layers_to_train_params: int
    num_batch_norm_layers_to_update_running_stats: int

In [6]:
with open(CONFIG_FILE_PATH, encoding='utf-8') as f:
    config_dict = yaml.load(f, Loader=yaml.FullLoader)

config = Config.model_validate(config_dict)

assert config.num_classes in POSSIBLE_NUM_CLASSES
assert 0 <= config.num_batch_norm_layers_to_train_params <= 36  # 36 batch norm layers in resnet34
assert 0 <= config.num_batch_norm_layers_to_update_running_stats <= 36  # 36 batch norm layers in resnet34

In [7]:
class ImageDataset(Dataset):
    def __init__(self, filenames: List[str], labels: List[int]) -> None:
        self.filenames = filenames
        self.labels = labels
        self.transformation = torchvision.models.ResNet34_Weights.IMAGENET1K_V1.transforms()

    def __len__(self) -> int:
        return len(self.filenames)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        image = Image.open(os.path.join('images', f'{self.filenames[idx]}.jpg')).convert('RGB')
        label = self.labels[idx]

        transformed_img = self.transformation(image)

        return transformed_img.to(config.device), torch.tensor(label).to(config.device)

In [8]:
def get_image_names_and_labels(annotations_file_path: str, num_classes: int) -> Tuple[List[str], List[int]]:
    filenames: List[str] = []
    labels: List[int] = []

    with open(annotations_file_path, encoding='utf-8') as f:
        lines = f.readlines()

    label_col_idx = SPECIES_COL_IDX if num_classes == 2 else CLASS_ID_COL_IDX

    for line in lines:
        line_split = line.split()
        filenames.append(line_split[IMAGE_COL_IDX])
        labels.append(int(line_split[label_col_idx]) - 1)

    return filenames, labels

In [9]:
def get_batch_norm_layers(model: nn.Module) -> List[nn.Module]:
    return [
        module
        for module in model.modules()
        if isinstance(module, nn.BatchNorm2d)
    ]

In [10]:
def get_pretrained_model_with_trainable_last_layer(num_batch_norm_layers_to_train_params: int) -> nn.Module:
    model = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)

    for param in model.parameters():
        param.requires_grad = False
    
    batch_norm_layers = get_batch_norm_layers(model)
    batch_norm_layers_to_train_params = batch_norm_layers[-num_batch_norm_layers_to_train_params:] if num_batch_norm_layers_to_train_params else []

    for batch_norm_layer in batch_norm_layers_to_train_params:
        for param in batch_norm_layer.parameters():
            param.requires_grad = True

    model.fc = nn.Linear(in_features=model.fc.in_features, out_features=config.num_classes)

    return model.to(config.device)

In [11]:
def get_model_accuracy(model: nn.Module, data_loader: DataLoader) -> float:
    correct_predictions_cnt = 0
    total_predictions_cnt = 0
    model.eval()
    with torch.no_grad():
        for inputs, labels in tqdm(data_loader, desc='Computing accuracy'):
            outputs = model(inputs)
            correct_predictions_cnt += (torch.argmax(outputs, axis=1) == labels).sum()
            total_predictions_cnt += len(outputs)
    return correct_predictions_cnt / total_predictions_cnt

In [12]:
def train_single_epoch(
        model: nn.Module,
        train_data_loader: DataLoader,
        criterion: nn.Module,
        optimizer: torch.optim.Optimizer,
        num_batch_norm_layers_to_update_running_stats: int,
        ) -> float:
    model.train()
    batch_norm_layers = get_batch_norm_layers(model)
    batch_norm_layers_to_not_update_running_stats = (
        batch_norm_layers[:-num_batch_norm_layers_to_update_running_stats]
        if num_batch_norm_layers_to_update_running_stats
        else batch_norm_layers
    )
    for batch_norm_layer in batch_norm_layers_to_not_update_running_stats:
        batch_norm_layer.eval()
    train_loss_sum = 0.0
    train_samples_cnt = 0
    for inputs, labels in tqdm(train_data_loader, desc='Training model'):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss_sum += loss.item() * len(outputs)
        train_samples_cnt += len(outputs)
    return train_loss_sum / train_samples_cnt

In [13]:
def save_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer) -> Tuple[nn.Module, torch.optim.Optimizer]:
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }
    torch.save(checkpoint, 'checkpoint.pt')

In [14]:
def load_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer) -> None:
    checkpoint = torch.load('checkpoint.pt')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [15]:
filenames_trainval, labels_trainval = get_image_names_and_labels('annotations/trainval.txt', num_classes=config.num_classes)
filenames_test, labels_test = get_image_names_and_labels('annotations/test.txt', num_classes=config.num_classes)
filenames_train, filenames_val, labels_train, labels_val = train_test_split(filenames_trainval, labels_trainval, test_size=0.2, stratify=labels_trainval)

dataset_train = ImageDataset(filenames_train, labels_train)
dataset_val = ImageDataset(filenames_val, labels_val)
dataset_test = ImageDataset(filenames_test, labels_test)

train_data_loader = DataLoader(dataset_train, batch_size=config.batch_size, shuffle=True)
val_data_loader = DataLoader(dataset_val, batch_size=config.batch_size, shuffle=False)
test_data_loader = DataLoader(dataset_test, batch_size=config.batch_size, shuffle=False)

In [16]:
model = get_pretrained_model_with_trainable_last_layer(config.num_batch_norm_layers_to_train_params)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.adam_optimizer_config.lr, weight_decay=config.adam_optimizer_config.weight_decay)

max_val_accuracy = float('-inf')
argmax_epoch = -1

for epoch in range(config.max_num_epochs):
    print(f'Epoch #{epoch}:')
    train_loss = train_single_epoch(model, train_data_loader, criterion, optimizer, config.num_batch_norm_layers_to_update_running_stats)
    print(f'Train loss: {train_loss}')
    val_accuracy = get_model_accuracy(model, val_data_loader)
    if val_accuracy > max_val_accuracy:
        print(f'Validation accuracy: {100 * val_accuracy:.2f}% (new best)')
        max_val_accuracy = val_accuracy
        argmax_epoch = epoch
        save_checkpoint(model, optimizer)
        print('Checkpoint saved')
    else:
        print(f'Validation accuracy: {100 * val_accuracy:.2f}% (worse than {100 * max_val_accuracy:.2f}% of epoch {argmax_epoch})')
        if epoch > argmax_epoch + config.patience:
            print(f'Early stopping')
            break
    print()

Epoch #0:


Training model: 100%|██████████| 92/92 [00:11<00:00,  7.77it/s]


Train loss: 1.697668322402498


Computing accuracy: 100%|██████████| 23/23 [00:02<00:00,  7.71it/s]


Validation accuracy: 85.87% (new best)
Checkpoint saved

Epoch #1:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.55it/s]


Train loss: 0.49133794009685516


Computing accuracy: 100%|██████████| 23/23 [00:02<00:00,  7.89it/s]


Validation accuracy: 88.32% (new best)
Checkpoint saved

Epoch #2:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.40it/s]


Train loss: 0.3179315498665623


Computing accuracy: 100%|██████████| 23/23 [00:02<00:00,  7.88it/s]


Validation accuracy: 90.22% (new best)
Checkpoint saved

Epoch #3:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.54it/s]


Train loss: 0.24672293541548046


Computing accuracy: 100%|██████████| 23/23 [00:02<00:00,  8.09it/s]


Validation accuracy: 90.35% (new best)
Checkpoint saved

Epoch #4:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.54it/s]


Train loss: 0.20043843806437825


Computing accuracy: 100%|██████████| 23/23 [00:02<00:00,  8.05it/s]


Validation accuracy: 92.80% (new best)
Checkpoint saved

Epoch #5:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.46it/s]


Train loss: 0.16613922227659952


Computing accuracy: 100%|██████████| 23/23 [00:03<00:00,  7.44it/s]


Validation accuracy: 93.07% (new best)
Checkpoint saved

Epoch #6:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.10it/s]


Train loss: 0.14449505592979814


Computing accuracy: 100%|██████████| 23/23 [00:03<00:00,  7.39it/s]


Validation accuracy: 92.39% (worse than 93.07% of epoch 5)

Epoch #7:


Training model: 100%|██████████| 92/92 [00:13<00:00,  7.06it/s]


Train loss: 0.1234120801496117


Computing accuracy: 100%|██████████| 23/23 [00:03<00:00,  7.51it/s]


Validation accuracy: 92.26% (worse than 93.07% of epoch 5)

Epoch #8:


Training model: 100%|██████████| 92/92 [00:13<00:00,  7.01it/s]


Train loss: 0.11033547989538182


Computing accuracy: 100%|██████████| 23/23 [00:03<00:00,  7.49it/s]


Validation accuracy: 93.61% (new best)
Checkpoint saved

Epoch #9:


Training model: 100%|██████████| 92/92 [00:13<00:00,  7.01it/s]


Train loss: 0.09842986459641354


Computing accuracy: 100%|██████████| 23/23 [00:03<00:00,  7.40it/s]


Validation accuracy: 93.07% (worse than 93.61% of epoch 8)

Epoch #10:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.10it/s]


Train loss: 0.09096334376574858


Computing accuracy: 100%|██████████| 23/23 [00:03<00:00,  7.41it/s]


Validation accuracy: 92.39% (worse than 93.61% of epoch 8)

Epoch #11:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.10it/s]


Train loss: 0.0803109509465487


Computing accuracy: 100%|██████████| 23/23 [00:03<00:00,  7.47it/s]


Validation accuracy: 93.21% (worse than 93.61% of epoch 8)

Epoch #12:


Training model: 100%|██████████| 92/92 [00:13<00:00,  6.98it/s]


Train loss: 0.07199683495918693


Computing accuracy: 100%|██████████| 23/23 [00:03<00:00,  7.42it/s]


Validation accuracy: 93.34% (worse than 93.61% of epoch 8)

Epoch #13:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.09it/s]


Train loss: 0.06628377077615132


Computing accuracy: 100%|██████████| 23/23 [00:03<00:00,  7.35it/s]


Validation accuracy: 93.34% (worse than 93.61% of epoch 8)

Epoch #14:


Training model: 100%|██████████| 92/92 [00:13<00:00,  6.98it/s]


Train loss: 0.062058944902990174


Computing accuracy: 100%|██████████| 23/23 [00:03<00:00,  7.41it/s]

Validation accuracy: 93.21% (worse than 93.61% of epoch 8)
Early stopping





In [17]:
load_checkpoint(model, optimizer)
print('Checkpoint loaded')

test_accuracy = get_model_accuracy(model, test_data_loader)

print(f'Test accuracy: {100 * test_accuracy:.2f}%')

Checkpoint loaded


Computing accuracy: 100%|██████████| 115/115 [00:16<00:00,  7.14it/s]

Test accuracy: 90.95%



