[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ijdoc/wandb-demos/blob/main/pytorch/10-demo.ipynb)

# Pytorch Demo

## Install SDK

In [None]:
!pip install wandb

## Basic Integration

In [None]:
# 1. Import the W&B Python SDK
import wandb

config = {
    "epochs": 10,
    "learning_rate": 0.00003,
    "img_size": 128,
    "batch_size": 32,
}

device, datasets, loaders, model, loss_fn, optimizer = init_training(config)

# 2. Initialize logging
wandb.init(config=config, job_type="train")

for epoch in range(config["epochs"]):
    loss = train_step(model, loaders["train"], loss_fn, optimizer, device)
    val_loss, accuracy = eval_step(model, loaders["val"], loss_fn, device)

    print(
        f'Epoch {epoch+1}/{config["epochs"]}, '
        f"Training Loss: {loss:.4f}, "
        f"Validation Loss: {val_loss:.4f}, "
        f"Accuracy: {accuracy:.4f}"
    )

    # 3. Log metrics
    wandb.log({"train/loss": loss, "val/loss": val_loss, "val/acc": accuracy})

# 4. Done!!
wandb.finish()

## Log Artifacts

In [None]:
import wandb

config = {
    "epochs": 10,
    "learning_rate": 0.0001,
    "img_size": 128,
    "batch_size": 32,
}

device, datasets, loaders, model, loss_fn, optimizer = init_training(config)

wandb.init(config=config, job_type="train")

# 4. Reference the dataset used
wandb.use_artifact("team-jdoc/datasets/playing-cards:v0", type="dataset")

for epoch in range(config["epochs"]):
    loss = train_step(model, loaders["train"], loss_fn, optimizer, device)
    val_loss, accuracy = eval_step(model, loaders["val"], loss_fn, device)

    wandb.log({"train/loss": loss, "val/loss": val_loss, "val/acc": accuracy})

# 5. Log the resulting model
model_path = save_model_checkpoint(model, optimizer, epoch, loss)
model_artifact = wandb.log_artifact(model_path, type="model")

wandb.finish()

## Expand Lineage

In [None]:
import wandb

wandb.init(job_type="test")
wandb.use_artifact(model_artifact)

(loss, accuracy, misses) = test(
    config["img_size"],
    model,
    datasets["test"].classes,
    loaders["test"],
    loss_fn,
    device,
)

wandb.log({"test/loss": loss, "test/acc": accuracy, "test/misses": len(misses)})

# 6. Log a Table
wandb.log({"misses": wandb.Table(dataframe=misses)})

wandb.finish()

# Setup

## Import Libraries

In [None]:
!ltt install torch
!pip install timm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import torchvision
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import timm  # Where the model is stored

import matplotlib.pyplot as plt  # For data viz
import pandas as pd
import numpy as np
import sys
from tqdm.notebook import tqdm
from PIL import Image

print("System Version:", sys.version)
print("PyTorch version", torch.__version__)
print("Torchvision version", torchvision.__version__)
print("Numpy version", np.__version__)
print("Pandas version", pd.__version__)

## Prepare Data 

In [None]:
import random

train_folder = "./artifacts/playing-cards:v0/train/"
valid_folder = "./artifacts/playing-cards:v0/valid/"
test_folder = "./artifacts/playing-cards:v0/test/"


class ImageFolderWithPaths(ImageFolder):
    """Custom dataset that includes image file paths. Extends torchvision.datasets.ImageFolder"""

    # override the __getitem__ method. this is the method that dataloader calls
    def __getitem__(self, index):
        # this is what ImageFolder normally returns
        original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
        # the image file path
        path = self.imgs[index][0]
        # make a new tuple that includes original and the path
        tuple_with_path = original_tuple + (path,)
        return tuple_with_path


def get_transform(img_size):
    return transforms.Compose(
        [
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
        ]
    )


def prep_data(img_size, batch_size):
    transform = get_transform(img_size)
    # Load datasets as tensors

    train_dataset = ImageFolderWithPaths(train_folder, transform=transform)
    val_dataset = ImageFolderWithPaths(valid_folder, transform=transform)
    test_dataset = ImageFolderWithPaths(test_folder, transform=transform)

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    return {
        "sets": {"train": train_dataset, "val": val_dataset, "test": test_dataset},
        "loaders": {
            "train": train_dataloader,
            "val": val_dataloader,
            "test": test_dataloader,
        },
    }


def show_sample(dataset):
    image, label, path = random.choice(dataset)
    plt.imshow(Image.open(path))
    plt.axis("off")  # Hide the axis
    plt.show()


show_sample(prep_data(128, 128)["sets"]["train"])

## Prepare Model

