# Set up


In [None]:
%pip install torchtune torchao

In [None]:
import os

import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchtune.training import get_cosine_schedule_with_warmup
from torchvision.datasets import ImageFolder
from tqdm import tqdm

In [None]:
def save_checkpoint(model: nn.Module, optimizer: optim.Optimizer, save_path):
    torch.save(
        {
            "model": model.state_dict(),
            "optimizer": optimizer.state_dict(),
        },
        save_path,
    )


def load_checkpoint(
    model: nn.Module,
    optimizer: optim.Optimizer,
    save_path,
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
):
    dict = torch.load(save_path, map_location=device, weights_only=True)

    model.load_state_dict(dict["model"])
    optimizer.load_state_dict(dict["optimizer"])

    return model, optimizer

In [None]:
def load_dataset(root, batch_size=32):
    torch.manual_seed(42)

    data_augmentation = transforms.Compose(
        [
            transforms.RandomRotation(degrees=15),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(
                brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1
            ),
        ]
    )
    preprocess = transforms.Compose(
        [
            transforms.Resize(size=(224, 224), antialias=True),
            transforms.ToTensor(),
            transforms.Normalize([0.7037, 0.6818, 0.6685], [0.2739, 0.2798, 0.2861]),
        ]
    )

    # ImageFolder
    train_set = ImageFolder(
        f"{root}/train", transform=transforms.Compose([data_augmentation, preprocess])
    )
    valid_set = ImageFolder(f"{root}/val", transform=preprocess)
    test_set = ImageFolder(f"{root}/test", transform=preprocess)

    # DataLoader
    train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(dataset=valid_set, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False)

    return train_loader, valid_loader, test_loader

In [None]:
def train(
    model: nn.Module,
    train_loader: DataLoader,
    valid_loader: DataLoader,
    save_path="./results",
    num_epochs=100,
    lr=0.01,
    momentum=0.9,
    num_warmup_steps=5,
    weight_decay=0.0005,
    resume_training=False,
    load_path="./",
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
):
    os.makedirs(save_path, exist_ok=True)

    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(
        model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay
    )
    scheduler = get_cosine_schedule_with_warmup(
        optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_epochs
    )

    results = []
    start_epoch = 0
    if resume_training:
        df = pd.read_csv(f"{load_path}/results.csv")
        results = list(df.T.to_dict().values())
        start_epoch = int(results[-1]["epoch"])

        model, optimizer = load_checkpoint(
            model, optimizer, f"{load_path}/resnet_{start_epoch}.pth", device
        )
        scheduler.load_state_dict(
            torch.load(
                f"{load_path}/scheduler_{start_epoch}.pth",
                map_location=device,
                weights_only=True,
            )
        )

        for _ in range(start_epoch):
            for _ in train_loader:
                break

        print(f"Resuming training from epoch {start_epoch}")

    print(f"Start trainning with {str(device).upper()}")
    for epoch in range(start_epoch, num_epochs):
        # Train step
        model.train()
        train_running_loss, train_correct = 0.0, 0
        with tqdm(
            total=len(train_loader),
            desc=f"Train epoch {epoch+1}/{num_epochs}",
            unit="batch",
        ) as pbar:
            for i, (images, labels) in enumerate(train_loader):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                train_running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                train_correct += (predicted == labels).sum().item()

                pbar.set_postfix({"loss": format(train_running_loss / (i + 1), ".4f")})
                pbar.update()

            train_loss = train_running_loss / len(train_loader)
            train_acc = 100 * train_correct / len(train_loader.dataset)
            pbar.set_postfix(
                {
                    "loss": format(train_loss, ".4f"),
                    "acc": format(train_acc, ".2f"),
                }
            )

        # Validation step
        model.eval()
        valid_running_loss, valid_correct = 0.0, 0
        with tqdm(
            total=len(valid_loader),
            desc=f"Valid epoch {epoch+1}/{num_epochs}",
            unit="batch",
        ) as pbar:
            for i, (images, labels) in enumerate(valid_loader):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                valid_running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                valid_correct += (predicted == labels).sum().item()

                pbar.set_postfix({"loss": format(valid_running_loss / (i + 1), ".4f")})
                pbar.update()

            valid_loss = valid_running_loss / len(valid_loader)
            valid_acc = 100 * valid_correct / len(valid_loader.dataset)
            pbar.set_postfix(
                {
                    "loss": format(valid_loss, ".4f"),
                    "acc": format(valid_acc, ".2f"),
                }
            )

        print("last_lr =", scheduler.get_last_lr())
        scheduler.step()

        # Save results
        results.append(
            {
                "epoch": epoch + 1,
                "train_loss": train_loss,
                "train_acc": train_acc,
                "valid_loss": valid_loss,
                "valid_acc": valid_acc,
            }
        )
        df = pd.DataFrame(results)
        df.to_csv(f"{save_path}/results.csv", index=False)

        # Save checkpoint
        if (epoch + 1) % 5 == 0:
            save_checkpoint(model, optimizer, f"{save_path}/resnet_{epoch+1}.pth")
            torch.save(scheduler.state_dict(), f"{save_path}/scheduler_{epoch+1}.pth")

# Train


In [None]:
train_loader, valid_loader, test_loader = load_dataset(
    "/kaggle/input/categories-classification/data", batch_size=32
)

In [None]:
model = ResNet101(num_classes=10)

In [None]:
train(
    model=model,
    train_loader=train_loader,
    valid_loader=valid_loader,
    save_path="./resnet101",
    num_epochs=100,
    resume_training=True,
    load_path="/kaggle/input/checkpoint-11",
)