[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/food-analytic/train-classification/blob/main/notebooks/train_convnext_colab.ipynb)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%pip install -qqq timm
%pip install -qqq wandb

In [None]:
!mkdir '/content/data/'
!unzip -qq -O utf-8 '/content/drive/Shareddrives/Food Analytic/Data/chula_food_353.zip' -d '/content/data/'

In [None]:
import os
import random
from typing import Dict, List, Tuple
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, WeightedRandomSampler
from torchvision.datasets import ImageFolder
from torchvision import transforms
import timm
from timm.optim import create_optimizer_v2
import wandb

In [None]:
wandb.login()

In [None]:
config = {
    # Data
    "batch_size": 480,
    "image_size": (224, 224),
    "seed": 42,
    "train_path": "/content/data/chula_food_353/train",
    "val_path": "/content/data/chula_food_353/val",
    "test_path": "/content/data/chula_food_353/test",
    "min_sample_per_class": 60,
    "val_length": 15,
    "test_length": 15,
    "num_workers": 2,

    # Model
    "base_model": 'convnext_base_in22k',
    "dropout": 0.2,

    # Training
    "num_epochs" : 10,
    "lr": 1e-3,
    "optimizer": "madgrad",
    "weight_save_path": '/content/convnext.pt',

    # Logging
    "project": "classification-convnext",
}

random.seed(config["seed"])
np.random.seed(config["seed"])
torch.manual_seed(config["seed"])

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
def get_weighted_random_sampler(dataset):
    label_weights = 1 / np.bincount(dataset.targets)
    weights = label_weights[dataset.targets]
    sampler = WeightedRandomSampler(
        weights,
        len(weights),
        replacement=True,
        generator=torch.Generator().manual_seed(config["seed"]),
    )
    return sampler

In [None]:
train_transform = transforms.Compose(
    [
        transforms.Resize(size=config["image_size"]),
        timm.data.auto_augment.rand_augment_transform(
            config_str="rand-m9-mstd0.5",
            hparams={},
        ),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

test_transform = transforms.Compose(
    [
        transforms.Resize(size=config["image_size"]),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)

train_dataset = ImageFolder(root=config["train_path"], transform=train_transform)
val_dataset = ImageFolder(root=config["val_path"], transform=test_transform)
test_dataset = ImageFolder(root=config["test_path"], transform=test_transform)

# Use WeightedRandomSampler to tackle the class imbalance problem
sampler = get_weighted_random_sampler(train_dataset)

train_loader = DataLoader(
    train_dataset,
    batch_size=config["batch_size"],
    sampler=sampler,
    num_workers=config["num_workers"],
    pin_memory=True,
    generator=torch.Generator().manual_seed(config["seed"]),
)
val_loader = DataLoader(
    val_dataset,
    batch_size=config["batch_size"],
    shuffle=False,
    num_workers=config["num_workers"],
    pin_memory=True,
)
test_loader = DataLoader(
    test_dataset,
    batch_size=config["batch_size"],
    shuffle=False,
    num_workers=config["num_workers"],
    pin_memory=True,
)

In [None]:
class ChulaFoodNet(nn.Module):
  def __init__(self, num_classes):
    super(ChulaFoodNet, self).__init__()
    self.pretrained_model = timm.create_model(config['base_model'], pretrained=True, drop_rate=config['dropout'])
    self.pretrained_model.head.fc = nn.Linear(1024, num_classes)

  def forward(self, input):
    x = self.pretrained_model(input)
    return x

In [None]:
model = ChulaFoodNet(len(train_dataset.classes))

for parameter in model.pretrained_model.parameters():
    parameter.requires_grad_(False)

for parameter in model.pretrained_model.head.parameters():
    parameter.requires_grad_(True)

model.to(device)

criterion = nn.CrossEntropyLoss()

optimizer = create_optimizer_v2(
    model.parameters(),
    config['optimizer'],
    lr=config['lr'],
)

scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
    optimizer=optimizer,
    T_0=10,
    T_mult=2,
)

In [None]:
def train(model, loader, device, epoch, criterion, optimizer, scheduler):
    model.train()
    num_correct = 0
    num_data = 0

    pbar = tqdm(enumerate(loader), total=len(loader))
    for batch_idx, data in pbar:
        inputs, labels = data[0].to(device), data[1].to(device)

        optimizer.zero_grad()
        outputs = model(inputs)

        _, predicted = torch.max(outputs, 1)
        num_correct += (predicted == labels).sum().item()
        num_data += labels.size(0)

        lr = optimizer.param_groups[0]['lr']

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        scheduler.step(epoch + batch_idx / len(loader))

        accumulate_accuracy = num_correct / num_data * 100
        pbar.set_description(f'[Training Epoch {epoch}] LR: {lr:.6f}, Loss: {loss:.4f}, Accuracy: {accumulate_accuracy:.4f}')
        wandb.log({"train_acc": accumulate_accuracy, "train_loss": loss, "lr": lr})

def validate(model, loader, device, epoch, criterion):
    model.eval()
    num_correct = 0
    num_data = 0

    pbar = tqdm(enumerate(loader), total=len(loader))
    with torch.no_grad():
        for batch_idx, data in pbar:
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)

            num_correct += (predicted == labels).sum().item()
            num_data += labels.size(0)

            loss = criterion(outputs, labels)

            accumulate_accuracy = num_correct / num_data * 100
            pbar.set_description(f'[Testing Epoch {epoch}] Loss: {loss:.4f}, Accuracy: {accumulate_accuracy:.4f}')
            wandb.log({"val_acc": accumulate_accuracy, "val_loss": loss})
        
    return num_correct, num_data

def test_per_class(model, loader, device):
    model.eval()
    num_classes = len(loader.dataset.dataset.classes)
    num_correct_per_class = np.zeros(num_classes)
    num_data_per_class = np.zeros(num_classes)

    pbar = tqdm(enumerate(loader), total=len(loader))
    with torch.no_grad():
        for batch_idx, data in pbar:
            inputs, labels = data[0].to(device), data[1].to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)

            labels = labels.cpu().detach().numpy()
            predicted = predicted.cpu().detach().numpy()
            num_correct_per_class += np.bincount(labels[predicted == labels], minlength=num_classes)
            num_data_per_class += np.bincount(labels, minlength=num_classes)
    
    return num_correct_per_class, num_data_per_class

In [None]:
wandb.init(project=config['project'], config=config)

best_accuracy = 0
for epoch in range(1, config['num_epochs'] + 1):
    train(model, train_loader, device, epoch, criterion, optimizer, scheduler)
    num_correct, num_data = validate(model, val_loader, device, epoch, criterion)
    val_accuracy = num_correct / num_data
    if val_accuracy > best_accuracy:
        torch.save(model.state_dict(), config['weight_save_path'])
        best_accuracy = val_accuracy

wandb.finish()

In [None]:
model.load_state_dict(torch.load(config['weight_save_path']))
test_correct_per_class, test_data_per_class = test_per_class(model, test_loader, device)
test_accuracy_per_class = test_correct_per_class / test_data_per_class
test_accuracy = test_correct_per_class.sum().item() / test_data_per_class.sum().item()
print(f"Test Accuracy: {test_accuracy:.4f}")

In [None]:
df_eval = pd.DataFrame({
    'class_name': test_dataset.classes,
    'accuracy': test_accuracy_per_class,
    'num_training_data': np.bincount(np.asarray(train_dataset.targets)[train_split.indices], minlength=len(train_dataset.classes)),
})

df_eval.to_csv('eval.csv', index=False)