In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import wandb
import numpy as np
import matplotlib.pyplot as plt

import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

from utils.config import Config
from utils.validation import TensorModel, BatchModel
from utils.architecures import NeuralNetworkV1, NeuralNetworkV2
from utils.training import multiclass_validation_step
from utils.plots import plot_decision_boundary

## Configs

- *The `Config` class is utilized to encapsulate model hyperparameters and device configuration, ensuring reproducibility and streamlined experiment management.*
- *The `wandb.init()` function initializes a new Weights & Biases run, specifying project details, team, experiment name, and relevant context parameters for comprehensive experiment tracking.*

In [None]:
hyperparameters = {
    'name': 'multi-class-network-v1',
    'epochs': 100,
    'batch_size': 16,
    'hidden_size': [32, 16, 8],
    'learning_rate': 0.0001,
    'n_features': 2,
    'n_classes': 4,
    'weight_decay': 0.0005
}

config = Config(hyperparameters)

In [None]:
# Log in to Weights & Biases. This will prompt you to enter your API key if not already logged in.
wandb.login()

# Initialize Weights & Biases. This will start a new run and log the hyperparameters.
run = wandb.init(
    project='pytorch-bootcamp',
    entity='nikossacoff-development',
    name=config.name,
    config={
        'model': 'MultiClassNetworkV1',
        'optimizer': 'Adam',
        'criterion': 'CrossEntropyLoss',
        'hyperparameters': hyperparameters
    }
)

## Load data

- *We use the `torch.load()` function to load our training and evaluation data. We ensure they are on the correct device.*
- *We validate the tensors using `TensorModel`.*

In [None]:
# Load tensors
train_data = torch.load('temp/data/multiclass-classification/train_data.pth').to(device=config.device)
val_data = torch.load('temp/data/multiclass-classification/validation_data.pth').to(device=config.device)
eval_data = torch.load('temp/data/multiclass-classification/evaluation_data.pth').to(device=config.device)

# Validate tensors
train_data = TensorModel(tensor=train_data, tensor_dimensions=2).tensor
val_data = TensorModel(tensor=val_data, tensor_dimensions=2).tensor
eval_data = TensorModel(tensor=eval_data, tensor_dimensions=2).tensor

In [None]:
# Split the data into features and labels
X_train, y_train = train_data[:, :-1], train_data[:, -1]
X_val, y_val = val_data[:, :-1], val_data[:, -1]
X_eval, y_eval = eval_data[:, :-1], eval_data[:, -1]

logging.info(f"Training data: {X_train.shape} | Labels: {y_train.shape}")
logging.info(f"Validation data: {X_val.shape} | Labels: {y_val.shape}")
logging.info(f"Evaluation data: {X_eval.shape} | Labels: {y_eval.shape}")

## Build a model

### PyTorch Model

In [None]:
class MultiClassNetworkV1(nn.Module):
    def __init__(self, n_features: int, hidden_size: list, n_classes: int, device: torch.device):
        super().__init__()
        self.stack = nn.Sequential(
            nn.Linear(n_features, hidden_size[0]),
            nn.ReLU(),
            nn.Linear(hidden_size[0], hidden_size[1]),
            nn.ReLU(),
            nn.Linear(hidden_size[1], hidden_size[2]),
            nn.ReLU(),
            nn.Linear(hidden_size[2], n_classes)
        )

        self.to(device)

    def forward(self, x):
        return self.stack(x)

model = MultiClassNetworkV1(
    n_features=config.n_features,
    hidden_size=config.hidden_size,
    n_classes=config.n_classes,
    device=config.device
)

### Loss function

- *For multi-class classification tasks, Cross-Entropy Loss is the standard choice due to its effectiveness in measuring the discrepancy between predicted and true class distributions.*
- *We employ PyTorch's `nn.CrossEntropyLoss`, which operates directly on the model's raw output logits, thereby mitigating numerical instability and ensuring robust optimization.*

In [None]:
criterion = nn.CrossEntropyLoss()

### Optimizer

In [None]:
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)

### DataLoaders

In [None]:
# Create a TensorDataset instante for the training data and validation data
train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)
eval_dataset = TensorDataset(X_eval, y_eval)

# Create DataLoaders for the training and validation datasets
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=config.batch_size, shuffle=False)
eval_loader = DataLoader(eval_dataset, batch_size=config.batch_size, shuffle=False)

## Training Loop

In [None]:
# Count the number of steps. We will validate the model every 5 steps.
step = 0

# Count the number of epochs without improvement.
# If the validation loss is lower than the best validation loss (minus a small delta value), we reset the counter to 0.
# If the validation loss is not lower than the best validation loss, we increment the counter by 1. When counter = patience, we stop the training.
counter = 0
best_val_loss = float('inf')

for epoch in np.arange(config.epochs):
    # Accumulate the training and validation loss for this epoch.
    train_loss_accum = 0.0
    validation_loss_accum = 0.0

    # Training loop
    for (X_train, y_train) in train_loader:
        # Set the model to training mode
        model.train()

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(X_train)

        # Compute the loss
        train_loss = criterion(outputs, y_train.long())
        train_loss_accum += train_loss.item()
        
        # Backward pass
        train_loss.backward()

        # Clip gradients to prevent exploding gradients
        grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        # Calculate the weight norm
        weight_norm = torch.norm(torch.stack([torch.norm(param) for param in model.parameters()]))

        # Log the metrics
        run.log({
            'Metrics/Gradient norm': grad_norm,
            'Metrics/Weight norm': weight_norm
        })

        # Update the model parameters
        optimizer.step()

    ### Inference phase
    validation = multiclass_validation_step(
        model=model,
        dataloader=val_loader,
        criterion=criterion,
        device=config.device
    )

    run.log({
        'Metrics/Accuracy': validation['accuracy'],
        'Metrics/Precision': validation['precision'],
        'Metrics/Recall': validation['recall'],
        'Metrics/F1-Score': validation['f1_score']
    })

    # Average the training and validation loss
    train_loss = train_loss_accum / len(train_loader)

    # Log the training loss for the whole epoch
    run.log({
        'Loss/Training': train_loss,
        'Loss/Validation': validation['loss']
    })

    if (epoch + 1) % 5 == 0:
        logging.info(f"Epoch: {epoch + 1}/{config.epochs} | Training loss: {train_loss:.4f} | Validation loss: {validation['loss']:.4f}")
        logging.info(f"Epoch: {epoch + 1}/{config.epochs} | Accuracy: {validation['accuracy']:.4f} | F1-Score: {validation['f1_score']:.4f}\n")

In [None]:
# Save model
torch.save(model.state_dict(), 'temp/models/multi-class-nn.pth')

# Generate a wandb artifact
model_artifact = wandb.Artifact(
    name='multi-class-nn',
    type='model',
    metadata={
        'input_size': config.n_features,
        'hidden_size': config.hidden_size,
        'epochs': config.epochs,
        'learning_rate': config.learning_rate,
        'dropout': config.dropout,
        'weight_decay': config.weight_decay
    }
)

model_artifact.add_file('temp/models/multi-class-nn.pth', name='model.pth')
run.log_artifact(model_artifact)

In [None]:
run.finish()