# Set up


In [None]:
%pip install torchtune torchao

In [None]:
import os

import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import DataLoader
from torchtune.training import get_cosine_schedule_with_warmup
from torchvision.datasets import ImageFolder
from tqdm import tqdm

In [None]:
class Bottleneck(nn.Module):
    expansion = 4

    def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
        super(Bottleneck, self).__init__()

        self.conv1 = nn.Conv2d(
            in_channels, out_channels, kernel_size=1, stride=1, padding=0
        )
        self.batch_norm1 = nn.BatchNorm2d(out_channels)

        self.conv2 = nn.Conv2d(
            out_channels, out_channels, kernel_size=3, stride=stride, padding=1
        )
        self.batch_norm2 = nn.BatchNorm2d(out_channels)

        self.conv3 = nn.Conv2d(
            out_channels,
            out_channels * self.expansion,
            kernel_size=1,
            stride=1,
            padding=0,
        )
        self.batch_norm3 = nn.BatchNorm2d(out_channels * self.expansion)

        self.i_downsample = i_downsample
        self.stride = stride
        self.relu = nn.ReLU()

    def forward(self, x):
        identity = x.clone()
        x = self.relu(self.batch_norm1(self.conv1(x)))
        x = self.relu(self.batch_norm2(self.conv2(x)))
        x = self.conv3(x)
        x = self.batch_norm3(x)

        if self.i_downsample is not None:
            identity = self.i_downsample(identity)

        x += identity
        x = self.relu(x)

        return x


class Block(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, i_downsample=None, stride=1):
        super(Block, self).__init__()

        self.conv1 = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            stride=stride,
            bias=False,
        )
        self.batch_norm1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(
            out_channels,
            out_channels,
            kernel_size=3,
            padding=1,
            stride=stride,
            bias=False,
        )
        self.batch_norm2 = nn.BatchNorm2d(out_channels)

        self.i_downsample = i_downsample
        self.stride = stride
        self.relu = nn.ReLU()

    def forward(self, x):
        identity = x.clone()

        x = self.relu(self.batch_norm2(self.conv1(x)))
        x = self.batch_norm2(self.conv2(x))

        if self.i_downsample is not None:
            identity = self.i_downsample(identity)

        print(x.shape)
        print(identity.shape)
        x += identity
        x = self.relu(x)

        return x


class ResNet(nn.Module):
    def __init__(self, ResBlock, layer_list, num_classes, num_channels=3):
        super(ResNet, self).__init__()
        self.in_channels = 64

        self.conv1 = nn.Conv2d(
            num_channels, 64, kernel_size=7, stride=2, padding=3, bias=False
        )
        self.batch_norm1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU()
        self.max_pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(ResBlock, layer_list[0], planes=64)
        self.layer2 = self._make_layer(ResBlock, layer_list[1], planes=128, stride=2)
        self.layer3 = self._make_layer(ResBlock, layer_list[2], planes=256, stride=2)
        self.layer4 = self._make_layer(ResBlock, layer_list[3], planes=512, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * ResBlock.expansion, num_classes)

    def forward(self, x):
        x = self.relu(self.batch_norm1(self.conv1(x)))
        x = self.max_pool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)

        return x

    def _make_layer(self, ResBlock, blocks, planes, stride=1):
        ii_downsample = None
        layers = []

        if stride != 1 or self.in_channels != planes * ResBlock.expansion:
            ii_downsample = nn.Sequential(
                nn.Conv2d(
                    self.in_channels,
                    planes * ResBlock.expansion,
                    kernel_size=1,
                    stride=stride,
                ),
                nn.BatchNorm2d(planes * ResBlock.expansion),
            )

        layers.append(
            ResBlock(
                self.in_channels, planes, i_downsample=ii_downsample, stride=stride
            )
        )
        self.in_channels = planes * ResBlock.expansion

        for _ in range(blocks - 1):
            layers.append(ResBlock(self.in_channels, planes))

        return nn.Sequential(*layers)


def ResNet50(num_classes=10, channels=3):
    return ResNet(Bottleneck, [3, 4, 6, 3], num_classes, channels)


def ResNet101(num_classes=10, channels=3):
    return ResNet(Bottleneck, [3, 4, 23, 3], num_classes, channels)

In [None]:
def save_checkpoint(
    model: nn.Module,
    optimizer: optim.Optimizer,
    save_path,
    parallel=False,
):
    torch.save(
        {
            "model": model.module.state_dict() if parallel else model.state_dict(),
            "optimizer": optimizer.state_dict(),
        },
        save_path,
    )


def load_checkpoint(
    model: nn.Module,
    optimizer: optim.Optimizer,
    save_path,
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    parallel=False,
):
    dict = torch.load(save_path, map_location=device, weights_only=True)

    if parallel:
        model.module.load_state_dict(dict["model"])
    else:
        model.load_state_dict(dict["model"])

    optimizer.load_state_dict(dict["optimizer"])

    return model, optimizer

