### Preliminaries

Import all the required libraries for setting up data, model training, and validation. 

In [None]:
import copy
import numpy as np
import random
import timm
import torch
import torch.nn as nn

from pathlib import Path
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import f1_score
from torch.nn.functional import mse_loss
from torch.utils.data import DataLoader, Subset, random_split
from torchvision import datasets, transforms
from tqdm import tqdm

Set all necessary constants to be used in setting up data, model training, and validation.

In [3]:
TRAIN_VERSION = 6

MANUAL_SEED = 27
NUMBER_OF_CLASSES = 4

RAW_DATA_DIR = "../training_data"

EPOCHS = 30
LAMBDA_WEIGHT = 0.5

NORMALIZE_MEAN = [0.485, 0.456, 0.406]
NOMRALIZE_STD = [0.229, 0.224, 0.225]

Set manual seeds in random libraries to ensure reproducibility of results. 

In [4]:
random.seed(MANUAL_SEED)
np.random.seed(MANUAL_SEED)
torch.manual_seed(MANUAL_SEED)
torch.cuda.manual_seed_all(MANUAL_SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

### Data Preprocessing

Start by loading in the dataset and determining the training set and validating set sizes. That is, 80% of the dataset will be used for training and the remaining 20% will be used for validation.

Based on these determined sizes, the dataset will be split into their respective subsets randomly (with a pre-determined seed).

In [5]:
# Load images from directory into an iterable object
base_dataset = datasets.ImageFolder(root=RAW_DATA_DIR)          

# Determine subset sizes
train_size = int(0.8*len(base_dataset))
validation_size = len(base_dataset)-train_size

# Split dataset into subsets randomly (with a pre-determined seed)
generator = torch.Generator().manual_seed(MANUAL_SEED)
train_dataset, validation_dataset = random_split(base_dataset, [train_size, validation_size], generator=generator)

To address class imbalance in the dataset, data augmentation will be applied to underrepresented classes to match the number of samples in the most represented class. Specifically, images from the *Cordana*, *Healthy*, and *Pestalotiopsis* classes will be used to create augmented images to achieve a more balanced class distribution.

Firstly, define transform functions that will be applied to images of the training dataset. These transformations include resizing, normalization, and data augmentation techniques, such as flipping, rotation.

In [6]:
# Base transform function, will be applied to every image.
transform_self_base = transforms.Compose([
    transforms.Resize((224, 224)),
])

# Augmented transform function, will be applied in order to create new
# images for underrepresented classes. 
transform_self_aug = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),          
    transforms.RandomVerticalFlip(p=0.3),
    transforms.RandomRotation(15),
    transforms.ColorJitter(
        brightness=0.3,
        contrast=0.3,
        saturation=0.3,
        hue=0.05
    ),
    transforms.RandomResizedCrop((224, 224), scale=(0.8, 1.0)),
])

Next, define a subset processing function that will load original images (with base transform) for all classes but will also make use of the augmented transform function to produce augmented images **only for** underrepresented classes.

In [7]:
def load_class_dataset(target_class: str, num_variants: int):
    # Instantiate a subset of the training dataset based on given target_class.
    class_index = train_dataset.dataset.class_to_idx[target_class]
    indices = [i for i, (_, label) in enumerate(train_dataset.dataset.samples) if label == class_index]
    train_subset = Subset(train_dataset.dataset, indices)

    # Reverse the mapping to obtain classes using indices.
    index_to_class = {v: k for k, v in train_dataset.dataset.class_to_idx.items()}

    # For each class, retain original images but also produce augmented images
    # if the class is underrepresented.
    class_images = []
    for i in range(len(train_subset)):
        image, label = train_subset[i]
        
        # Append to class images list the original image with base transform.
        base_image = transform_self_base(image)
        class_images.append((base_image, label))

        # Obtain class name.
        class_name = index_to_class[label]

        # If class is Sigatoka, skip augmented image production.
        if class_name == "sigatoka":
            continue

        # Product (num_variants) augmented images using the augmented transform 
        # and append to class_images list.
        for _ in range(num_variants):
            augmented_image = transform_self_aug(image)
            class_images.append((augmented_image, label))

    return class_images

Finally, call this processing function to each of the classes. Collate the augmented images in *augmented_train_dataset*.

In [None]:
classes = ["sigatoka", "cordana", "healthy", "pestalotiopsis"]
augemented_train_dataset = []
for class_name in classes:
    class_dataset = load_class_dataset(class_name, 4)
    augemented_train_dataset += class_dataset

