### Birds Classification - Pytorch Metric Learning

#### 1. Download dataset
https://www.kaggle.com/datasets/gpiosenka/100-bird-species

In [None]:
# Click on the link above, login to kaggle and hit download button then copy the download url below
!wget -O archive.zip "<Dataset-URL>"
!unzip archive.zip -d dataset && rm -rf archive.zip

#### 2. Training

In [1]:
import os
import json
import torch
import PIL
from model.AlexNet import AlexNet
from utils.dataset import build_dataloader
from utils.common import get_all_embeddings, get_accuracy, log_to_file
from pytorch_metric_learning import distances, losses, miners, reducers, testers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
from torchvision import datasets, transforms
from multiprocessing import cpu_count

torch.backends.cudnn.benchmark = True

In [10]:
# Constants
EPOCHS = 300
BATCH_SIZE = 1024

IMAGE_SIZE = 128
EMBEDDING_SIZE = 64

MEAN = [0.485, 0.456, 0.406]
STD = [0.229, 0.224, 0.225]

DEVICE = torch.device("cuda:0")
TRAIN_DATASET = "./dataset/train"
VAL_DATASET = "./dataset/valid"
SAVE_PATH = "./weights"

# Image transformations
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.Normalize(MEAN, STD),
])

# Train Dataloader
train_dataset, train_loader = build_dataloader(
    batch_size=BATCH_SIZE,
    root_dir=TRAIN_DATASET,
    transform=transform,
    shuffle=True,
    num_workers=cpu_count()
)

# Val Dataloader
val_dataset, val_loader = build_dataloader(
    batch_size=BATCH_SIZE,
    root_dir=VAL_DATASET,
    transform=transform,
    shuffle=False,
    num_workers=cpu_count()
)

In [3]:
# Define model
model = AlexNet(input_size=IMAGE_SIZE, embedding_size=EMBEDDING_SIZE, pretrained=True).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# Define triplet loss utility functions
distance = distances.CosineSimilarity()
reducer = reducers.ThresholdReducer(low=0)
loss_func = losses.TripletMarginLoss(margin=0.2, distance=distance, reducer=reducer)
mining_func = miners.TripletMarginMiner(margin=0.2, distance=distance, type_of_triplets="semihard")

In [4]:
# load pretrained weights, if you want to continue training
model.load_state_dict(torch.load("./weights/model_best.pth", map_location=DEVICE))

<All keys matched successfully>

In [11]:
history = {"train": [], "val": [], "best_accuracy": 0.0}
os.makedirs(SAVE_PATH, exist_ok=True)
if os.path.exists("training.log"):
    os.remove("training.log")

In [None]:
for epoch in range(0, EPOCHS):
    epoch += 1

    # Model training
    model.to(DEVICE)
    model.train()
    
    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(DEVICE), labels.to(DEVICE)
        
        optimizer.zero_grad()
        embeddings = model(data)
        indices_tuple = mining_func(embeddings, labels)
        loss = loss_func(embeddings, labels, indices_tuple)
        loss.backward()
        optimizer.step()

        if batch_idx % 20 == 0:
            history["train"].append({"epoch": epoch, "loss": loss.item(), "triplets": mining_func.num_triplets})
            msg = f"Epoch [{epoch}/{EPOCHS}] Iter [{batch_idx}/{len(train_loader)}], Loss: {loss.item()}, Triplets: {mining_func.num_triplets}"
            log_to_file(msg)
            print(msg)
    
    # evaluate after n epochs
    if epoch %2 == 0:
        # model validation
        model.eval()
        
        with torch.no_grad():
            # as all embeddings need to be stored in memory
            # you can set DEVICE = torch.device('cpu') in case gpu memory overflow occurs
            accuracy = get_accuracy(val_dataset, train_dataset, model, DEVICE)
            
            history["val"].append({"epoch": epoch, "accuracy": accuracy})
            msg = f"Val accuracy: {accuracy}"
            log_to_file(msg)
            print(msg)
            
            # save model
            torch.save(model.state_dict(), f"{SAVE_PATH}/model_latest.pth")
            
            if accuracy >= history["best_accuracy"]:
                history["best_accuracy"] = accuracy
                torch.save(model.state_dict(), f"{SAVE_PATH}/model_best.pth")

            with open("history.json", "w") as f:
                f.write(json.dumps(history))