# Train PawPrint+ on local GPU

In [None]:
WANDB_API_KEY = "..."
WANDB_ENTITY = "..."

In [2]:
import os
import glob
from pathlib import Path
from dataclasses import dataclass, field
from typing import Optional, Literal
from datetime import datetime
import gc
import warnings
from PIL import Image

import wandb
import numpy as np
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score, recall_score, precision_score
from sklearn.model_selection import train_test_split
import torch
from torch import autocast
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingLR
from torchvision import transforms
from torch.utils.data import DataLoader, Dataset
import torchvision.models as models

## Utils

In [3]:
def _time_stamp():
    current_time = datetime.now()
    return current_time.strftime("%y%m%d%H%M%S%f")


def _format_name(name: str, max_len: int = 30) -> str:
    name = name.strip().lower().replace(" ", "-")
    return name[:max_len]


def clear_cache():
    gc.collect()
    torch.cuda.empty_cache()


def device(force_cuda=True) -> torch.device:
    has_cuda = torch.cuda.is_available()
    if force_cuda:
        assert has_cuda, "CUDA is not available."
        return torch.device("cuda")
    return torch.device("cuda") if has_cuda else torch.device("cpu")


def ignore_warnings():
    warnings.filterwarnings("ignore")


def fix_random_seed(seed=42):
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)


def join_path(*args):
    return os.path.join(os.getcwd(), *args)


class EarlyStopping(object):
    """Stop training when loss does not decrease"""

    def __init__(self, patience: int, path_to_save: str):
        self._min_loss = float("inf")
        self._patience = patience
        self._path = path_to_save
        self.__check_point = None
        self.__counter = 0

    def should_stop(self, loss: float, model: torch.nn.Module, epoch: int) -> bool:
        if loss < self._min_loss:
            self._min_loss = loss
            self.__counter = 0
            self.__check_point = epoch
            torch.save(model.state_dict(), self._path)
        elif loss > self._min_loss:
            self.__counter += 1
            if self.__counter == self._patience:
                return True
        return False

    def load(self, weights_only=True):
        return torch.load(self._path, weights_only=weights_only)

    @property
    def check_point(self):
        if self.__check_point is None:
            raise ValueError("No check point is saved!")
        return self.__check_point

    @property
    def best_loss(self):
        return self._min_loss


@dataclass()
class Config:
    name: str
    model: Literal["resnet152", "resnet50"]
    batch: int
    epochs: int
    weight_decay: float
    lr: float
    backbone_lr: float
    enable_fp16: Optional[bool] = field(default=False)
    patience: Optional[int] = field(default=0)

    # Will be filled automatically
    id: str = field(init=False)
    model_path: str = field(init=False)

    def __post_init__(self):
        self.id = _time_stamp()
        self.name = _format_name(self.name)
        self.model_path = f"{self.name}_{self.id}.pt"

        if self.batch < 1:
            raise ValueError("batch must be positive integer.")
        if self.epochs < 0:
            raise ValueError("epochs must be positive integer.")
        if self.lr < 0:
            raise ValueError("lr must be positive float.")

    def add(self, **kwargs):
        for k, v in kwargs.items():
            if k in self.__dict__.keys():
                raise KeyError(f"Duplicate key: {k}")
            object.__setattr__(self, k, v)

    def to_dict(self):
        return self.__dict__

    def __repr__(self):
        pairs = self.to_dict().items()
        pairs = ", ".join([f"{k}={v}" for k, v in pairs])
        name = str(self.__class__.__name__)
        return f"{name}({pairs})"

## Config

In [None]:
config = Config(
    name="resnet-152",
    model="resnet152",
    batch=64,
    epochs=100,
    lr=1e-3,
    backbone_lr=1e-3,
    enable_fp16=True,
    patience=30,
    weight_decay=1e-3,
)

In [5]:
ignore_warnings()
fix_random_seed(42)
device = device(force_cuda=True)

In [None]:
wandb.login(key=WANDB_API_KEY)
wandb.init(
    project="pawprint+", entity=WANDB_ENTITY, name=config.id, config=config.to_dict()
)

## Dataset

In [7]:
class PawPrintDataset(Dataset):
    def __init__(self, datasets, labels, transformations):
        self.transform = transforms.Compose(transformations)
        self.datasets = datasets
        self.labels = labels

    def __len__(self):
        return len(self.datasets)

    def __getitem__(self, idx):
        image_path, label = self.datasets[idx]
        image = Image.open(image_path).convert("RGB")
        label = torch.tensor(self.labels[label], dtype=torch.long)
        if self.transform:
            image = self.transform(image)
        return image, label

    def decode(self, label: int) -> str | None:
        for text_label, int_label in self.labels.items():
            if label == int_label:
                return text_label
        return None

In [8]:
augmentation_pipeline = [
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply(
        [
            transforms.ColorJitter(
                brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1
            ),
            transforms.GaussianBlur(kernel_size=5, sigma=1.0),
        ],
        p=0.5,
    ),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]

normalize_pipeline = [
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]

