[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/food-analytic/train-classification/blob/main/notebooks/train_convnext_colab.ipynb)

In [None]:
%pip install -qqq timm
%pip install -qqq torchinfo
%pip install -qqq wandb

In [None]:
!mkdir '/content/data/'
!unzip -qq -O utf-8 '/content/drive/Shareddrives/Food Analytic/Data/chula_food_330.zip' -d '/content/data/'

In [None]:
import os
import random
from typing import Dict, List, Tuple
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision.datasets import ImageFolder
from torchvision import transforms
from torchinfo import summary
import timm
from timm.optim import create_optimizer_v2
import wandb

In [None]:
wandb.login()

In [None]:
config = {
    # Data
    "batch_size": 160,
    "image_size": (224, 224),
    "seed": 42,
    "train_path": "/content/data/chula_food_330/train",
    "val_path": "/content/data/chula_food_330/val",
    "test_path": "/content/data/chula_food_330/test",
    "num_workers": 2,

    # Model
    "base_model": 'convnext_base_in22k',
    "dropout": 0.2,

    # Training
    "num_epochs" : 10,
    "lr": 5e-5,
    "optimizer": "madgrad",
    "weight_save_path": '/content/drive/Shareddrives/Food Analytic/models/Classification/convnext_v2_zoku.pt',
    "weight_load_path": '/content/drive/Shareddrives/Food Analytic/models/Classification/convnext_v2.pt',

    # Logging
    "project": "classification-convnext",
}

random.seed(config["seed"])
np.random.seed(config["seed"])
torch.manual_seed(config["seed"])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def get_weighted_random_sampler(dataset):
    label_weights = 1 / np.bincount(dataset.targets)
    weights = label_weights[dataset.targets]
    sampler = WeightedRandomSampler(
        weights,
        len(weights),
        replacement=True,
        generator=torch.Generator().manual_seed(config["seed"]),
    )
    return sampler

