transfer learning with mixed dataset
all the hyperparameters are the same as in the paper

In [None]:
from google.colab import drive
drive.mount('/content/drive')

import os
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, ConcatDataset

from torch import nn, optim
from torchvision import models

In [None]:
# Paths
root_dir = '/content/drive/MyDrive/plant_village_dataset'

# These folders contain a mix of PlantVillage and PlantDoc images.
# The classes are reduced to only those present in both datasets
# Each class has a maximum of 1000 images
train = os.path.join(root_dir, 'train_downsampled')
val = os.path.join(root_dir, 'val_downsampled')



# Transforms

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])



In [None]:
# Load datasets
train_dataset = datasets.ImageFolder(train, transform=transform)
val_dataset = datasets.ImageFolder(val, transform=transform)



# Dataloaders
batch_size = 100
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

In [None]:
print(f"Loaded {len(train_dataset)} training images")

print(f"Number of classes: {len(train_dataset.classes)}")
print(f"Loaded {len(val_dataset)} validation images")

In [None]:
# Load pretrained AlexNet
alexnet = models.alexnet(pretrained=True)

# Freeze feature extractor
for param in alexnet.features.parameters():
    param.requires_grad = False

# Replace classifier to match number of classes
num_classes = len(train_dataset.classes)
alexnet.classifier[6] = nn.Linear(4096, num_classes)

# Move to device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
alexnet = alexnet.to(device)

# Loss and optimizer (only train classifier params)
criterion = nn.CrossEntropyLoss()

# Optimizer as set in the paper
optimizer = torch.optim.SGD(
    alexnet.classifier.parameters(),  # only train classifier
    lr=0.005,
    momentum=0.9,
    weight_decay=0.0005
)

# Learning rate scheduler
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=10,   # number of epochs before decreasing LR (â‰ˆ 30/3 as in paper)
    gamma=0.1       # factor by which to decrease LR
)

# Training loop
epochs = 30


In [None]:
#Lists to keep track of train and validation metrics
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

save_every = 5   # save model every 5 epochs

for epoch in range(epochs):
    alexnet.train()
    running_loss = 0.0
    correct_train = 0
    total_train = 0
    total_batches = 0


    # Training

    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = alexnet(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        # Loss
        running_loss += loss.item()
        total_batches += 1

        # Accuracy
        _, preds = outputs.max(1)
        correct_train += (preds == labels).sum().item()
        total_train += labels.size(0)

    epoch_train_loss = running_loss / total_batches
    epoch_train_acc = correct_train / total_train

    train_losses.append(epoch_train_loss)
    train_accuracies.append(epoch_train_acc)


    # Validation

    alexnet.eval()
    val_loss_sum = 0.0
    correct_val = 0
    total_val = 0

    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)

            outputs = alexnet(images)
            loss = criterion(outputs, labels)

            val_loss_sum += loss.item()

            _, preds = outputs.max(1)
            correct_val += (preds == labels).sum().item()
            total_val += labels.size(0)

    epoch_val_loss = val_loss_sum / len(val_loader)
    epoch_val_acc = correct_val / total_val

    val_losses.append(epoch_val_loss)
    val_accuracies.append(epoch_val_acc)


    # Step LR scheduler

    scheduler.step()

    print(f"Epoch {epoch+1}/{epochs} | "
          f"Train Loss: {epoch_train_loss:.4f}, Train Acc: {epoch_train_acc:.4f} | "
          f"Val Loss: {epoch_val_loss:.4f}, Val Acc: {epoch_val_acc:.4f}")


    # Save periodically

    if (epoch + 1) % save_every == 0:
        save_path = f"/content/drive/MyDrive/plant_village_dataset/alexnet_tl_epoch_{epoch+1}.pth"
        torch.save(alexnet.state_dict(), save_path)
        print(f"Saved model at epoch {epoch+1} -- {save_path}")

# Final Save
torch.save(alexnet.state_dict(),
           "/content/drive/MyDrive/plant_village_dataset/alexnet_tl_final.pth")
print("Final model saved")


In [None]:
#plotting training process
import matplotlib.pyplot as plt

plt.plot(train_losses, label="Train Loss")
plt.plot(val_losses, label="Val Loss")
plt.legend()
plt.show()

In [None]:
import matplotlib.pyplot as plt

plt.plot(train_accuracies, label="Train Accuracy")
plt.plot(val_accuracies, label="Val Accuracy")
plt.legend()
plt.show()