# Train PawPrint on Colab

In [None]:
WANDB_API_KEY = "..."
WANDB_ENTITY = "..."

In [2]:
!pip install -qq wandb

In [None]:
import sys
import os
import glob
from pathlib import Path
from dataclasses import dataclass, field
from typing import Optional, Literal
from datetime import datetime
import gc
import warnings
from PIL import Image

import wandb
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score
from sklearn.model_selection import train_test_split
import torch
from torch import autocast
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torchvision.models as models

## Utils

In [4]:
def _time_stamp():
    current_time = datetime.now()
    return current_time.strftime("%y%m%d%H%M%S%f")


def _format_name(name: str, max_len: int = 30) -> str:
    name = name.strip().lower().replace(" ", "-")
    return name[:max_len]


def clear_cache():
    gc.collect()
    torch.cuda.empty_cache()


def device(force_cuda=True) -> torch.device:
    has_cuda = torch.cuda.is_available()
    if force_cuda:
        assert has_cuda, "CUDA is not available."
        return torch.device("cuda")
    return torch.device("cuda") if has_cuda else torch.device("cpu")


def ignore_warnings():
    warnings.filterwarnings("ignore")


def fix_random_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


class EarlyStopping(object):
    """Stop training when loss does not decrease"""

    def __init__(self, patience: int, path_to_save: str):
        self._min_loss = float("inf")
        self._patience = patience
        self._path = path_to_save
        self.__check_point = None
        self.__counter = 0

    def should_stop(self, loss: float, model: torch.nn.Module, epoch: int) -> bool:
        if loss < self._min_loss:
            self._min_loss = loss
            self.__counter = 0
            self.__check_point = epoch
            torch.save(model.state_dict(), self._path)
        elif loss > self._min_loss:
            self.__counter += 1
            if self.__counter == self._patience:
                return True
        return False

    def load(self, weights_only=True):
        return torch.load(self._path, weights_only=weights_only)

    @property
    def check_point(self):
        if self.__check_point is None:
            raise ValueError("No check point is saved!")
        return self.__check_point

    @property
    def best_loss(self):
        return self._min_loss


class Colab:
    def __init__(
        self, mount_path="/content/drive", project_path="/content/drive/MyDrive"
    ):
        assert (
            importlib.util.find_spec("google.colab") is not None
        ), "This class is only available in Google Colab. Cannot import `google.colab` in current environment."
        self._mount_path = mount_path
        self._project_path = project_path

    def mount_drive(self, force_remount=False):
        from google.colab import drive

        drive.mount(self._mount_path, force_remount=force_remount)
        sys.path.append(self._project_path)

    def join(self, *args):
        return os.path.join(self._project_path, *args)


@dataclass()
class Config:
    name: str
    model: Literal["resnet", "vit"]
    batch: int
    epochs: int
    weight_decay: float
    lr: float
    backbone_lr: float
    enable_fp16: Optional[bool] = field(default=False)
    patience: Optional[int] = field(default=0)

    # Will be filled automatically
    id: str = field(init=False)
    model_path: str = field(init=False)

    def __post_init__(self):
        self.id = _time_stamp()
        self.name = _format_name(self.name)
        self.model_path = f"{self.name}_{self.id}.pt"

        if self.batch < 1:
            raise ValueError("batch must be positive integer.")
        if self.epochs < 0:
            raise ValueError("epochs must be positive integer.")
        if self.lr < 0:
            raise ValueError("lr must be positive float.")

    def add(self, **kwargs):
        for k, v in kwargs.items():
            if k in self.__dict__.keys():
                raise KeyError(f"Duplicate key: {k}")
            object.__setattr__(self, k, v)

    def to_dict(self):
        return self.__dict__

    def __repr__(self):
        pairs = self.to_dict().items()
        pairs = ", ".join([f"{k}={v}" for k, v in pairs])
        name = str(self.__class__.__name__)
        return f"{name}({pairs})"

