<a href="https://colab.research.google.com/github/ell-hol/stonks-wid-codex/blob/main/classification_customData.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [29]:
"""
A simple classification model based on a pretrained ResNet18 backbone.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F

from torchvision.models import resnet18


class ResNet18(nn.Module):
    def __init__(self, n_classes):
        super().__init__()
        self.resnet = resnet18(pretrained=True)
        self.resnet.fc = nn.Linear(self.resnet.fc.in_features, n_classes)
        # self.resnet.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3,
        #                               bias=False)

    def forward(self, x):
        x = self.resnet(x)
        return x

"""
A function to train the defined ResNet18 model on a custom Dataset.
The custom Dataset is built from directories where each directory is a class.
"""

from torchvision.transforms import ToTensor
from torchvision.transforms import Compose
from torchvision.transforms import Normalize
from torchvision.transforms import Resize

from sklearn.utils import shuffle

from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset
from torch.utils.data import Subset

from torch.optim import Adam
from torch.optim import SGD

from torch.optim.lr_scheduler import MultiStepLR

from torchvision.datasets import ImageFolder

from sklearn.metrics import accuracy_score
from sklearn.metrics import confusion_matrix

from sklearn.preprocessing import LabelEncoder

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

import os

RANDOM_STATE = 42


def train_model(dataset_dir,
                test_size=0.2,
                batch_size=32,
                num_workers=4,
                num_epochs=100,
                lr=0.001,
                exp_name=None,
                checkpoint_model=False,
                checkpoint_interval=10,
                use_gpu=False):
    """
    Trains a ResNet18 model on the specified dataset.

    Args:
        dataset_dir (str): The path to the dataset
        test_size (float): The size of the test set
        batch_size (int): The batch size
        num_workers (int): The number of worker threads to use for loading data
        num_epochs (int): The number of epochs to train for
        lr (float): The learning rate
        exp_name (str): The name of the experiment (for logging)
        checkpoint_model (bool): If True, save a checkpoint after each epoch
        checkpoint_interval (int): If checkpointing, save a checkpoint every
            `checkpoint_interval` epochs
        use_gpu (bool): If True, attempt to train on the GPU

    Returns:
        (torch.nn.Module, torch.optim.Optimizer, dict): A tuple containing the
            trained model, the optimizer, and a dictionary of metrics.

    """
    # Create the output directory
    os.makedirs("logs", exist_ok=True)

    # If no experiment name provided, make one
    if exp_name is None:
        exp_name = "exp_" + str(len(os.listdir("logs/"))).zfill(3)
        print("INFO: Using experiment name:", exp_name)

    # Prepare the dataset
    dataset = ImageFolder(dataset_dir,
                          transform=Compose([Resize((224, 224)), ToTensor()]))

    # Split the dataset into train and test sets
    train_size = int((1.0 - test_size) * len(dataset))
    test_size = len(dataset) - train_size
    train_dataset, test_dataset = \
        torch.utils.data.random_split(dataset, [train_size, test_size])

    # Encode the class labels
    le = LabelEncoder()
    le.fit(train_dataset.dataset.classes)

    # Create data loaders for training and test sets
    train_loader = DataLoader(train_dataset,
                              batch_size=batch_size,
                              num_workers=num_workers,
                              shuffle=True)
    test_loader = DataLoader(test_dataset,
                             batch_size=batch_size,
                             num_workers=num_workers,
                             shuffle=False)

    # Prepare the model
    model = ResNet18(len(le.classes_))
    if use_gpu:
        model = model.cuda()

    # Prepare the optimizer
    optimizer = Adam(model.parameters(), lr=lr)
    scheduler = MultiStepLR(optimizer, milestones=[75, 90, 95], gamma=0.1)

    # Set up the metrics
    metrics = {
        "loss": {},
        "acc": {},
        "confusion": {},
        "hist": {}
    }

    # Train the model
    for epoch in range(num_epochs):
        print("=" * 10, "EPOCH {}/{}".format(epoch + 1, num_epochs), "=" * 10)

        # Train the model for one epoch
        metrics = train(model, epoch, train_loader, optimizer, use_gpu, metrics)

        # Evaluate the model on the test set
        metrics = evaluate(model, epoch, test_loader, use_gpu, metrics)

        # Update the learning rate
        scheduler.step()

        # Logging
        if epoch % 10 == 0:
            print("\n\nLoss:", metrics["loss"][epoch])
            print("Accuracy:", metrics["acc"][epoch])

            print("\nConfusion Matrix")
            print(metrics["confusion"][epoch])

            # print("\nHistogram")
            # print(metrics["hist"][epoch])

        # # Plot the metrics
        # fig, axes = plt.subplots(1, 3, figsize=(10, 3))
        # axes[0].set_title("Loss")
        # axes[0].plot(metrics["loss"])
        # axes[0].grid(True)

        # axes[1].set_title("Accuracy")
        # axes[1].plot(metrics["acc"])
        # axes[1].grid(True)

        # # axes[2].set_title("Histogram")
        # # axes[2].plot(metrics["hist"])
        # # axes[2].grid(True)

        # plt.savefig("logs/" + exp_name + "/metrics.png")
        # plt.close(fig)

        # Save a checkpoint
        if checkpoint_model and (epoch % checkpoint_interval == 0):
            checkpoint_path = \
                "logs/" + exp_name + "/checkpoint_" + str(epoch) + ".pth"
            torch.save({
                "epoch": epoch,
                "model_state_dict": model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "loss": metrics["loss"][epoch],
                "acc": metrics["acc"][epoch],
                "confusion": metrics["confusion"][epoch]
            }, checkpoint_path)

    # Save the model
    torch.save(model.state_dict(), "logs/" + exp_name + "/model.pth")

    # Save the label encoder
    le_path = "logs/" + exp_name + "/le.npy"
    np.save(le_path, le.classes_)

    return model, optimizer, metrics

In [30]:
def train(model, epoch, data_loader, optimizer, use_gpu, metrics):
    """
    Trains the model for one epoch.

    Args:
        model (torch.nn.Module): The model to train.
        data_loader (torch.utils.data.DataLoader): The training dataset.
        optimizer (torch.optim.Optimizer): The optimizer.
        use_gpu (bool): If True, attempt to train on the GPU.
        metrics (dict): The dict where the metrics are stored.

    Returns:
        dict: The updated metrics dictionary.

    """
    # Prepare the metrics dictionary
    train_loss = []
    train_acc = []
    confusion = np.zeros((4, 4), dtype=int)

    # Set the model to train mode
    model.train()

    # Iterate over training batches
    for i, (inputs, targets) in enumerate(data_loader):
        # Prepare the inputs
        if use_gpu:
            inputs = inputs.cuda()
            targets = targets.cuda()

        optimizer.zero_grad()

        # Forward propagation
        logits = model(inputs)

        # Loss computation
        loss = F.cross_entropy(logits, targets)

        # Back propagation
        loss.backward()
        optimizer.step()

        # Compute the total number of correct classifications
        _, predicted = torch.max(logits, 1)
        batch_size = inputs.size(0)
        num_correct = (predicted == targets).sum().item()
        batch_acc = num_correct / batch_size

        # Update metrics
        train_loss.append(loss.item())
        train_acc.append(batch_acc)
        confusion += confusion_matrix(targets.cpu().numpy(), predicted.cpu().numpy(), labels=[0, 1, 2, 3])

        # Print the results on the console
        print("\r[Epoch {}] [Batch {} / {}] [Loss: {:.4f}] [Acc: {:.2f}%]".format(
            epoch + 1,
            i + 1,
            len(data_loader),
            loss.item(),
            batch_acc * 100
        ), end="")

    # Update the metrics dictionary
    metrics["loss"][epoch] = np.mean(train_loss)
    metrics["acc"][epoch] = np.mean(train_acc)
    metrics["confusion"][epoch] = confusion

    return metrics


def evaluate(model, val_epoch, data_loader, use_gpu, metrics):
    """
    Evaluates the model on the specified dataset.

    Args:
        model (torch.nn.Module): The model to evaluate.
        data_loader (torch.utils.data.DataLoader): The dataset to evaluate on.
        use_gpu (bool): If True, attempt to train on the GPU.
        metrics (dict): The dict where the metrics are stored.

    Returns:
        dict: The updated metrics dictionary.

    """
    # Prepare the metrics dictionary
    test_loss = []
    test_acc = []
    confusion = np.zeros((4, 4), dtype=int)

    # Set the model to eval mode
    model.eval()

    # Iterate over test batches
    for i, (inputs, targets) in enumerate(data_loader):
        # Prepare the inputs
        if use_gpu:
            inputs = inputs.cuda()
            targets = targets.cuda()

        # Forward propagation
        with torch.no_grad():
            logits = model(inputs)

        # Loss computation
        loss = F.cross_entropy(logits, targets)

        # Compute the total number of correct classifications
        _, predicted = torch.max(logits, 1)
        batch_size = inputs.size(0)
        num_correct = (predicted == targets).sum().item()
        batch_acc = num_correct / batch_size

        # Update metrics
        test_loss.append(loss.item())
        test_acc.append(batch_acc)
        confusion += confusion_matrix(targets.cpu().numpy(), predicted.cpu().numpy(), labels=[0, 1, 2, 3])

        # Print the results on the console
        print("\r[Batch {} / {}] [Loss: {:.4f}] [Acc: {:.2f}%]".format(
            i + 1,
            len(data_loader),
            loss.item(),
            batch_acc * 100
        ), end="")

    # Update the metrics dictionary
    metrics["loss"][val_epoch] = np.mean(test_loss)
    metrics["acc"][val_epoch] = np.mean(test_acc)
    metrics["confusion"][val_epoch] = confusion

    return metrics

In [31]:
!rm -rf data/.ipynb_checkpoints/

In [None]:
train_model(dataset_dir='data', use_gpu=True)

INFO: Using experiment name: exp_000


  cpuset_checked))


[Batch 1 / 1] [Loss: 4.4176] [Acc: 22.22%]

Loss: 4.417627811431885
Accuracy: 0.2222222222222222

Confusion Matrix
[[0 1 4 0]
 [0 0 2 0]
 [0 0 2 0]
 [0 0 0 0]]
[Batch 1 / 1] [Loss: 0.1055] [Acc: 100.00%]

Loss: 0.10554042458534241
Accuracy: 1.0

Confusion Matrix
[[5 0 0 0]
 [0 2 0 0]
 [0 0 2 0]
 [0 0 0 0]]