### Training Proper

Now that the training set is ready and all classes have been balanced to have approximately equal sample sizes, model training can begin.

For this AI model, a pre-trained Vision Transformer (ViT) will be used, but with a twist: self-ensembling. Two ViTs will be deployed and trained, where one acts as the teacher and the other as the student. The teacher model is exposed to weakly augmented images, allowing it to remain stable and consistent. Meanwhile, the student model is trained on strongly augmented images, encouraging it to learn more robust features. This setup enables the teacher to guide and correct the student by enforcing consistency between their outputs.

Additionally, cross-entropy loss is used to train the model on labeled data. To help the student learn from the teacher, a consistency loss (like Mean Squared Error) is also used to compare their predictions on augmented images. Using both losses together helps the student become more accurate and better at handling noisy or challenging inputs.

To start this process, the weak (teacher) and strong (student) augmentation functions are firstly defined. For the weak augmentation, only a random horizontal flip is applied before converting the image to a Tensor and normalizing its values. Meanwhile, for the strong augmentation, a few transformations are defined.

Additionally, a validation transform function is also defined as it will be used to convert validation set images to Tensors and normalize their values.

In [9]:
# Weak transform for teacher training -> more stable and consistent output.
weak_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=NORMALIZE_MEAN, 
        std=NOMRALIZE_STD
    ),
])

# Strong transform for student training -> learn robust features.
strong_transform = transforms.Compose([
    transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
    transforms.ColorJitter(),
    transforms.ToTensor(),
    transforms.RandomErasing(p=0.5),
    transforms.Normalize(
        mean=NORMALIZE_MEAN, 
        std=NOMRALIZE_STD
    ),
])

# Convert validation set images to Tensors and normalize for processing later on.
validation_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=NORMALIZE_MEAN, 
        std=NOMRALIZE_STD
    ),
])

A DualTransformDataset is then defined to return both a weakly and a strongly augmented version of each base image.

The new augmented training dataset, wrapped using the DualTransformDataset, is instantiated using the balanced dataset prepared in the Data Preprocessing section. Additionally, a new validation dataset is also created, now ensuring that all images are properly formatted and compatible for model evaluation.

To enable batch processing during training and validation, a DataLoader is created using the augmented and validation training dataset.

In [10]:
# Define DualTransformDataset object that returns weakly- and strongly-augmented
# images, given the base image.
class DualTransformDataset(torch.utils.data.Dataset):
    def __init__(self, dataset, weak_transform, strong_transform):
        self.dataset = dataset
        self.weak_transform = weak_transform
        self.strong_transform = strong_transform

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        return self.weak_transform(image), self.strong_transform(image), label

    def __len__(self):
        return len(self.dataset)

# Instantiate new datasets.
augemented_train_dataset = DualTransformDataset(augemented_train_dataset, weak_transform, strong_transform)
validation_dataset = [(validation_transform(image), label) for image, label in validation_dataset]

# Instantiate loaders using datasets, for batch processing later on.
train_loader = DataLoader(augemented_train_dataset, batch_size=32, shuffle=True)
validation_loader = DataLoader(validation_dataset, batch_size=32, shuffle=False)

Class weights are then computed based on the distribution of labels in the training dataset to address any lingering class imbalances. The loss function is also defined using cross-entropy with label smoothing and the computed class weights to improve fairness.

In [11]:
# Use GPU if available.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Collect class labels and compute weights.
targets = [label for _, _, label in augemented_train_dataset]
class_weights = compute_class_weight("balanced", classes=np.unique(targets), y=targets)

# Convert weights to Tensors and move to computing device.
class_weights = torch.tensor(class_weights, dtype=torch.float).to(device)

# Define loss function.
criterion = nn.CrossEntropyLoss(label_smoothing=0.1, weight=class_weights)

Next, the student and teacher models are instantiated using a pretrained ViT model. 

In [12]:
# Instantiate student and teacher models using a pretrained Vision Transformer (ViT) model.
student = timm.create_model("vit_base_patch16_224", pretrained=True, num_classes=NUMBER_OF_CLASSES)
teacher = copy.deepcopy(student)

# Move models to computing device.
student = student.to(device)
teacher = teacher.to(device)

The *update_teacher()* function is also defined to update the teacher model’s weights using an exponential moving average (EMA) of the student model’s weights. The teacher is updated gradually to ensure it remains a stable target during training. The *ema_decay* parameter controls how much the teacher retains its previous values versus adopting the student’s current parameters.