## Config

In [5]:
config = Config(
    name="vit-l-16",
    model="vit",
    batch=16,
    epochs=50,
    lr=1e-3,
    backbone_lr=1e-5,
    enable_fp16=True,
    patience=10,
    weight_decay=1e-3,
)

In [None]:
gdrive = Colab(project_path="/content/drive/MyDrive/pawprint")
gdrive.mount_drive()

ignore_warnings()
fix_random_seed(42)
device = device(force_cuda=True)

In [None]:
wandb.login(key=WANDB_API_KEY)
wandb.init(
    project="pawprint", entity=WANDB_ENTITY, name=config.id, config=config.to_dict()
)

## Dataset

In [8]:
class PawPrintDataset(Dataset):
    def __init__(self, datasets, labels, transformations):
        self.transform = transforms.Compose(transformations)
        self.datasets = datasets
        self.labels = labels

    def __len__(self):
        return len(self.datasets)

    def __getitem__(self, idx):
        image_path, label = self.datasets[idx]
        image = Image.open(image_path).convert("RGB")
        label = torch.tensor(self.labels[label], dtype=torch.long)
        if self.transform:
            image = self.transform(image)
        return image, label

    def decode(self, label: int) -> str | None:
        for text_label, int_label in self.labels.items():
            if label == int_label:
                return text_label
        return None


def read_img_as_dataset(img_paths):
    label_dict = {}
    datasets = []
    tmp_labels = set()

    source_pattern = os.path.join(img_paths, "*/*/*.png")
    file_list = glob.glob(source_pattern)
    for file_path in file_list:
        text_label = Path(file_path).parent.name
        text_label = str(text_label).lower().strip()

        tmp_labels.add(text_label)
        datasets.append((file_path, text_label))

    for idx, text_label in enumerate(tmp_labels):
        label_dict[text_label] = idx

    return datasets, label_dict

