In [1]:
# !unzip drive/MyDrive/Pets-data/images.zip
# !cp -r drive/MyDrive/Pets-data/annotations .
# !ls

In [2]:
import os
import yaml
from typing import Tuple, List

import torch
import torchvision
from tqdm import tqdm
from PIL import Image
import torch.nn as nn
from pydantic import BaseModel
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split

In [3]:
# Define constants

CONFIG_FILE_PATH = 'config.yaml'
IMAGE_COL_IDX = 0
CLASS_ID_COL_IDX = 1
SPECIES_COL_IDX = 2
POSSIBLE_NUM_CLASSES = {2, 37}

In [4]:
class AdamOptimizerConfig(BaseModel):
    lr: float
    weight_decay: float

In [5]:
class Config(BaseModel):
    device: str
    num_classes: int
    batch_size: int
    max_num_epochs: int
    patience: int
    adam_optimizer_config: AdamOptimizerConfig

In [6]:
with open(CONFIG_FILE_PATH, encoding='utf-8') as f:
    config_dict = yaml.load(f, Loader=yaml.FullLoader)

config = Config.model_validate(config_dict)

assert config.num_classes in POSSIBLE_NUM_CLASSES

In [7]:
class ImageDataset(Dataset):
    def __init__(self, filenames: List[str], labels: List[int]) -> None:
        self.filenames = filenames
        self.labels = labels
        self.transformation = torchvision.models.ResNet34_Weights.IMAGENET1K_V1.transforms()

    def __len__(self) -> int:
        return len(self.filenames)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        image = Image.open(os.path.join('images', f'{self.filenames[idx]}.jpg')).convert('RGB')
        label = self.labels[idx]

        transformed_img = self.transformation(image)

        return transformed_img.to(config.device), torch.tensor(label).to(config.device)

In [8]:
def get_image_names_and_labels(annotations_file_path: str, num_classes: int) -> Tuple[List[str], List[int]]:
    filenames: List[str] = []
    labels: List[int] = []

    with open(annotations_file_path, encoding='utf-8') as f:
        lines = f.readlines()

    label_col_idx = SPECIES_COL_IDX if num_classes == 2 else CLASS_ID_COL_IDX

    for line in lines:
        line_split = line.split()
        filenames.append(line_split[IMAGE_COL_IDX])
        labels.append(int(line_split[label_col_idx]) - 1)

    return filenames, labels

In [9]:
def get_pretrained_model_with_trainable_last_layer() -> nn.Module:
    model = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.IMAGENET1K_V1)

    for param in model.parameters():
        param.requires_grad = False

    model.fc = nn.Linear(in_features=model.fc.in_features, out_features=config.num_classes)

    return model.to(config.device)

In [10]:
def get_model_accuracy(model: nn.Module, data_loader: DataLoader) -> float:
    correct_predictions_cnt = 0
    total_predictions_cnt = 0
    model.eval()
    with torch.no_grad():
        for inputs, labels in tqdm(data_loader, desc='Computing accuracy'):
            outputs = model(inputs)
            correct_predictions_cnt += (torch.argmax(outputs, axis=1) == labels).sum()
            total_predictions_cnt += len(outputs)
    return correct_predictions_cnt / total_predictions_cnt

In [11]:
def train_single_epoch(model: nn.Module, train_data_loader: DataLoader, criterion: nn.Module, optimizer: torch.optim.Optimizer) -> float:
    model.train()
    train_loss_sum = 0.0
    train_samples_cnt = 0
    for inputs, labels in tqdm(train_data_loader, desc='Training model'):
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        train_loss_sum += loss.item() * len(outputs)
        train_samples_cnt += len(outputs)
    return train_loss_sum / train_samples_cnt

In [12]:
def save_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer) -> Tuple[nn.Module, torch.optim.Optimizer]:
    checkpoint = {
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }
    torch.save(checkpoint, 'checkpoint.pt')

In [13]:
def load_checkpoint(model: nn.Module, optimizer: torch.optim.Optimizer) -> None:
    checkpoint = torch.load('checkpoint.pt')
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

In [14]:
filenames_trainval, labels_trainval = get_image_names_and_labels('annotations/trainval.txt', num_classes=config.num_classes)
filenames_test, labels_test = get_image_names_and_labels('annotations/test.txt', num_classes=config.num_classes)
filenames_train, filenames_val, labels_train, labels_val = train_test_split(filenames_trainval, labels_trainval, test_size=0.2, stratify=labels_trainval)

dataset_train = ImageDataset(filenames_train, labels_train)
dataset_val = ImageDataset(filenames_val, labels_val)
dataset_test = ImageDataset(filenames_test, labels_test)

