## 5. Training loop

The provided training loop is responsible for iteratively adjusting the model's parameters to minimize the loss and improve performance. This loop includes both training and validation phases for a specified number of epochs, allowing assessment of the model's performance on unseen data after each epoch.

#### Key Steps in the Training Loop:

1. **Epoch Loop**: The outer loop runs for a specified number of epochs, where each epoch signifies a complete pass through the training dataset.

2. **Training Phase**:
   - The model is set to training mode using `model.train()`, which ensures that layers like dropout and batch normalization operate in training mode.
   - For each batch in the `train_loader`, the inputs and labels are transferred to the appropriate device (GPU or CPU).
   - Gradients are zeroed using `optimizer.zero_grad()` to prevent accumulation from previous iterations.
   - The model makes predictions which are compared against the true labels using the loss function (`criterion`).
   - Backpropagation is used to compute gradients, and the optimizer updates the model parameters.

3. **Loss Calculation**:
   - The running loss is accumulated to compute the average loss for the epoch, which is printed to monitor training progress.

4. **Validation Phase**:
   - The model is switched to evaluation mode with `model.eval()`, ensuring that layers behave accordingly (e.g., no dropout).
   - The validation loop computes the model's accuracy on the validation dataset without updating the model parameters (`torch.no_grad()` ensures no gradients are calculated).
   - The validation accuracy is printed to assess how well the model generalizes to unseen data.

In [1]:
from tqdm import tqdm

def train(model, criterion, optimizer, train_loader, val_loader, num_epochs=5):
    """
    Train the model.
    
    Args:
        model: The model to train.
        criterion: The loss function.
        optimizer: The optimizer.
        train_loader: DataLoader for the training data.
        val_loader: DataLoader for the validation data.
        num_epochs (int): Number of epochs to train.
    
    Returns:
        model: The trained model.
    """
    for epoch in range(num_epochs):
        # Set model to training mode
        model.train()
        running_loss = 0.0
        for i, data in enumerate(tqdm(train_loader)):
            inputs, labels = data
            # Move data to the appropriate device
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass to get model outputs
            outputs = model(inputs)

            # Compute the loss
            loss = criterion(outputs, labels)
            # Backward pass to compute gradients
            loss.backward()
            # Update model parameters
            optimizer.step()

            # Accumulate the running loss
            running_loss += loss.item()
        
        epoch_loss = running_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss}")
        
        # Validation phase
        # set the model to validation mode
        model.eval()
        correct = 0
        total = 0
        # Disable gradient computation for validation
        with torch.no_grad():
            for data in val_loader:
                images, labels = data
                # Move validation data to the appropriate device
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                # Get the predicted class
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()

        print(f"Validation accuracy: {100 * correct / total}%")

    return model

## 6. Inference

The inference function is designed to predict the class label of a given image using a trained neural network model. This process involves preprocessing the image, feeding it through the model, and interpreting the model's output to determine the most likely class.

#### Key Steps in the Inference Process:

1. **Image Preprocessing**:
   - The image is preprocessed to match the input requirements of the model. This includes resizing, normalization, and converting the image to a tensor format compatible with PyTorch.

2. **Model Evaluation**:
   - The model is set to evaluation mode using `model.eval()`, which ensures that layers like dropout and batch normalization behave appropriately during inference.

3. **Prediction**:
   - The preprocessed image is passed through the model to obtain the output logits.
   - The `torch.max` function is used to determine the class with the highest predicted probability, which is returned as the predicted label.

In [2]:
def predict_image(model, image_path):
    """
    Predict the class of a sample image.
    
    Args:
        model: The trained model.
        image_path (str): Path to the image to predict.
    
    Returns:
        int: Predicted class label.
    """
    transform = T.Compose([
        T.Resize((256, 256)),
        T.ToTensor(),
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    image = Image.open(image_path).convert("RGB")
    # Apply the transformations and add a batch dimension
    image = transform(image).unsqueeze(0)
    image = image.to(device)

    model.eval() # Set the model to evaluation mode
    with torch.no_grad():
        outputs = model(image)
        _, predicted = torch.max(outputs, 1)
        return predicted.item()