In [9]:
augmentation_pipeline = [
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.TrivialAugmentWide(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]

# Preprocessing for ImageNet pre-trained weights
normalize_pipeline = [
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]

train_imgs, label_dict = read_img_as_dataset(gdrive.join("PP/train"))
test_imgs, _ = read_img_as_dataset(gdrive.join("PP/test"))

train_imgs, val_imgs = train_test_split(
    train_imgs, test_size=0.2, shuffle=True, random_state=42
)

train_set = PawPrintDataset(
    train_imgs, label_dict, transformations=augmentation_pipeline
)
val_set = PawPrintDataset(val_imgs, label_dict, transformations=normalize_pipeline)
test_set = PawPrintDataset(test_imgs, label_dict, transformations=normalize_pipeline)

train_loader = DataLoader(
    train_set, batch_size=config.batch, pin_memory=True, shuffle=True
)
val_loader = DataLoader(val_set, batch_size=config.batch, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=config.batch, pin_memory=True)

print(f"train: {len(train_set)}")
print(f"val: {len(val_set)}")
print(f"test: {len(test_set)}")
print(f"lables: {label_dict}")
print(f"# of lables: {len(label_dict)}")

train: 514
val: 129
test: 290
lables: {'dog_pp': 0, 'dog_sol': 1, 'dog_angae': 2, 'cat_sky': 3, 'dog_wangbal': 4, 'dog_pony': 5, 'dog_bbibbi': 6, 'dog_coco': 7, 'dog_mina': 8, 'dog_pori': 9, 'cat_seoli': 10, 'cat_geomson': 11, 'cat_munji': 12, 'dog_kong': 13, 'dog_mi': 14, 'cat_star': 15, 'dog_bori': 16, 'cat_morae': 17, 'cat_lilly': 18, 'dog_ming': 19}
# of lables: 20


## Model

In [10]:
# Load pretrained weights if needed
vit_model_path = gdrive.join("vit_l_16.pth")
if config.model == "vit" and not os.path.exists(vit_model_path):
    vit_l_16 = models.vit_l_16(weights=models.ViT_L_16_Weights.IMAGENET1K_V1)
    torch.save(vit_l_16.state_dict(), vit_model_path)

resnet_model_path = gdrive.join("resnet_152.pth")
if config.model == "resnet" and not os.path.exists(resnet_model_path):
    resnet_152 = models.resnet152(weights=models.ResNet152_Weights.IMAGENET1K_V1)
    torch.save(resnet_152.state_dict(), resnet_model_path)

In [11]:
class ViTransformer(nn.Module):
    def __init__(self, num_labels):
        super(ViTransformer, self).__init__()
        self.feature_extractor = models.vit_l_16()
        self.feature_extractor.load_state_dict(torch.load(vit_model_path))
        n_features = self.feature_extractor.heads.head.in_features
        self.feature_extractor.heads.head = nn.Identity()

        self.classifier = nn.Sequential(
            nn.Linear(n_features, 512),
            nn.ReLU(),
            nn.Linear(512, num_labels),
        )

    def forward(self, x):
        features = self.feature_extractor(x)
        output = self.classifier(features)
        return output

In [12]:
class ResNet(nn.Module):
    def __init__(self, num_labels):
        super(ResNet, self).__init__()
        self.feature_extractor = models.resnet152()
        self.feature_extractor.load_state_dict(torch.load(resnet_model_path))
        n_features = self.feature_extractor.fc.in_features
        self.feature_extractor.fc = nn.Identity()

        self.classifier = nn.Sequential(
            nn.Linear(n_features, 512),
            nn.ReLU(),
            nn.Linear(512, num_labels),
        )

    def forward(self, x):
        features = self.feature_extractor(x)
        output = self.classifier(features)
        return output

## Train

In [13]:
@torch.no_grad()
def validate(model, criterion, val_loader, device):
    model.eval()

    val_loss = list()
    model_preds = list()
    true_labels = list()

    for data in val_loader:
        image_input, test_label = data
        image_input = image_input.to(device, non_blocking=True)
        test_label = test_label.to(device, non_blocking=True)

        with autocast(
            device_type=str(device), enabled=config.enable_fp16, dtype=torch.float16
        ):
            output = model(image_input)
            batch_loss = criterion(output, test_label)

        val_loss.append(batch_loss.item())

        model_preds += output.argmax(1).detach().cpu().numpy().tolist()
        true_labels += test_label.detach().cpu().numpy().tolist()

    return val_loss


def train(
    model,
    optimizer,
    criterion,
    early_stopper,
    scheduler,
    train_loader,
    val_loader,
    device,
):
    clear_cache()

    for epoch in range(config.epochs):
        model.train()
        train_loss = list()
        for data in train_loader:

            image_input, train_label = data
            image_input = image_input.to(device, non_blocking=True)
            train_label = train_label.to(device, non_blocking=True)

            with autocast(
                device_type=str(device), enabled=config.enable_fp16, dtype=torch.float16
            ):
                output = model(image_input)
                batch_loss = criterion(output, train_label.long())

            train_loss.append(batch_loss.item())

            batch_loss.backward()
            optimizer.step()
            model.zero_grad()

        val_loss = validate(model, criterion, val_loader, device)
        train_loss = np.mean(train_loss)
        val_loss = np.mean(val_loss)
        wandb.log({"train_loss": train_loss, "val_loss": val_loss})

        tqdm.write(
            f"Epoch {epoch}, Train-Loss: {train_loss:.5f},  Val-Loss: {val_loss:.5f}"
        )

        if early_stopper.should_stop(val_loss, model, epoch):
            break

        scheduler.step()

    tqdm.write(f"\n\n -- EarlyStopping: [Epoch: {early_stopper.check_point}]")

In [14]:
num_labels = len(label_dict)
model = (
    ResNet(num_labels=num_labels)
    if config.model == "resnet"
    else ViTransformer(num_labels=num_labels)
)
optimizer = optim.AdamW(
    [
        {"params": model.feature_extractor.parameters(), "lr": config.backbone_lr},
        {"params": model.classifier.parameters(), "lr": config.lr},
    ],
    weight_decay=config.weight_decay,
)
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: 0.9**epoch)
criterion = nn.CrossEntropyLoss()
early_stopper = EarlyStopping(
    patience=config.patience,
    path_to_save=gdrive.join("weights", f"{config.name}_{config.id}.pth"),
)