In [9]:
def read_img_as_dataset(img_paths):
    label_dict = {}
    datasets = []
    tmp_labels = set()

    source_pattern = os.path.join(img_paths, "*/*/*.png")
    file_list = glob.glob(source_pattern)
    for file_path in file_list:
        text_label = Path(file_path).parent.name
        text_label = str(text_label).lower().strip()

        tmp_labels.add(text_label)
        datasets.append((file_path, text_label))

    for idx, text_label in enumerate(tmp_labels):
        label_dict[text_label] = idx

    return datasets, label_dict


train_imgs, label_dict = read_img_as_dataset(join_path("data", "PP+", "train"))
test_imgs, _ = read_img_as_dataset(join_path("data", "PP+", "test"))

train_imgs, val_imgs = train_test_split(
    train_imgs, test_size=0.1, shuffle=True, random_state=42
)

train_set = PawPrintDataset(
    train_imgs, label_dict, transformations=augmentation_pipeline
)
val_set = PawPrintDataset(val_imgs, label_dict, transformations=normalize_pipeline)
test_set = PawPrintDataset(test_imgs, label_dict, transformations=normalize_pipeline)

train_loader = DataLoader(
    train_set, batch_size=config.batch, pin_memory=True, shuffle=True
)
val_loader = DataLoader(val_set, batch_size=config.batch, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=config.batch, pin_memory=True)

print(f"train: {len(train_set)}")
print(f"val: {len(val_set)}")
print(f"test: {len(test_set)}")
print(f"labels: {label_dict}")
print(f"# of labels: {len(label_dict)}")

train: 522


val: 58
test: 1082
labels: {'dog_kong': 0, 'dog_bbibbi': 1, 'dog_mina': 2, 'dog_ming': 3, 'dog_pori': 4, 'dog_angae': 5, 'dog_pp': 6, 'dog_pony': 7, 'dog_coco': 8, 'dog_mi': 9, 'dog_bori': 10, 'dog_wangbal': 11}
# of labels: 12


## Model

In [10]:
# Load pretrained weights if needed
resnet152_model_path = join_path("resnet_152.pth")
if config.model == "resnet152" and not os.path.exists(resnet152_model_path):
    resnet_152 = models.resnet152(weights=models.ResNet152_Weights.IMAGENET1K_V1)
    torch.save(resnet_152.state_dict(), resnet152_model_path)

resnet50_model_path = join_path("resnet_50.pth")
if config.model == "resnet50" and not os.path.exists(resnet50_model_path):
    resnet_50 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
    torch.save(resnet_50.state_dict(), resnet50_model_path)

Downloading: "https://download.pytorch.org/models/resnet152-394f9c45.pth" to C:\Users\admin/.cache\torch\hub\checkpoints\resnet152-394f9c45.pth
100%|██████████| 230M/230M [00:21<00:00, 11.3MB/s] 


In [11]:
class ResNet152(nn.Module):
    def __init__(self, num_labels):
        super(ResNet152, self).__init__()
        self.feature_extractor = models.resnet152()
        self.feature_extractor.load_state_dict(torch.load(resnet152_model_path))
        n_features = self.feature_extractor.fc.in_features
        self.feature_extractor.fc = nn.Identity()

        self.classifier = nn.Sequential(
            nn.Linear(n_features, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_labels),
        )

    def forward(self, x):
        features = self.feature_extractor(x)
        output = self.classifier(features)
        return output

In [None]:
class ResNet50(nn.Module):
    def __init__(self, num_labels):
        super(ResNet50, self).__init__()
        self.feature_extractor = models.resnet50()
        self.feature_extractor.load_state_dict(torch.load(resnet50_model_path))
        n_features = self.feature_extractor.fc.in_features
        self.feature_extractor.fc = nn.Identity()

        self.classifier = nn.Sequential(
            nn.Linear(n_features, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, num_labels),
        )

    def forward(self, x):
        features = self.feature_extractor(x)
        output = self.classifier(features)
        return output

## Train

In [13]:
@torch.no_grad()
def validate(model, criterion, val_loader, device):
    model.eval()

    val_loss = list()
    model_preds = list()
    true_labels = list()

    for data in val_loader:
        image_input, test_label = data
        image_input = image_input.to(device, non_blocking=True)
        test_label = test_label.to(device, non_blocking=True)

        with autocast(
            device_type=str(device), enabled=config.enable_fp16, dtype=torch.float16
        ):
            output = model(image_input)
            batch_loss = criterion(output, test_label)

        val_loss.append(batch_loss.item())

        model_preds += output.argmax(1).detach().cpu().numpy().tolist()
        true_labels += test_label.detach().cpu().numpy().tolist()

    return val_loss