In [None]:
def load_dataset(root, batch_size=32):
    torch.manual_seed(42)

    data_augmentation = transforms.Compose(
        [
            transforms.RandomRotation(degrees=15),
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.ColorJitter(
                brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1
            ),
        ]
    )
    preprocess = transforms.Compose(
        [
            transforms.Resize(size=(224, 224), antialias=True),
            transforms.ToTensor(),
            transforms.Normalize([0.7037, 0.6818, 0.6685], [0.2739, 0.2798, 0.2861]),
        ]
    )

    # ImageFolder
    train_set = ImageFolder(
        f"{root}/train", transform=transforms.Compose([data_augmentation, preprocess])
    )
    valid_set = ImageFolder(f"{root}/val", transform=preprocess)
    test_set = ImageFolder(f"{root}/test", transform=preprocess)

    # DataLoader
    train_loader = DataLoader(dataset=train_set, batch_size=batch_size, shuffle=True)
    valid_loader = DataLoader(dataset=valid_set, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(dataset=test_set, batch_size=batch_size, shuffle=False)

    return train_loader, valid_loader, test_loader

In [None]:
def train(
    model: nn.Module,
    train_loader: DataLoader,
    valid_loader: DataLoader,
    save_path: str,
    num_epochs=100,
    resume_training: bool = False,
    lr=0.01,
    momentum=0.9,
    weight_decay=0.0005,
    num_warmup_steps=5,
    parallel=False,
    device: torch.device = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
):
    os.makedirs(save_path, exist_ok=True)

    if parallel:
        model = nn.DataParallel(model)

    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.SGD(
        model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay
    )
    scheduler = get_cosine_schedule_with_warmup(
        optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_epochs
    )

    if resume_training:
        df = pd.read_csv(f"{save_path}/results.csv")
        results = list(df.T.to_dict().values())
        start_epoch = int(results[-1]["epoch"])

        model, optimizer = load_checkpoint(
            model, optimizer, f"{save_path}/ViT_{start_epoch}.pth", parallel, device
        )
        scheduler.load_state_dict(
            torch.load(f"{save_path}/scheduler_{start_epoch}.pth")
        )

        for _ in range(start_epoch):
            for _ in train_loader:
                break

        print(f"Resuming training from epoch {start_epoch}")

    print(f"Start trainning with {str(device).upper()}")
    for epoch in range(num_epochs):
        # Train step
        model.train()
        train_running_loss, train_correct = 0.0, 0
        with tqdm(
            total=len(train_loader),
            desc=f"Train epoch {epoch+1}/{num_epochs}",
            unit="batch",
        ) as pbar:
            for i, (images, labels) in enumerate(train_loader):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                train_running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                train_correct += (predicted == labels).sum().item()

                pbar.set_postfix({"loss": format(train_running_loss / (i + 1), ".4f")})
                pbar.update()

            train_loss = train_running_loss / len(train_loader)
            train_acc = 100 * train_correct / len(train_loader.dataset)
            pbar.set_postfix(
                {
                    "loss": format(train_loss, ".4f"),
                    "accuracy": format(train_acc, ".2f"),
                }
            )

        # Validation step
        model.eval()
        valid_running_loss, valid_correct = 0.0, 0
        with tqdm(
            total=len(valid_loader),
            desc=f"Valid epoch {epoch+1}/{num_epochs}",
            unit="batch",
        ) as pbar:
            for i, (images, labels) in enumerate(valid_loader):
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)

                valid_running_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                valid_correct += (predicted == labels).sum().item()

                pbar.set_postfix({"loss": format(valid_running_loss / (i + 1), ".4f")})
                pbar.update()

            valid_loss = valid_running_loss / len(valid_loader)
            valid_acc = 100 * valid_correct / len(valid_loader.dataset)
            pbar.set_postfix(
                {
                    "loss": format(valid_loss, ".4f"),
                    "accuracy": format(valid_acc, ".2f"),
                }
            )

        print("Last learning rate: ", scheduler.get_last_lr())
        scheduler.step()

        # Save results
        results.append(
            {
                "epoch": epoch + 1,
                "train_loss": train_loss,
                "train_acc": train_acc,
                "valid_loss": valid_loss,
                "valid_acc": valid_acc,
            }
        )
        df = pd.DataFrame(results)
        df.to_csv(f"{save_path}/results.csv", index=False)

        # Save checkpoint
        if epoch % 5 == 0:
            save_checkpoint(model, optimizer, f"ViT_{epoch+1}.pt", parallel)
            torch.save(scheduler.state_dict(), f"{save_path}/scheduler_{epoch+1}.pth")

# Train


In [None]:
train_loader, valid_loader, test_loader = load_dataset(
    "/kaggle/input/categories-classification/data", batch_size=32
)

In [None]:
model = ResNet101(num_classes=10)

In [None]:
train(
    model=model,
    train_loader=train_loader,
    valid_loader=valid_loader,
    save_path="./resnet101",
    num_epochs=100,
    resume_training=False,
    parallel=True,
)