# Intro to Machine Learning for Plant Sciences (ML4PS2025) - Deep Learning lab

## Introduction

Python has multiple high quality libraries for Deep Learning: **torch**, **tensorflow**, **keras**. 

These libraries are standard, documented, fairly easy to use and highly optimized. Within their ecosystem, they cover all components and aspects of the trade: data processing and loading, model design and training, optimization and parallelization, testing and evaluating.

Here are some libraries within this ecosystem that may be useful:
- **numpy** for handling data
- **pandas** for working with datasets
- **scipy** for optimization and maths problems
- **rasterio** for handling raster data (eg satellite imagery)
- **sklearn** for metrics and testing
- **wandb** (Weight and Biases) for hyperparameter tuning and keeping track of models trained
- **lightning** (on top of **pytorch**) for scalability and deployment with **pytorch**
... and so many more!




In [None]:
import urllib
import matplotlib.pyplot as plt

import torch
import torchvision
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torchvision import models
from torchvision import transforms
from torchvision.datasets import OxfordIIITPet

import sklearn
from sklearn.metrics import accuracy_score

### Training a Computer Vision model

Simple classification framework


In [None]:
# Function definitions
def evaluate_model(model, loader):
    """Get predictions and true labels"""
    model.eval()
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for inputs, labels in loader:
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    return all_labels, all_preds

In [None]:
# Normalization
temp_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

dataset_temp = OxfordIIITPet(root='./data', split='trainval', download=True, transform=temp_transform) # Temporary
temp_loader = DataLoader(dataset_temp, batch_size=32, shuffle=False) # Temporary

def compute_mean_std(loader):
    """Compute mean and std for normalization"""
    mean = 0.
    std = 0.
    total_images = 0

    for images, _ in loader:
        batch_samples = images.size(0)
        images = images.view(batch_samples, images.size(1), -1)
        mean += images.mean(2).sum(0)
        std += images.std(2).sum(0)
        total_images += batch_samples

    mean /= total_images
    std /= total_images
    return mean, std

mean, std = compute_mean_std(temp_loader)

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean.tolist(), std.tolist())
])

In [None]:
# Train Test Validation split
from sklearn.model_selection import train_test_split

# Load official datasets
trainval_dataset = OxfordIIITPet(root='./data', split='trainval', transform=transform)
test_dataset  = OxfordIIITPet(root='./data', split='test', transform=transform)
class_names = trainval_dataset.classes

# Create DataLoaders
train_dataset, val_dataset = train_test_split(trainval_dataset, test_size=0.2, random_state=42)
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)


In [None]:
# Statistics
num_classes = len(class_names)
print(f"{len(train_dataset)} images (in train set) - {len(val_dataset)} images (in val set) - {len(test_dataset)} images (in test set) \n{num_classes} classes")

# Check dataset balance
def plot_class_distribution(dataset, class_names):
    class_counts = [0] * len(class_names)
    for _, label in dataset:
        class_counts[label] += 1

    plt.figure(figsize=(10, 5))
    plt.bar(class_names, class_counts)
    plt.xticks(rotation=90)
    plt.xlabel('Classes')
    plt.ylabel('Number of images')
    plt.title('Class Distribution')
    plt.show() 

plot_class_distribution(train_dataset, class_names)

 #### Training

In [None]:
# Simple CNN model
class SimpleCNN(torch.nn.Module):
    def __init__(self, num_classes):
        super(SimpleCNN, self).__init__()
        self.conv1 = torch.nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
        #self.conv2 = torch.nn.Conv2d(16, 32, kernel_size=3, padding=1)
        self.pool = torch.nn.MaxPool2d(kernel_size=2, stride=2, padding=0)
        self.fc1 = torch.nn.Linear(16 * 112 * 112, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        #x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        return x
    
model = SimpleCNN(num_classes=len(class_names)) # Model instantiation

# Display the architecture of the model
print(model)

In [None]:
# Training 
criterion = torch.nn.CrossEntropyLoss() # Loss function
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # Optimizer
num_epochs = 10

train_losses, val_losses = list(), list()
for epoch in range(num_epochs):
    model.train() # training mode
    running_loss = 0.0
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    epoch_loss = running_loss / len(train_loader)
    train_losses.append(epoch_loss)

    model.eval()  # evaluation mode
    val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in val_loader:
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
        val_loss /= len(val_loader)
        val_losses.append(val_loss)

    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss:.4f} - Val Loss: {val_loss:.4f}')
print('Training complete')