def train(
    model,
    optimizer,
    criterion,
    early_stopper,
    scheduler,
    train_loader,
    val_loader,
    device,
):
    clear_cache()

    for epoch in range(config.epochs):
        model.train()
        train_loss = list()
        for data in train_loader:

            image_input, train_label = data
            image_input = image_input.to(device, non_blocking=True)
            train_label = train_label.to(device, non_blocking=True)

            with autocast(
                device_type=str(device), enabled=config.enable_fp16, dtype=torch.float16
            ):
                output = model(image_input)
                batch_loss = criterion(output, train_label.long())

            train_loss.append(batch_loss.item())

            batch_loss.backward()
            optimizer.step()
            model.zero_grad()

        val_loss = validate(model, criterion, val_loader, device)
        train_loss = np.mean(train_loss)
        val_loss = np.mean(val_loss)
        wandb.log({"train_loss": train_loss, "val_loss": val_loss})

        tqdm.write(
            f"Epoch {epoch}, Train-Loss: {train_loss:.5f},  Val-Loss: {val_loss:.5f}"
        )

        if early_stopper.should_stop(val_loss, model, epoch):
            break

        scheduler.step()

    tqdm.write(f"\n\n -- EarlyStopping: [Epoch: {early_stopper.check_point}]")

In [14]:
num_labels = len(label_dict)
os.makedirs(join_path("weights"), exist_ok=True)

model = ResNet50(num_labels) if config.model == "resnet50" else ResNet152(num_labels)
optimizer = optim.AdamW(
    [
        {"params": model.feature_extractor.parameters(), "lr": config.backbone_lr},
        {"params": model.classifier.parameters(), "lr": config.lr},
    ],
    weight_decay=config.weight_decay,
)
scheduler = CosineAnnealingLR(optimizer, T_max=config.epochs, eta_min=1e-6)
criterion = nn.CrossEntropyLoss()
early_stopper = EarlyStopping(
    patience=config.patience,
    path_to_save=join_path("weights", f"{config.name}_{config.id}.pth"),
)

model.to(device)
criterion.to(device)

print(f"model: {model.__class__.__name__}")
print(f"optimizer: {optimizer.__class__.__name__}")

model: ResNet152
optimizer: AdamW


In [15]:
train(
    model,
    optimizer,
    criterion,
    early_stopper,
    scheduler,
    train_loader,
    val_loader,
    device,
)

Epoch 0, Train-Loss: 1.73740,  Val-Loss: 27.69752
Epoch 1, Train-Loss: 1.21956,  Val-Loss: 18.91076
Epoch 2, Train-Loss: 0.89048,  Val-Loss: 2.96842
Epoch 3, Train-Loss: 1.01027,  Val-Loss: 2.75351
Epoch 4, Train-Loss: 0.92323,  Val-Loss: 2.67590
Epoch 5, Train-Loss: 0.73747,  Val-Loss: 1.94974
Epoch 6, Train-Loss: 0.67854,  Val-Loss: 0.86418
Epoch 7, Train-Loss: 0.65213,  Val-Loss: 9.35126
Epoch 8, Train-Loss: 0.76542,  Val-Loss: 1.79660
Epoch 9, Train-Loss: 1.00416,  Val-Loss: 4.60462
Epoch 10, Train-Loss: 0.68537,  Val-Loss: 1.58082
Epoch 11, Train-Loss: 0.57287,  Val-Loss: 0.84101
Epoch 12, Train-Loss: 0.56027,  Val-Loss: 0.80167
Epoch 13, Train-Loss: 0.56570,  Val-Loss: 1.53265
Epoch 14, Train-Loss: 0.64697,  Val-Loss: 2.33636
Epoch 15, Train-Loss: 0.54306,  Val-Loss: 2.76127
Epoch 16, Train-Loss: 0.49223,  Val-Loss: 1.48786
Epoch 17, Train-Loss: 0.44982,  Val-Loss: 0.95415
Epoch 18, Train-Loss: 0.35043,  Val-Loss: 1.68461
Epoch 19, Train-Loss: 0.45740,  Val-Loss: 1.00265
Epoch 20

In [16]:
@torch.no_grad()
def evaluate(model, test_loader, device):
    clear_cache()
    model.eval()
    model_preds = list()
    true_labels = list()

    for data in tqdm(test_loader):
        image_input, test_label = data
        image_input = image_input.to(device, non_blocking=True)
        test_label = test_label.to(device, non_blocking=True)

        output = model(image_input)

        model_preds += output.argmax(1).detach().cpu().numpy().tolist()
        true_labels += test_label.detach().cpu().numpy().tolist()

    return {
        "accuracy": accuracy_score(true_labels, model_preds),
        "recall": recall_score(true_labels, model_preds, average="macro"),
        "precision": precision_score(true_labels, model_preds, average="macro"),
        "f1": f1_score(true_labels, model_preds, average="macro"),
    }

In [17]:
model = ResNet50(num_labels) if config.model == "resnet50" else ResNet152(num_labels)
model.load_state_dict(early_stopper.load())
model = model.to(device)

scores = evaluate(model, test_loader, device)

print()
for k, v in scores.items():
    wandb.summary[k] = v
    print(f"{k}: {v:.3f}")

100%|██████████| 17/17 [00:20<00:00,  1.19s/it]



accuracy: 0.328
recall: 0.093
precision: 0.073
f1: 0.075


In [None]:
wandb.finish()