model.to(device)
criterion.to(device)

print(f"model: {model.__class__.__name__}")
print(f"optimizer: {optimizer.__class__.__name__}")

model: ViTransformer
optimizer: AdamW


In [15]:
train(
    model,
    optimizer,
    criterion,
    early_stopper,
    scheduler,
    train_loader,
    val_loader,
    device,
)

Epoch 0, Train-Loss: 1.49507,  Val-Loss: 0.89857
Epoch 1, Train-Loss: 0.48274,  Val-Loss: 0.40021
Epoch 2, Train-Loss: 0.22343,  Val-Loss: 0.35507
Epoch 3, Train-Loss: 0.15278,  Val-Loss: 0.31254
Epoch 4, Train-Loss: 0.11820,  Val-Loss: 0.34467
Epoch 5, Train-Loss: 0.07592,  Val-Loss: 0.32326
Epoch 6, Train-Loss: 0.05216,  Val-Loss: 0.30510
Epoch 7, Train-Loss: 0.06803,  Val-Loss: 0.29171
Epoch 8, Train-Loss: 0.06419,  Val-Loss: 0.36912
Epoch 9, Train-Loss: 0.04590,  Val-Loss: 0.28831
Epoch 10, Train-Loss: 0.04698,  Val-Loss: 0.23476
Epoch 11, Train-Loss: 0.04321,  Val-Loss: 0.25885
Epoch 12, Train-Loss: 0.02705,  Val-Loss: 0.25959
Epoch 13, Train-Loss: 0.03206,  Val-Loss: 0.25896
Epoch 14, Train-Loss: 0.02333,  Val-Loss: 0.24518
Epoch 15, Train-Loss: 0.01033,  Val-Loss: 0.28079
Epoch 16, Train-Loss: 0.02709,  Val-Loss: 0.22377
Epoch 17, Train-Loss: 0.02047,  Val-Loss: 0.23492
Epoch 18, Train-Loss: 0.02414,  Val-Loss: 0.25090
Epoch 19, Train-Loss: 0.02040,  Val-Loss: 0.26568
Epoch 20, 

## Evaluation

In [16]:
@torch.no_grad()
def evaluate(model, test_loader, device):
    clear_cache()
    model.eval()
    model_preds = list()
    true_labels = list()

    for data in test_loader:
        image_input, test_label = data
        image_input = image_input.to(device, non_blocking=True)
        test_label = test_label.to(device, non_blocking=True)

        output = model(image_input)

        model_preds += output.argmax(1).detach().cpu().numpy().tolist()
        true_labels += test_label.detach().cpu().numpy().tolist()

    return {
        "accuracy": accuracy_score(true_labels, model_preds),
        "recall": recall_score(true_labels, model_preds, average="macro"),
        "precision": precision_score(true_labels, model_preds, average="macro"),
        "f1": f1_score(true_labels, model_preds, average="macro"),
    }

In [17]:
model = (
    ResNet(num_labels=num_labels)
    if config.model == "resnet"
    else ViTransformer(num_labels=num_labels)
)
model.load_state_dict(early_stopper.load())
model = model.to(device)

scores = evaluate(model, test_loader, device)
for k, v in scores.items():
    wandb.summary[k] = v
    print(f"{k}: {v:.3f}")

accuracy: 0.907
recall: 0.852
precision: 0.912
f1: 0.866


In [None]:
wandb.finish()