In [13]:
def update_teacher(student, teacher, ema_decay=0.99):
    for student_param, teacher_param in zip(student.parameters(), teacher.parameters()):
        teacher_param.data = ema_decay*teacher_param.data + (1-ema_decay)*student_param.data    # Exponential Moving Average (EMA)

The model training loop can now begin, after everything has been set.

The training loop runs for a defined number of epochs, using the Adam optimizer to update the student model’s parameters. During each epoch, the student model is trained on strongly augmented images, while the teacher model remains in evaluation mode and predicts on weakly augmented images. Two losses are computed: cross-entropy loss for supervised learning, and a consistency loss (mean squared error) between student and teacher predictions. These are combined with a weighting factor.

After each training step, the teacher model is updated using the Exponential Moving Average (EMA) of the student’s parameters. At the end of every epoch, the student model is evaluated on the validation set, and key metrics such as loss, accuracy, and macro F1 score are logged.

In [None]:
optimizer = torch.optim.Adam(student.parameters(), lr=3e-4, weight_decay=1e-4)  # Set up optimizer for the student model.
best_val_acc = 0.0                                                              # Initialize best validation accuracy to 0.0%.

for epoch in range(EPOCHS):
    # Set student to training mode and teacher to evaluation mode.
    student.train()
    teacher.eval()

    running_loss = 0.0
    total, correct = 0, 0

    # Create progress bar for current epoch.
    loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")

    for weak_img, strong_img, labels in loop:
         # Move images and labels to the computing device.
        weak_img, strong_img, labels = weak_img.to(device), strong_img.to(device), labels.to(device)

        # Forward pass through teacher (no gradient computation).
        with torch.no_grad():
            teacher_output = teacher(weak_img)

        student_output = student(strong_img)                                    # Forward pass through student.
        loss_ce = criterion(student_output, labels)                             # Compute cross-entropy loss on labeled data.
        loss_consistency = mse_loss(student_output, teacher_output.detach())    # Compute consistency loss between student and teacher outputs.
        loss = loss_ce + LAMBDA_WEIGHT * loss_consistency                       # Combine losses.

        # Backpropagation and optimization step.
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        update_teacher(student, teacher)                                        # Update teacher model via exponential moving average of student.

        # Track training loss and accuracy.
        running_loss += loss.item()
        _, predicted = torch.max(student_output.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

        # Update progress bar with current loss and accuracy.
        loop.set_postfix(loss=running_loss / (total // 32), accuracy=100. * correct / total)
    
    # Evaluation phase, set student model to evaluation mode.
    student.eval()

    # Initialize evaluation variables.
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    true_labels = []
    predicted_labels = []

    with torch.no_grad():
        for image, labels in validation_loader:
            image, labels = image.to(device), labels.to(device)

            outputs = student(image)

            # Compute loss.
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            # Compute accuracy.
            _, predicted = torch.max(outputs, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

            # Collect labels for F1 score.
            true_labels.extend(labels.cpu().numpy())
            predicted_labels.extend(predicted.cpu().numpy())

    # Calculate validation accuracy
    avg_val_loss = val_loss / len(validation_loader)
    val_accuracy = 100. * val_correct / val_total

    # Print validation accuracy and f1 score
    f1 = f1_score(true_labels, predicted_labels, average='macro')
    print(f"Validation Loss: {avg_val_loss:.4f} | Accuracy: {val_accuracy:.2f}% | F1 Score (macro): {f1:.4f}")
    
    # Log accuracy and f1 score to text file
    log_path = Path("../models/model1") / f"version_{TRAIN_VERSION}" / "training_log.txt"
    log_path.parent.mkdir(parents=True, exist_ok=True)
    with open(log_path, "a") as f:
        f.write(f"Epoch {epoch+1}: Val Loss = {avg_val_loss:.4f}, Accuracy = {val_accuracy:.2f}%, F1 = {f1:.4f}\n")

    # Save or overwrite file if current model (student and teacher) has better accuracy
    if val_accuracy > best_val_acc:
        best_val_acc = val_accuracy

        save_path = Path("../models/model1/") / f"version_{TRAIN_VERSION}"
        save_path.mkdir(parents=True, exist_ok=True)

        torch.save(student.state_dict(), save_path / "student.pth")
        torch.save(teacher.state_dict(), save_path / "teacher.pth")
        print(f"Saved new best model at Epoch {epoch+1}")