# Plot training and validation loss
plt.plot(range(num_epochs), train_losses, label='Train Loss')
plt.plot(range(num_epochs), val_losses, label='Val Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
# Testing
all_labels, all_preds = evaluate_model(model, train_loader)
print(f"Accuracy on train set: {accuracy_score(all_labels, all_preds)}")

all_labels, all_preds = evaluate_model(model, test_loader)
print(f"Accuracy on test set: {accuracy_score(all_labels, all_preds)}")

In [None]:
# TODO improve the model

# TODO change the learning rate, number of epochs, batch size

# TODO think about overfitting: regularization, dropout, data augmentation, early stopping

# TODO add more metrics: confusion matrix, F1-score

#### More real use case: using a pretrained model

In [None]:
# Load pretrained ResNet18
resnet_model = models.resnet18(weights="ResNet18_Weights.IMAGENET1K_V1") # Model instantiation
resnet_model.fc = torch.nn.Linear(resnet_model.fc.in_features, num_classes) # Adapt last layer to current use case

# Display the architecture of the model
print(resnet_model)

In [None]:
# Training 
criterion = torch.nn.CrossEntropyLoss() # Loss function
optimizer = torch.optim.AdamW(resnet_model.parameters(), lr=1e-4, weight_decay=1e-4) # Optimizer
num_epochs = 10

train_losses, val_losses = list(), list()
for epoch in range(num_epochs):
    resnet_model.train() # training mode
    running_loss = 0.0
    for inputs, labels in train_loader:
        optimizer.zero_grad()
        outputs = resnet_model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    epoch_loss = running_loss / len(train_loader)
    train_losses.append(epoch_loss)

    resnet_model.eval()  # evaluation mode
    val_loss = 0.0
    with torch.no_grad():
        for inputs, labels in val_loader:
            outputs = resnet_model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
        val_loss /= len(val_loader)
        val_losses.append(val_loss)

    print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {epoch_loss:.4f} - Val Loss: {val_loss:.4f}')
print('Training complete')

# Plot training and validation loss
plt.plot(range(num_epochs), train_losses, label='Train Loss')
plt.plot(range(num_epochs), val_losses, label='Val Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()
plt.show()

In [None]:
# TODO evaluate the ResNet model

In [None]:
# TODO Try freezing all layers but the classifier head and retrain
if False:
    for param in model.parameters():
        param.requires_grad = False
        
    model.fc = torch.nn.Linear(model.fc.in_features, num_classes)

#### So what is actually going on?

In [None]:
# TODO check the output of the first layer

In [None]:
# TODO dive into class wise metrics

### Doing image segmentation

Try doing it all yourself on an image segmentation task

In [None]:
from torchvision import transforms
from torchvision.datasets import OxfordIIITPet
from torch.utils.data import DataLoader
import segmentation_models_pytorch as smp

# Standard ImageNet normalization values
IMAGE_NET_MEAN = [0.485, 0.456, 0.406]
IMAGE_NET_STD = [0.229, 0.224, 0.225]

def denormalize(img):
    """Denormalize an ImageNet image tensor for visualization."""
    mean = torch.tensor(IMAGE_NET_MEAN).view(3,1,1)
    std = torch.tensor(IMAGE_NET_STD).view(3,1,1)
    img = img * std + mean 
    img = img.permute(1,2,0) 
    img = img.clamp(0,1)
    return img

In [None]:
# Transforms for images and masks
img_transform = transforms.Compose([
    transforms.Resize((128,128)),
    transforms.ToTensor(),
    transforms.Normalize(IMAGE_NET_MEAN, IMAGE_NET_STD)
])

mask_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.PILToTensor(), 
])

# Load dataset (with masks)
dataset = OxfordIIITPet(
    root='./data', 
    download=True, 
    transform=img_transform, 
    target_types='segmentation'
)

def collate_fn(batch):
    images = torch.stack([item[0] for item in batch])
    masks = torch.stack([mask_transform(item[1]) for item in batch])
    return images, masks

loader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=collate_fn)

# Load pretrained U-Net (encoder pretrained on ImageNet)
model = smp.Unet(encoder_name="resnet18", encoder_weights="imagenet", in_channels=3, classes=1)

In [None]:
# Visualize some predictions
model.eval()

images, masks = next(iter(loader)) 
masks = masks.float() / masks.max() # Convert masks to 0/1 (background vs pet)
with torch.no_grad():
    images_resized = torch.nn.functional.interpolate(images, size=(128,128))
    outputs = model(images_resized)
    preds = torch.sigmoid(outputs)


fig, axes = plt.subplots(4, 3, figsize=(12, 12))
for i in range(4):
    axes[i,0].imshow(denormalize(images[i]))
    axes[i,0].set_title("Image")
    axes[i,0].axis('off')
    
    axes[i,1].imshow(masks[i][0], cmap='gray')
    axes[i,1].set_title("True Mask")
    axes[i,1].axis('off')
    
    axes[i,2].imshow(preds[i][0], cmap='gray')
    axes[i,2].set_title("Pred Mask")
    axes[i,2].axis('off')

plt.tight_layout()
plt.show()

In [None]:
images, masks = next(iter(loader)) 

masks.bincount()