In [None]:
train_transform = transforms.Compose(
    [
        transforms.Resize(size=config["image_size"]),
        timm.data.auto_augment.rand_augment_transform(
            config_str="rand-m9-mstd0.5",
            hparams={},
        ),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

test_transform = transforms.Compose(
    [
        transforms.Resize(size=config["image_size"]),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

train_dataset = ImageFolder(root=config["train_path"], transform=train_transform)
val_dataset = ImageFolder(root=config["val_path"], transform=test_transform)
test_dataset = ImageFolder(root=config["test_path"], transform=test_transform)

# Use WeightedRandomSampler to tackle the class imbalance problem
sampler = get_weighted_random_sampler(train_dataset)

train_loader = DataLoader(
    train_dataset,
    batch_size=config["batch_size"],
    sampler=sampler,
    num_workers=config["num_workers"],
    pin_memory=True,
    generator=torch.Generator().manual_seed(config["seed"]),
)
val_loader = DataLoader(
    val_dataset,
    batch_size=config["batch_size"],
    shuffle=False,
    num_workers=config["num_workers"],
    pin_memory=True,
)
test_loader = DataLoader(
    test_dataset,
    batch_size=config["batch_size"],
    shuffle=False,
    num_workers=config["num_workers"],
    pin_memory=True,
)

In [None]:
class ChulaFoodNet(nn.Module):
  def __init__(self, num_classes):
    super(ChulaFoodNet, self).__init__()
    self.pretrained_model = timm.create_model(config['base_model'], pretrained=True, drop_rate=config['dropout'])
    self.pretrained_model.head.fc = nn.Linear(1024, num_classes)

  def forward(self, input):
    x = self.pretrained_model(input)
    return x

In [None]:
model = ChulaFoodNet(len(train_dataset.classes))

for parameter in model.pretrained_model.parameters():
    parameter.requires_grad_(False)

for parameter in model.pretrained_model.stages[2].blocks[14:].parameters():
    parameter.requires_grad_(True)

for parameter in model.pretrained_model.stages[3].parameters():
    parameter.requires_grad_(True)

for parameter in model.pretrained_model.head.parameters():
    parameter.requires_grad_(True)

if config['weight_load_path'] is not None:
    model.load_state_dict(torch.load(config['weight_load_path']))

model.to(device)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

optimizer = create_optimizer_v2(
    model.parameters(),
    config['optimizer'],
    lr=config['lr'],
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer=optimizer,
    T_0=10,
    T_mult=2,
)

summary(model, input_size=(config['batch_size'], 3, *config['image_size']))

In [None]:
def train(model, loader, device, epoch, criterion, optimizer, scheduler):
    model.train()
    num_correct_top_1 = 0
    num_correct_top_3 = 0
    num_data = 0
    running_loss = 0.0

    pbar = tqdm(enumerate(loader), total=len(loader))
    for batch_idx, data in pbar:
        inputs, labels = data[0].to(device), data[1].to(device)

        lr = optimizer.param_groups[0]['lr']
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step(epoch + batch_idx / len(loader))

        _, predicted = torch.max(outputs, 1)
        _, predicted_top_3 = torch.topk(outputs, 3, 1)
        num_correct_top_1 += (predicted == labels).sum().item()
        num_correct_top_3 += (
            predicted_top_3 == torch.broadcast_to(torch.unsqueeze(labels, 1), predicted_top_3.size())
        ).sum().item()
        num_data += labels.size(0)
        running_loss += loss * inputs.size(0)
        epoch_loss = running_loss / num_data
        accuracy_top_1 = num_correct_top_1 / num_data * 100
        accuracy_top_3 = num_correct_top_3 / num_data * 100
        pbar.set_description(
            f'[Training Epoch {epoch}] LR: {lr:.6f}, Loss: {epoch_loss:.4f}, Top 1 Accuracy: {accuracy_top_1:.4f}, Top 3 Accuracy {accuracy_top_3:.4f}'
        )
    
    return epoch_loss, accuracy_top_1, accuracy_top_3, lr

def validate(model, loader, device, epoch, criterion):
    model.eval()
    num_correct_top_1 = 0
    num_correct_top_3 = 0
    num_data = 0
    running_loss = 0

    pbar = tqdm(enumerate(loader), total=len(loader))
    with torch.no_grad():
        for batch_idx, data in pbar:
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            _, predicted_top_1 = torch.max(outputs, 1)
            _, predicted_top_3 = torch.topk(outputs, 3, 1)
            num_correct_top_1 += (predicted_top_1 == labels).sum().item()
            num_correct_top_3 += (
                predicted_top_3 == torch.broadcast_to(torch.unsqueeze(labels, 1), predicted_top_3.size())
            ).sum().item()
            num_data += labels.size(0)
            running_loss += loss * inputs.size(0)
            epoch_loss = running_loss / num_data
            accuracy_top_1 = num_correct_top_1 / num_data * 100
            accuracy_top_3 = num_correct_top_3 / num_data * 100
            pbar.set_description(
                f'[Testing Epoch {epoch}] Loss: {epoch_loss:.4f}, Top 1 Accuracy: {accuracy_top_1:.4f}, Top 3 Accuracy {accuracy_top_3:.4f}'
            )
        
    return epoch_loss, accuracy_top_1, accuracy_top_3

def predict(model, loader, device):
    model.eval()
    predictions = []
    targets = []
    pbar = tqdm(enumerate(loader), total=len(loader))
    with torch.no_grad():
        for batch_idx, data in pbar:
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = model(inputs)
            predictions.append(outputs.cpu().detach())
            targets.append(labels.cpu().detach())
    return torch.cat(predictions), torch.cat(targets)

In [None]:
wandb.init(project=config['project'], config=config)

best_accuracy_top_1 = 0
best_accuracy_top_3 = 0
for epoch in range(1, config['num_epochs'] + 1):
    train_loss, train_accuracy_top_1, train_accuracy_top_3, lr = train(model, train_loader, device, epoch, criterion, optimizer, scheduler)
    wandb.log({'train_loss': train_loss, 'train_accuracy_top_1': train_accuracy_top_1, 'train_accuracy_top_3': train_accuracy_top_3}, commit=False)
    val_loss, val_accuracy_top_1, val_accuracy_top_3 = validate(model, val_loader, device, epoch, criterion)
    wandb.log({'val_loss': val_loss, 'val_accuracy_top_1': val_accuracy_top_1, 'val_accuracy_top_3': val_accuracy_top_3}, commit=True)
    if val_accuracy_top_1 > best_accuracy_top_1 or val_accuracy_top_3 > best_accuracy_top_3:
        torch.save(model.state_dict(), config['weight_save_path'])
        best_accuracy_top_1 = val_accuracy_top_1 if val_accuracy_top_1 > best_accuracy_top_1 else best_accuracy_top_1
        best_accuracy_top_3 = val_accuracy_top_3 if val_accuracy_top_3 > best_accuracy_top_3 else best_accuracy_top_3

wandb.finish()

In [None]:
model.load_state_dict(torch.load(config['weight_save_path']))
test_predictions, test_targets = predict(model, test_loader, device)

_, predicted_top_1 = torch.max(test_predictions, 1)
_, predicted_top_3 = torch.topk(test_predictions, 3, 1)

num_classes = len(test_loader.dataset.classes)
correct_top_1_per_class = np.bincount(test_targets[predicted_top_1 == test_targets], minlength=num_classes)
correct_top_3_per_class = np.bincount(
    test_targets[torch.any(predicted_top_3 == torch.broadcast_to(torch.unsqueeze(test_targets, 1), predicted_top_3.size()), dim=1)]
, minlength=num_classes)
data_per_class = np.bincount(test_targets, minlength=num_classes)

accuracy_top_1_per_class = correct_top_1_per_class / data_per_class
accuracy_top_3_per_class = correct_top_3_per_class / data_per_class

accuracy_top_1 = correct_top_1_per_class.sum() / data_per_class.sum()
accuracy_top_3 = correct_top_3_per_class.sum() / data_per_class.sum()
print(f"Test Top 1 Accuracy: {accuracy_top_1:.4f}")
print(f"Test Top 3 Accuracy: {accuracy_top_3:.4f}")

In [None]:
df_eval = pd.DataFrame({
    'class_name': test_dataset.classes,
    'num_training_data': np.bincount(train_dataset.targets, minlength=len(train_dataset.classes)),
    'accuracy_top_1': accuracy_top_1_per_class,
    'accuracy_top_3': accuracy_top_3_per_class,
})

df_eval.to_csv('eval.csv', index=False)