In [None]:
class CardClassifier(nn.Module):
    def __init__(self, num_classes=53):
        super(CardClassifier, self).__init__()

        # Define base model
        self.base_model = timm.create_model("efficientnet_b0", pretrained=True)

        # Remove last layer
        self.features = nn.Sequential(*list(self.base_model.children())[:-1])
        # Keep a reference to the removed layer so we know how many connections
        # we need for the new last layer
        removed_layer = list(self.base_model.children())[-1]

        # Recreate the last layer (the classifier)
        self.classifier = nn.Sequential(
            nn.Flatten(), nn.Linear(removed_layer.in_features, num_classes)
        )

    def forward(self, x):
        # Connect these parts and return the output
        x = self.features(x)
        return self.classifier(x)


# model = CardClassifier(num_classes)
# predictions = model(images)
# predictions.shape  # [batch_size, num_classes]

## Helper Functions

In [None]:
import os


def init_training(config):
    # Device
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Training on {device}")

    # Data
    data_objects = prep_data(config["img_size"], config["batch_size"])
    datasets = data_objects["sets"]
    loaders = data_objects["loaders"]

    # Model
    model = CardClassifier()
    model = model.to(device)
    # Loss function
    loss_fn = nn.CrossEntropyLoss()

    # Optimizer
    optimizer = optim.Adam(model.parameters(), lr=config["learning_rate"])
    return device, datasets, loaders, model, loss_fn, optimizer


def train_step(model, dataloader, loss_fn, optimizer, device):
    model.train()  # Set the model to training mode
    total_loss = 0

    for X_batch, y_batch, _ in dataloader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(
            device
        )  # Move data to the appropriate device (CPU/GPU)

        optimizer.zero_grad()  # Reset gradients to zero to avoid accumulation

        y_pred = model(X_batch)  # Forward pass: compute the model output
        loss = loss_fn(y_pred, y_batch)  # Compute the loss
        total_loss += loss.item()

        loss.backward()  # Backward pass: compute the gradient of the loss with respect to model parameters
        optimizer.step()  # Update parameters

    average_loss = total_loss / len(dataloader)
    return average_loss  # Return the average loss for the epoch


def eval_step(model, dataloader, loss_fn, device):
    model.eval()  # Set the model to evaluation mode
    total_loss = 0
    total_correct = 0

    with torch.no_grad():  # No gradients needed for validation, saves memory and computations
        for X_batch, y_batch, _ in dataloader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)

            y_pred = model(X_batch)
            loss = loss_fn(y_pred, y_batch)
            total_loss += loss.item()

            # Assuming y_pred are raw logits, you could adapt this depending on your output
            predicted = torch.argmax(y_pred, dim=1)
            total_correct += (predicted == y_batch).sum().item()

    average_loss = total_loss / len(dataloader)
    accuracy = total_correct / len(dataloader.dataset)

    return (
        average_loss,
        accuracy,
    )  # Return the average loss and accuracy


def test(img_size, model, class_names, dataloader, loss_fn, device):
    model.eval()  # Set the model to evaluation mode
    total_loss = 0
    total_correct = 0

    transform = transforms.Compose(
        [
            transforms.Resize((img_size, img_size)),
        ]
    )
    # Create an empty DataFrame
    misses = pd.DataFrame(columns=["image", "truth", "guess"])

    with torch.no_grad():  # No gradients needed, saves memory and computations
        for X_batch, y_batch, image_paths in dataloader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)

            y_pred = model(X_batch)
            loss = loss_fn(y_pred, y_batch)
            total_loss += loss.item()

            # Convert outputs probabilities to predicted class
            _, preds = torch.max(y_pred, 1)
            total_correct += (preds == y_batch).sum().item()
            missed_idxs = torch.where(preds != y_batch)[0]
            # Append misclassified samples to the DataFrame
            for idx in missed_idxs:
                misses.loc[len(misses)] = {
                    "image": wandb.Image(transform(Image.open(image_paths[idx]))),
                    "truth": class_names[y_batch[idx]],
                    "guess": class_names[preds[idx]],
                }

    average_loss = total_loss / len(dataloader)
    accuracy = total_correct / len(dataloader.dataset)

    # Return the average loss, accuracy and misclassifications
    return (
        average_loss,
        accuracy,
        misses,
    )


def save_model_checkpoint(model, optimizer, epoch, loss):
    # Save the model state and the optimizer state
    checkpoint = {
        "epoch": epoch,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": optimizer.state_dict(),
        "loss": loss,
        # You can add more items to the checkpoint if needed
    }

    # Specify the directory you want to create
    checkpoints_dir = "./checkpoints"

    # Check whether the specified path exists or not
    if not os.path.exists(checkpoints_dir):
        os.makedirs(checkpoints_dir)  # Create a new directory because it does not exist
    filepath = f"{checkpoints_dir}/{epoch}_card_classifier_checkpoint.pth"

    # Save to file
    torch.save(checkpoint, filepath)

    return filepath

# Debug

In [None]:
!pip show wandb
import wandb

dir(wandb)

In [None]:
wandb.finish()

In [None]:
len(misses)