Import all dependancies

In [2]:
import os
import pandas as pd
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

# TODO: Handle relative paths correctly so we can all run it independantly and add to the readme where the dataset needs to be stored
dataset_path = r"C:\Users\chris\Desktop\University\Code\ComputerVision\ForestNetDataset"

Load the dataset into memory

In [20]:
# Join the directory with each CSV filename.
test_path = os.path.join(dataset_path, "test.csv")
train_path = os.path.join(dataset_path, "train.csv")
validation_path = os.path.join(dataset_path, "val.csv")

# Read the CSV files into pandas DataFrames.
test_df = pd.read_csv(test_path)
train_df = pd.read_csv(train_path)
val_df = pd.read_csv(validation_path)

# Create a mapping from the string labels to integers based on the training data.
labels = sorted(train_df["merged_label"].unique())
label_to_index = {label: idx for idx, label in enumerate(labels)}
print("Label mapping:", label_to_index)


# FOR MODEL DEVELOPMENT JUST USE THE FIRST 128 SAMPLES FROM THE TRAINING SET
train_df = train_df.head(128)

Label mapping: {'Grassland shrubland': 0, 'Other': 1, 'Plantation': 2, 'Smallholder agriculture': 3}


Define DataLoaders for the training and test set

In [28]:

# This class implements the function __getitem__ which means it can be passed into the DataLoader class from pytorch 
# which makes the batch processing much more seamless.
class ForestNetDataset(Dataset):
    def __init__(self, df, dataset_path, transform=None, label_map=None):
        """
        Args:
            df (pd.DataFrame): DataFrame containing the image paths and labels.
            dataset_path (str): The base directory for the images.
            transform (callable, optional): A function/transform to apply to the images.
            label_map (dict, optional): Mapping from label names to integers.
        """
        self.df = df
        self.dataset_path = dataset_path
        self.transform = transform
        self.label_map = label_map

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        # Get the row for this index
        row = self.df.iloc[idx]
        # Build the full image path
        image_rel_path = row["example_path"] + "/images/visible/composite.png"
        image_path = os.path.join(self.dataset_path, image_rel_path)
        # Open the image (convert to RGB in case it's grayscale)
        image = Image.open(image_path).convert("RGB")
        if self.transform:
            image = self.transform(image)
        # Get the label and convert it to an integer using the provided map
        label = row["merged_label"]
        if self.label_map is not None:
            label = self.label_map[label]
        return image, label

# --- Image Transforms ---
# Resize images to 224x224, convert them to tensors, and normalize.
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),

    # TODO: Look into calculating these values for our dataset. It probably has a lot more green than other
    # datasets.
    # These normalization values are typical for natural images.
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# --- Create Datasets and DataLoaders ---
train_dataset = ForestNetDataset(train_df, dataset_path, transform=transform, label_map=label_to_index)
test_dataset = ForestNetDataset(test_df, dataset_path, transform=transform, label_map=label_to_index)

batch_size = 1
# TO DO: Experiment with different num_workers (I don't know what this does)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=1)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=1)

In [22]:
print(train_dataset.__getitem__(121))

(tensor([[[-1.8439, -1.8439, -1.7925,  ..., -1.9124, -1.9295, -1.9295],
         [-1.8782, -1.8782, -1.7925,  ..., -1.9124, -1.9467, -1.9467],
         [-1.9124, -1.8782, -1.8097,  ..., -1.9467, -1.9638, -1.9638],
         ...,
         [-1.8953, -1.9295, -1.9124,  ..., -1.8782, -1.8782, -1.8953],
         [-1.8953, -1.9295, -1.9124,  ..., -1.8953, -1.8439, -1.8610],
         [-1.9124, -1.9295, -1.9124,  ..., -1.8953, -1.8268, -1.8439]],

        [[-1.6681, -1.6506, -1.6155,  ..., -1.6331, -1.6681, -1.6681],
         [-1.6856, -1.6681, -1.6155,  ..., -1.6506, -1.6856, -1.6856],
         [-1.7031, -1.6681, -1.6331,  ..., -1.6856, -1.6856, -1.7031],
         ...,
         [-1.6155, -1.6506, -1.6331,  ..., -1.6331, -1.6331, -1.6506],
         [-1.5980, -1.6506, -1.6331,  ..., -1.6506, -1.5980, -1.6155],
         [-1.5980, -1.6331, -1.6331,  ..., -1.6331, -1.5805, -1.5805]],

        [[-1.5953, -1.5953, -1.5779,  ..., -1.6302, -1.6650, -1.6650],
         [-1.6127, -1.6127, -1.5604,  ..., -