train_data_loader = DataLoader(dataset_train, batch_size=config.batch_size, shuffle=True)
val_data_loader = DataLoader(dataset_val, batch_size=config.batch_size, shuffle=False)
test_data_loader = DataLoader(dataset_test, batch_size=config.batch_size, shuffle=False)

In [15]:
model = get_pretrained_model_with_trainable_last_layer()

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=config.adam_optimizer_config.lr, weight_decay=config.adam_optimizer_config.weight_decay)

max_val_accuracy = float('-inf')
argmax_epoch = -1

for epoch in range(config.max_num_epochs):
    print(f'Epoch #{epoch}:')
    train_loss = train_single_epoch(model, train_data_loader, criterion, optimizer)
    print(f'Train loss: {train_loss}')
    val_accuracy = get_model_accuracy(model, val_data_loader)
    if val_accuracy > max_val_accuracy:
        print(f'Validation accuracy: {100 * val_accuracy:.2f}% (new best)')
        max_val_accuracy = val_accuracy
        argmax_epoch = epoch
        save_checkpoint(model, optimizer)
        print('Checkpoint saved')
    else:
        print(f'Validation accuracy: {100 * val_accuracy:.2f}% (worse than {100 * max_val_accuracy:.2f}% of epoch {argmax_epoch})')
        if epoch > argmax_epoch + config.patience:
            print(f'Early stopping')
            break
    print()

Epoch #0:


Training model: 100%|██████████| 92/92 [00:17<00:00,  5.34it/s]


Train loss: 0.19382216032270505


Computing accuracy: 100%|██████████| 23/23 [00:04<00:00,  5.17it/s]


Validation accuracy: 97.69% (new best)
Checkpoint saved

Epoch #1:


Training model: 100%|██████████| 92/92 [00:15<00:00,  5.86it/s]


Train loss: 0.0643600793150456


Computing accuracy: 100%|██████████| 23/23 [00:03<00:00,  6.36it/s]


Validation accuracy: 99.18% (new best)
Checkpoint saved

Epoch #2:


Training model: 100%|██████████| 92/92 [00:14<00:00,  6.22it/s]


Train loss: 0.05387528233594545


Computing accuracy: 100%|██████████| 23/23 [00:03<00:00,  6.26it/s]


Validation accuracy: 99.18% (worse than 99.18% of epoch 1)

Epoch #3:


Training model: 100%|██████████| 92/92 [00:14<00:00,  6.42it/s]


Train loss: 0.049693096024186714


Computing accuracy: 100%|██████████| 23/23 [00:02<00:00,  8.03it/s]


Validation accuracy: 99.18% (worse than 99.18% of epoch 1)

Epoch #4:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.65it/s]


Train loss: 0.04158833804135413


Computing accuracy: 100%|██████████| 23/23 [00:02<00:00,  8.04it/s]


Validation accuracy: 99.05% (worse than 99.18% of epoch 1)

Epoch #5:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.52it/s]


Train loss: 0.04413732932880521


Computing accuracy: 100%|██████████| 23/23 [00:03<00:00,  7.50it/s]


Validation accuracy: 99.18% (worse than 99.18% of epoch 1)

Epoch #6:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.10it/s]


Train loss: 0.03366809609123384


Computing accuracy: 100%|██████████| 23/23 [00:03<00:00,  7.02it/s]


Validation accuracy: 99.32% (new best)
Checkpoint saved

Epoch #7:


Training model: 100%|██████████| 92/92 [00:13<00:00,  7.01it/s]


Train loss: 0.03269082910770996


Computing accuracy: 100%|██████████| 23/23 [00:03<00:00,  7.29it/s]


Validation accuracy: 98.91% (worse than 99.32% of epoch 6)

Epoch #8:


Training model: 100%|██████████| 92/92 [00:13<00:00,  7.07it/s]


Train loss: 0.030781664412959642


Computing accuracy: 100%|██████████| 23/23 [00:03<00:00,  7.23it/s]


Validation accuracy: 99.18% (worse than 99.32% of epoch 6)

Epoch #9:


Training model: 100%|██████████| 92/92 [00:12<00:00,  7.14it/s]


Train loss: 0.0310769115584781


Computing accuracy: 100%|██████████| 23/23 [00:02<00:00,  7.78it/s]

Validation accuracy: 98.78% (worse than 99.32% of epoch 6)






In [16]:
load_checkpoint(model, optimizer)
print('Checkpoint loaded')

test_accuracy = get_model_accuracy(model, test_data_loader)

print(f'Test accuracy: {100 * test_accuracy:.2f}%')

Checkpoint loaded


Computing accuracy: 100%|██████████| 115/115 [00:16<00:00,  6.93it/s]

Test accuracy: 99.26%



