# Skin Cancer Detection using Deep Learning
This Jupyter notebook demonstrates how to build a deep learning model for skin cancer detection using the ResNet architecture. The model classifies skin lesion images as either benign or malignant.

## Prerequisites

Please run `02_data_preparation.ipynb` before running this notebook as it uses the output of data preparation. It ensures that the processed data are located where it is needed for this notebook to run smoothly.

### Dataset Requirements
- The processed dataset from `02_data_preparation.ipynb` must be inside `data/processed` folder
- Expected structure:
  - `data/processed/train_dataset.pt` - Processed training dataset
  - `data/processed/val_dataset.pt` - Processed validation dataset
  - `data/processed/test_dataset.pt` - Processed testing dataset

### Experiments Requirements
This project utilises Weights & Biases (wandb) to experiment various model configurations. To perform experiments with different configurations, you need to follow the following steps first:

- Create a wandb account at [wandb.ai](https://wandb.ai) if you don't have one yet
- Get your Wandb API key from your account settings
- Add your Wandb API key in the `.env` file with the key `WANDB_API_KEY`
- If the `.env` file doesn't exist, copy the `.env.example` file to create `.env` and replace the placeholder with your actual API key

### Environment Setup
Install the required packages from `requirements.txt` file.

### Notebook Structure
This notebook is organised into the following sections:
- Import Libraries - Required Python packages for the project
- Constants - Define paths and other constants
- Get Environment Variables - Load configuration for Weights & Biases
- Load Datasets - Prepare processed datasets for model training
- Train the Model - Functions for model training, validation, and fine-tuning
- Save the Model - Utilities to persist trained models
- Test the Model - Evaluate model performance
- Create an Experiment - Set up Weights & Biases experiment tracking
- Experiments - Run training with specific configurations
- Model Architecture Analysis - Analyse different model architectures
- Hyperparameters Tuning - Optimise model performance through systematic hyperparameter search

## Import Libraries

In [None]:
# Standard library imports
import os
from pathlib import Path

# Third-party imports
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from dotenv import load_dotenv

from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix, roc_auc_score

import torch
import torch.nn as nn
from torch.nn.functional import softmax
import torch.optim as optim
from torch.utils.data import DataLoader

import yaml
import wandb
from wandb.sdk.wandb_run import Run
from tqdm import tqdm

# Local imports
from scd.utils.common import load_datasets
from scd.model import SkinCancerCNN

## Constants

In [None]:
num_classes = 2 # Malignant and Benign
random_state = 42

# Define paths
root_dir = Path.cwd().parent
model_dir = root_dir / 'models'
processed_data_dir = root_dir / 'data' / 'processed'

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

## Get Environment Variables

We load environment variables from the `.env` file, which contains configuration for Weights & Biases (W&B) experiments.

The following environment variables are loaded:
- `WANDB_API_KEY`: Authentication key for accessing W&B services
- `WANDB_ENTITY`: Username or organization name in W&B
- `WANDB_PROJECT`: Name of the project in W&B for organizing experiments

These configurations enable systematic tracking of our skin cancer detection experiments, allowing us to compare different model architectures and hyperparameter settings.

In [None]:
# Load environment variables
load_dotenv()
WANDB_API_KEY = os.getenv('WANDB_API_KEY')
WANDB_ENTITY = os.getenv('WANDB_ENTITY') or 'the_lab'
WANDB_PROJECT = os.getenv('WANDB_PROJECT') or 'skin_cancer_detection'

## Load Datasets

We load the pre-processed datasets that were prepared in the previous notebook. These datasets contain skin lesion images that have been resized to the appropriate dimensions for our ResNet34 model (384x384 pixels) and normalised according to ImageNet statistics.

The datasets are loaded as TensorDatasets, which contain both the image tensors and their corresponding labels (benign or malignant). We then create DataLoader objects to efficiently batch the data during training, validation, and testing.

In [None]:
batch_size = 32

# Load the saved datasets
train_tensor_dataset, val_tensor_dataset, test_tensor_dataset = load_datasets(processed_data_dir)

# Create DataLoader objects
train_loader = DataLoader(train_tensor_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_tensor_dataset, batch_size=batch_size, shuffle=False)
test_loader = DataLoader(test_tensor_dataset, batch_size=batch_size, shuffle=False)

## Train the Model

### Train Epoch

The `train_epoch` function handles one complete training cycle through all batches in the training dataset. It tracks key metrics and reports them to Weights & Biases if an experiment run is provided for logging, enabling experiment tracking and visualisation.

In [None]:
def train_epoch(model: nn.Module, train_loader: DataLoader, criterion: nn.Module, optimiser: torch.optim.Optimizer, epochs: int, current_epoch: int, wandb_run: Run = None, phase: str = 'Training') -> tuple:
    """
    Train the model for one epoch.

    Parameters
    ----------
    model : nn.Module
        The model to train.
    train_loader : DataLoader
        The data loader for the training data.
    criterion : nn.Module
        The loss function.
    optimiser : torch.optim.Optimizer
        The optimizer for updating model weights.
    epochs : int
        The total number of epochs for training.
    current_epoch : int
        The current epoch number.
    wandb_run : Run, optional
        The Weights & Biases run object for logging.
    phase : str, optional
        The phase of training (default is 'Training').

    Returns
    -------
    tuple
        The average loss and accuracy for the epoch.
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model.train()

    running_loss = 0.0
    correct = 0
    total = 0

    train_pbar = tqdm(train_loader, desc=f"Epoch [{current_epoch+1}/{epochs}]", ncols=120)

    # Training loop
    for i, (images, labels) in enumerate(train_pbar):
        images, labels = images.to(device), labels.to(device)

        optimiser.zero_grad()

        outputs = model(images)
        if isinstance(outputs, tuple):
            outputs = outputs[0]

        loss = criterion(outputs, labels)
        loss.backward()

        optimiser.step()

        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        mean_loss = running_loss / (i + 1)
        accuracy = 100 * correct / total

        train_pbar.set_postfix({
            'loss': f'{mean_loss:.4f}',
            'acc': f'{accuracy:.2f}%',
        })

    # Log once per epoch
    if wandb_run:
        wandb_run.log({
            'train_loss': mean_loss,
            'train_accuracy': accuracy,
            'epoch': current_epoch + 1,
            'phase': phase
        })

    return mean_loss, accuracy

### Validate Epoch

The `validate_epoch` function evaluates the model's performance on validation data during training. It measures metrics like loss, accuracy, and ROC AUC score to track how well the model generalizes to unseen data. This function is crucial for monitoring model performance, detecting overfitting, and determining when to stop training or adjust hyperparameters.

In [None]:
def validate_epoch(model: nn.Module, val_loader: DataLoader, criterion: nn.Module, epoch: int, wandb_run: Run = None, phase: str = 'Validation') -> tuple:
    """
    Validate the model for one epoch.

    Parameters
    ----------
    model : nn.Module
        The model to validate.
    val_loader : DataLoader
        The data loader for the validation data.
    criterion : nn.Module
        The loss function.
    epoch : int
        The current epoch number.
    wandb_run : Run, optional
        The Weights & Biases run object for logging.
    phase : str, optional
        The phase of validation (default is 'Validation').

    Returns
    -------
    tuple
        The average loss and accuracy for the epoch.
    """
    model.eval()

    val_loss = 0.0
    val_correct = 0
    val_total = 0
    all_labels = []
    all_probs = []

    val_pbar = tqdm(val_loader, desc=phase, ncols=120)

    with torch.no_grad():
        for images, labels in val_pbar:
            images, labels = images.to(device), labels.to(device)

            outputs, _ = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            probs = softmax(outputs, dim=1)  # or F.sigmoid for binary
            _, predicted = torch.max(probs, 1)

            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

            all_labels.append(labels.cpu().numpy())
            all_probs.append(probs.cpu().numpy())

            val_pbar.set_postfix({
                'val_loss': f'{val_loss / (val_total // labels.size(0)):.4f}',
                'val_acc': f'{100 * val_correct / val_total:.2f}%'
            })

    avg_val_loss = val_loss / len(val_loader)
    val_accuracy = 100 * val_correct / val_total

    # Flatten arrays
    y_true = np.concatenate(all_labels)
    y_score = np.concatenate(all_probs)

    # Calculate AUC
    auc = roc_auc_score(y_true, y_score[:, 1])

    if wandb_run:
        wandb_run.log({
            f'{phase.lower()}_loss': avg_val_loss,
            f'{phase.lower()}_accuracy': val_accuracy,
            f'{phase.lower()}_auc': auc,
            'epoch': epoch + 1,
            'phase': phase
        })

    return avg_val_loss, val_accuracy

### Fine-Tuning

Fine-tuning is a critical technique in transfer learning where we take a pre-trained model and selectively retrain some of its layers to adapt it to our specific task. The `fine_tuning()` function implements this technique for skin cancer classification by:

1. **Selective Layer Freezing**: The function freezes early layers of the model that have learned general visual features, while unfreezing later layers to allow adaptation to skin lesion characteristics.

2. **Differential Learning Rates**: The function applies smaller learning rates to the middle layers and a higher learning rate (5x) to the final classification layer, enabling fine-grained parameter updates according to each layer's role.

This approach leverages the general feature extraction capabilities already learned by the model while specialising the deeper layers for distinguishing between benign and malignant skin lesions, achieving better performance than either training from scratch or using the pre-trained model without adaptation.

In [None]:
def fine_tuning(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, criterion: nn.Module, epochs: int = 5, patience: int = 5, learning_rate: float = 1e-4, wandb_run: Run = None) -> None:
    """
    Fine-tune the model on the training dataset with early stopping.

    Parameters
    ----------
    model : nn.Module
        The pre-trained model to fine-tune.
    train_loader : DataLoader
        DataLoader for the training dataset.
    val_loader : DataLoader
        DataLoader for the validation dataset.
    criterion : nn.Module
        Loss function to compute the loss.
    epochs : int, optional
        Number of epochs for fine-tuning. Defaults to 10.
    patience : int, optional
        Number of epochs with no improvement after which training will be stopped. Defaults to 5.
    learning_rate : float, optional
        Learning rate for the optimizer. Defaults to 1e-4.
    wandb_run : Run, optional
        Weights & Biases run object for logging metrics. Defaults to None.
    """
    # Initialise the best validation loss and counter for fine-tuning
    best_val_loss = float('inf')
    counter = 0

    # Unfreeze deeper layers for fine-tuning
    for name, param in model.named_parameters():
        if any(layer in name for layer in ['backbone.layer4', 'attention', 'classifier']):
            param.requires_grad = True
        else:
            param.requires_grad = False

    # Define optimiser
    optimiser = torch.optim.Adam([
        {"params": model.backbone.layer4.parameters(), "lr": learning_rate},
        {"params": model.attention.parameters(), "lr": learning_rate},
        {"params": model.classifier.parameters(), "lr": learning_rate * 5}
    ])

    # Fine-tuning loop with early stopping
    for epoch in range(epochs):
        _, _ = train_epoch(model, train_loader, criterion, optimiser, epochs, current_epoch=epoch, wandb_run=wandb_run, phase='Fine-tuning')
        val_loss, _ = validate_epoch(model, val_loader, criterion, epoch, wandb_run=wandb_run, phase='Fine-tuning Validation')

        # Early Stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print("Early stopping triggered.")
                break

### Train Model

The `train_model()` function orchestrates the entire training process for skin cancer classification models. This function:

1. **Model Preparation**:
  - Moves the model to the appropriate device (CPU/GPU)
  - Initialises the Adam optimizer with the specified learning rate

2. **Loss Function Configuration**:
  - Optionally applies class weighting to address dataset imbalance

3. **Training Process Management**:
  - Implements early stopping to prevent overfitting
  - Tracks best validation performance
  - Logs metrics to Weights & Biases for experiment tracking

4. **Fine-tuning Support**:
  - Optionally enables fine-tuning after initial training
  - Uses a selective layer unfreezing approach with differential learning rates

The function provides a comprehensive framework for training deep learning models on skin lesion datasets, with flexibility for experimenting with various training strategies and hyperparameters.

In [None]:
def train_model(model: nn.Module, train_loader: DataLoader, val_loader: DataLoader, epochs: int = 5, learning_rate: float = 1e-3, fine_tune: bool = False, use_class_weights: bool = True, wandb_run: Run = None) -> None:
    """
    Train the model with feature extraction and optional fine-tuning.

    Parameters
    ----------
    model : nn.Module
        The model to train.
    train_loader : DataLoader
        DataLoader for the training dataset.
    val_loader : DataLoader
        DataLoader for the validation dataset.
    epochs : int, optional
        Number of epochs to train the model. Defaults to 5.
    learning_rate : float, optional
        Learning rate for the optimizer. Defaults to 1e-3.
    fine_tune : bool, optional
        Whether to fine-tune the model after feature extraction. Defaults to False.
    use_class_weights : bool, optional
        Whether to use class weights in the loss function. Defaults to True.
    wandb_run : Run, optional
        Weights & Biases run object for logging metrics. Defaults to None.
    """

    # Set the device for training
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Move model to device
    model = model.to(device)

    # Initialise the optimiser
    optimiser = optim.Adam(model.parameters(), lr=learning_rate)

    # Set up early stopping parameters
    patience = 5
    best_val_loss = float('inf')
    counter = 0

    # Check if class weights should be used
    if use_class_weights:
        # Extract labels from the train_loader
        all_labels = []
        for _, labels in train_loader:
            all_labels.append(labels.cpu().numpy())
        train_labels = np.concatenate(all_labels)

        # Compute class weights using the labels
        class_weights = compute_class_weight('balanced', classes=np.unique(train_labels), y=train_labels)
        class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

        # Define the loss function with class weights
        criterion = nn.CrossEntropyLoss(weight=class_weights)
    else:
        # Define the loss function without class weights
        criterion = nn.CrossEntropyLoss()

    criterion = criterion.to(device)

    if fine_tune:
        # Freeze the pretrained backbone layers
        for name, param in model.named_parameters():
            if 'features' in name:
                param.requires_grad = False
            else:
                param.requires_grad = True

    # Training loop
    for epoch in range(epochs):
        _, _ = train_epoch(model, train_loader, criterion, optimiser, epochs, current_epoch=epoch, wandb_run=wandb_run)
        val_loss, _ = validate_epoch(model, val_loader, criterion, epoch, wandb_run=wandb_run)

        # Early Stopping
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            counter = 0
        else:
            counter += 1
            if counter >= patience:
                print("Early stopping triggered.")
                break


    # Fine-tuning the model
    if fine_tune:
        print('\n' + '#' * 50)
        print("Starting fine-tuning...")
        print('#' * 50, '\n')

        fine_tuning(
            model,
            train_loader,
            val_loader,
            criterion,
            epochs=epochs,
            patience=patience,
            learning_rate=learning_rate,
            wandb_run=wandb_run
        )

## Save the Model

The `save_model()` function persists trained deep learning models to disk for future use. It ensures that models are preserved after training, allowing for later inference without retraining.

In [None]:
def save_model_weights(model: nn.Module, model_dir: str, model_name: str) -> None:
    """
    Save the trained model to the specified directory.

    Parameters
    ----------
    model : nn.Module
        The trained model to save.
    model_dir : str
        Directory where the model will be saved.
    model_name : str
        Name of the model file to save (e.g., 'model.pth').
    """
    # Ensure the model directory exists
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    model_path = Path(model_dir) / model_name

    # Save the model state dictionary
    torch.save(model.state_dict(), model_path)
    print(f"Model saved to {model_path}")

## Test the Model

The `test_model()` function evaluates the performance of a trained deep learning model on unseen data. This function:

1. **Performs Inference on Test Data**:
  - Sets the model to evaluation mode to disable dropout and batch normalisation updates
  - Processes the test dataset batch by batch
  - Collects model predictions and ground truth labels

2. **Calculates Performance Metrics**:
  - Classification report with precision, recall, and F1-score for each class
  - Confusion matrix to visualise true positives, false positives, true negatives, and false negatives
  - ROC AUC score to evaluate the model's discriminative ability

3. **Visualisation and Reporting**:
  - Generates a visual heatmap of the confusion matrix
  - Formats and prints classification metrics for easy interpretation

4. **Optional Experiment Tracking**:
  - When provided with a Weights & Biases run object, logs all metrics and visualisations
  - Supports detailed experiment comparison and model versioning

This comprehensive evaluation helps determine the model's effectiveness at identifying malignant skin lesions and provides insights for potential improvements.



In [None]:
def test_model(model: nn.Module, test_loader: DataLoader, wandb_run: Run = None) -> None:
    """
    Test the model on the test dataset and log results.

    Parameters
    ----------
    model : nn.Module
        The trained model to evaluate.
    test_loader : DataLoader
        DataLoader for the test dataset.
    wandb_run : Run, optional
        Weights & Biases run object for logging. Defaults to None.
    """
    # Initialize lists to store predictions and labels
    all_preds = []
    all_labels = []

    # Evaluate the model on the test set
    model.eval()
    with torch.no_grad():
        # Loop through the test set
        for images, labels in test_loader:
            # Move data to device
            images, labels = images.to(device), labels.to(device)

            # Predict using the model
            outputs, _ = model(images)

            # Get predictions
            preds = torch.argmax(outputs, dim=1)

            # Store predictions and labels
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    # Classification report
    print(classification_report(all_labels, all_preds, target_names=['Benign', 'Malignant']))

    # Compute confusion matrix
    cm = confusion_matrix(all_labels, all_preds)

    # Compute ROC_AUC Score
    auc = roc_auc_score(all_labels, all_preds) 
    print('ROC AUC Score:', auc)

    # Log confusion matrix to Weights & Biases
    if wandb_run:
        wandb_run.log({
            'classification_report': wandb.Table(dataframe=pd.DataFrame(classification_report(all_labels, all_preds, target_names=['Benign', 'Malignant'], output_dict=True)).transpose()),
            'confusion_matrix': wandb.plot.confusion_matrix(
                probs=None,
                y_true=all_labels,
                preds=all_preds,
                class_names=['Benign', 'Malignant']
            ),
            'accuracy': np.mean(np.array(all_preds) == np.array(all_labels)) *100,
            'roc_auc_score': auc
        })

    # Plot confusion matrix
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=['Benign', 'Malignant'],
                yticklabels=['Benign', 'Malignant'])
    plt.xlabel('Predicted')
    plt.ylabel('True label')
    plt.title('Confusion Matrix')
    plt.show()

## Create an Experiment

The `create_wandb_experiment()` function integrates our skin cancer detection model with Weights & Biases (W&B) for experiment tracking. This function:

1. **Initialises W&B Experiment**: Creates a new experiment run with the provided name
2. **Configures Experiment Parameters**: Registers hyperparameters and settings in the W&B dashboard
3. **Establishes Project Context**: Links the experiment to our skin cancer detection project
4. **Handles Authentication**: Uses the API key from environment variables to access W&B services
5. **Returns Control Object**: Provides a run object for logging metrics throughout model training

This integration enables systematic tracking of model performance, visualisation of results, and comparison of different experimental configurations, which is essential for methodical research in deep learning applications for medical image analysis.

In [None]:
def create_wandb_experiment(name:str, config: dict) -> Run:
    """
    Function to create experiment to log the model and metrics to Weights & Biases.

    Parameters
    ----------
    name : str
        Name of the experiment run.
    config : dict
        Configuration dictionary containing hyperparameters and other settings.
    
    Returns
    -------
    wandb.run : Run
        A Weights & Biases run object for logging metrics and artifacts.
    """

    return wandb.init(
        # Set the wandb entity and project name
        entity=WANDB_ENTITY,
        project=WANDB_PROJECT,

        # Set the name of the run
        name=name,

        # Set the configuration for the run
        config=config,
    )

## Experiments

The `experiment()` function allows us to systematically run experiments with different configurations of our skin cancer detection model. The systematic experimentation approach helps identify the optimal model configuration for skin cancer detection, balancing accuracy, generalisation, and computational efficiency.

In [None]:
def experiment(model: nn.Module, data_loaders: tuple, hyperparameters: dict = {}, experiment_name: str = 'experiment', save_model: bool = False, save_model_as: str = 'model.pth') -> None:
    """
    Main function to run the skin cancer classification experiment.

    Parameters
    ----------
    model: nn.Module
        Model to experiment with
    data_loaders : tuple
        Tuple containing DataLoader objects for training, validation, and testing datasets.
    hyperparameters : dict, optional
        Dictionary containing hyperparameters for training. Defaults to an empty dictionary.
    experiment_name : str, optional
        Name of the experiment for logging purposes. Defaults to 'experiment'.
    save_model : bool, optional
        Whether to save the trained model. Defaults to False.
    save_model_as : str, optional
        Name of the file to save the model as. Defaults to 'model.pth'.
    """
    try:
        # Create a Weights & Biases experiment
        run = create_wandb_experiment(
            name=experiment_name,
            config=hyperparameters
        )

        # Unpack the data loaders
        train_loader, val_loader, test_loader = data_loaders

        # Train the model
        print('#' * 50)
        print("Training the model...")
        print('#' * 50, '\n')
        train_model(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            epochs=hyperparameters['epochs'],
            learning_rate=hyperparameters['learning_rate'],
            use_class_weights=hyperparameters['use_class_weights'],
            fine_tune=hyperparameters['fine_tune'],
            wandb_run=run,
        )

        # Test the model
        print('\n\n' + '#' * 50)
        print("Testing the model...")
        print('#' * 50, '\n')
        test_model(
            model=model,
            test_loader=test_loader,
            wandb_run=run
        )

        # Save the model
        if save_model:
            print('#' * 50)
            print("Saving the model...")
            print('#' * 50, '\n')
            save_model_weights(
                model=model,
                model_dir=model_dir,
                model_name=save_model_as
            )
        
        # Finish the Weights & Biases run
        run.finish()
    except Exception as e:
        if 'run' in locals():
            run.finish()
        raise e

# Model Architecture Analysis

In this section, we analyse the architecture of our pre-trained models for skin cancer detection with various configurations.

In [None]:
model = SkinCancerCNN()

# Initialise the constants
experiment_name = f'ResNet_skin_cancer_classification'

hyperparameters = {
    'epochs': 10,
    'learning_rate': 1e-4,
    'batch_size': batch_size,
    'use_class_weights': True,
    'fine_tune': False,
    'type': 'aug-training'
}

# Perform the experiment
try:
  experiment(
        model=model,
        data_loaders=(train_loader, val_loader, test_loader),
        hyperparameters=hyperparameters,
        experiment_name=experiment_name,
        save_model=True,
        save_model_as=f'{experiment_name}.pth'
    )
except Exception as e:
    print(f"An error occurred during the experiment: {e}")

# Hyperparameters Tuning

Hyperparameter tuning is a critical step in optimising models. To systematically explore different hyperparameter combinations to improve model performance, we use Weights & Biases Sweep.

## Sweep Configuration

We load a predefined sweep configuration from `configs/sweep.yaml` that specifies the hyperparameter search space, including learning rates, batch sizes, and number of epochs.

In [None]:
sweep_config_path = root_dir / 'configs' / 'sweep.yaml'
with open(sweep_config_path, 'r') as file:
    sweep_config = yaml.safe_load(file)

## Training Function

The `train_sweep()` function initialises a W&B run for each hyperparameter combination, trains the model with those parameters, and logs performance metrics.

In [None]:
def train_sweep():
    wandb.init(
        project=WANDB_PROJECT,
        entity=WANDB_ENTITY,
        name=f"sweep_run"
    )

    config = wandb.config

    train_model(
        model=SkinCancerCNN(),
        train_loader=train_loader,
        val_loader=val_loader,
        epochs=config['epochs'],
        learning_rate=config['learning_rate'],
        use_class_weights=True,
        fine_tune=False,
        wandb_run=wandb,
    )

    wandb.finish()


## Sweep Agent

We run multiple training experiments automatically using the W&B agent, which selects hyperparameter combinations according to the search strategy defined in the sweep configuration.

In [None]:
# Run wandb sweep command using the sweep configuration
sweep_id = wandb.sweep(sweep_config, project=WANDB_PROJECT)

# Start the sweep agent to run the training function
wandb.agent(sweep_id, function=train_sweep, count=5)