Notebook header, imports, output dirs

In [13]:
# 03_train_efficientnet_vit.ipynb
# Goal:
# - Rebuild datasets & dataloaders from metadata
# - Define a shared training loop
# - Train EfficientNet and ViT
# - Evaluate on test set

import os
import time
import copy
from pathlib import Path
from collections import Counter
import json

import torch
import torch.nn as nn
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler

from PIL import Image
import torchvision.transforms as T

import timm
from sklearn.metrics import f1_score, classification_report

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

# Paths (same as previous notebooks)
METADATA_DIR = Path("./metadata")
LABEL_MAPPING_PATH = METADATA_DIR / "label_mapping.json"
DATASET_INDEX_PATH = METADATA_DIR / "dataset_index.json"

MODELS_DIR = Path("./models")
MODELS_DIR.mkdir(parents=True, exist_ok=True)


Using device: cuda


Check that DataLoaders and num_classes exist

In [21]:
# Cell 2 â€“ Rebuild num_classes, datasets and dataloaders from JSON

# --- Load metadata ---
with open(LABEL_MAPPING_PATH, "r") as f:
    label_mapping = json.load(f)

with open(DATASET_INDEX_PATH, "r") as f:
    dataset_index = json.load(f)

num_classes = len(label_mapping["classes"])
print("Number of classes:", num_classes)
print("Number of samples in dataset_index:", len(dataset_index))

# Helper maps
id_to_label = {c["id"]: c["canonical_label"] for c in label_mapping["classes"]}
label_to_id = {v: k for k, v in id_to_label.items()}

# --- Field-poor classes (for augmentations) ---
FIELD_POOR_THRESHOLD = 5
field_count_by_class = {
    c["id"]: c.get("field_count", 0)
    for c in label_mapping["classes"]
}
field_poor_classes = {
    cid for cid, cnt in field_count_by_class.items()
    if cnt <= FIELD_POOR_THRESHOLD
}
print("Field-poor classes:", len(field_poor_classes))

# --- Transforms (same logic as notebook 2) ---
IMAGENET_MEAN = [0.485, 0.456, 0.406]
IMAGENET_STD  = [0.229, 0.224, 0.225]
IMG_SIZE = 224

transform_pv_basic = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.RandomHorizontalFlip(p=0.5),
    T.ToTensor(),
    T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

transform_pv_field_style = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.RandomHorizontalFlip(p=0.5),
    T.RandomRotation(degrees=20),
    T.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.3, hue=0.1),
    T.RandomAffine(
        degrees=0,
        translate=(0.1, 0.1),
        scale=(0.8, 1.2)
    ),
    T.ToTensor(),
    T.RandomErasing(p=0.3, scale=(0.02, 0.15), ratio=(0.3, 3.3)),
    T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

transform_field = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.RandomHorizontalFlip(p=0.5),
    T.ColorJitter(brightness=0.3, contrast=0.3),
    T.ToTensor(),
    T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

transform_eval = T.Compose([
    T.Resize((IMG_SIZE, IMG_SIZE)),
    T.ToTensor(),
    T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
])

# --- Dataset class (same as notebook 2) ---
class PlantDataset(Dataset):
    def __init__(self, entries, transform_train=True):
        self.entries = entries
        self.transform_train = transform_train

    def __len__(self):
        return len(self.entries)

    def __getitem__(self, idx):
        item = self.entries[idx]
        img_path = item["path"]
        class_id = item["class_id"]
        domain = item.get("domain", "pv")

        img = Image.open(img_path).convert("RGB")

        if self.transform_train:
            if domain == "field":
                img = transform_field(img)
            elif domain == "pv":
                if class_id in field_poor_classes and torch.rand(1).item() < 0.5:
                    img = transform_pv_field_style(img)
                else:
                    img = transform_pv_basic(img)
            else:
                img = transform_pv_field_style(img)
        else:
            img = transform_eval(img)

        return img, class_id

# --- Split entries into train/val/test ---
train_entries = [e for e in dataset_index if e["split"] == "train"]
val_entries   = [e for e in dataset_index if e["split"] == "val"]
test_entries  = [e for e in dataset_index if e["split"] == "test"]

print("Train:", len(train_entries), "Val:", len(val_entries), "Test:", len(test_entries))

train_dataset = PlantDataset(train_entries, transform_train=True)
val_dataset   = PlantDataset(val_entries,   transform_train=False)
test_dataset  = PlantDataset(test_entries,  transform_train=False)

# --- WeightedRandomSampler for class balancing ---
train_class_counts = Counter(e["class_id"] for e in train_entries)
max_count = max(train_class_counts.values())
class_weights = {cid: max_count / cnt for cid, cnt in train_class_counts.items()}
sample_weights = [class_weights[e["class_id"]] for e in train_entries]

sampler = WeightedRandomSampler(
    weights=torch.DoubleTensor(sample_weights),
    num_samples=len(sample_weights),
    replacement=True
)

BATCH_SIZE = 32

train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    sampler=sampler,
    num_workers=4,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)

print("Batches -> Train:", len(train_loader), "Val:", len(val_loader), "Test:", len(test_loader))


Number of classes: 39
Number of samples in dataset_index: 55448
Field-poor classes: 39
Train: 44343 Val: 5542 Test: 5563
Batches -> Train: 1386 Val: 174 Test: 174


Helper: training and evaluation functions

In [22]:
def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    all_targets = []
    all_preds = []

    for batch_idx, (images, targets) in enumerate(loader):
        # DEBUG: print every 50 batches
        if batch_idx % 50 == 0:
            print(f"  [train] batch {batch_idx}/{len(loader)}")

        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)

        preds = outputs.argmax(dim=1)
        all_targets.append(targets.detach().cpu())
        all_preds.append(preds.detach().cpu())

    all_targets = torch.cat(all_targets).numpy()
    all_preds = torch.cat(all_preds).numpy()

    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = (all_targets == all_preds).mean()
    epoch_f1 = f1_score(all_targets, all_preds, average="macro")

    return epoch_loss, epoch_acc, epoch_f1



@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    all_targets = []
    all_preds = []

    for images, targets in loader:
        images = images.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        outputs = model(images)
        loss = criterion(outputs, targets)

        running_loss += loss.item() * images.size(0)

        preds = outputs.argmax(dim=1)
        all_targets.append(targets.detach().cpu())
        all_preds.append(preds.detach().cpu())

    all_targets = torch.cat(all_targets).numpy()
    all_preds = torch.cat(all_preds).numpy()

    epoch_loss = running_loss / len(loader.dataset)
    epoch_acc = (all_targets == all_preds).mean()
    epoch_f1 = f1_score(all_targets, all_preds, average="macro")

    return epoch_loss, epoch_acc, epoch_f1, all_targets, all_preds


Helper: create model, optimizer, scheduler

In [16]:
def create_model_and_optim(model_name, num_classes, lr=3e-4, weight_decay=1e-4, device=DEVICE):
    """
    model_name: any timm model, e.g. 'efficientnet_b0', 'vit_base_patch16_224'
    """
    model = timm.create_model(
        model_name,
        pretrained=False,    # keep False for now so it doesn't try to download weights
        num_classes=num_classes
    )

    model.to(device)

    optimizer = AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = CosineAnnealingLR(optimizer, T_max=20)

    criterion = nn.CrossEntropyLoss()

    return model, criterion, optimizer, scheduler


Generic training loop for one model

In [17]:
def train_model(
    model_name,
    num_classes,
    train_loader,
    val_loader,
    max_epochs=20,
    lr=3e-4,
    weight_decay=1e-4,
    device=DEVICE,
    early_stopping_patience=5
):
    print(f"Starting training for model: {model_name}")
    print("  -> Creating model and optimizer...")

    model, criterion, optimizer, scheduler = create_model_and_optim(
        model_name=model_name,
        num_classes=num_classes,
        lr=lr,
        weight_decay=weight_decay,
        device=device
    )

    print("  -> Model created, starting epochs...")

    best_val_f1 = -1.0
    best_state = None
    history = {
        "train_loss": [],
        "train_acc": [],
        "train_f1": [],
        "val_loss": [],
        "val_acc": [],
        "val_f1": []
    }

    epochs_without_improvement = 0

    for epoch in range(1, max_epochs + 1):
        start_time = time.time()

        train_loss, train_acc, train_f1 = train_one_epoch(
            model, train_loader, criterion, optimizer, device
        )
        val_loss, val_acc, val_f1, _, _ = evaluate(
            model, val_loader, criterion, device
        )

        scheduler.step()

        history["train_loss"].append(train_loss)
        history["train_acc"].append(train_acc)
        history["train_f1"].append(train_f1)
        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)
        history["val_f1"].append(val_f1)

        elapsed = time.time() - start_time
        print(
            f"Epoch {epoch:02d} | "
            f"Train loss: {train_loss:.4f}, acc: {train_acc:.3f}, F1: {train_f1:.3f} | "
            f"Val loss: {val_loss:.4f}, acc: {val_acc:.3f}, F1: {val_f1:.3f} | "
            f"time: {elapsed:.1f}s"
        )

        # Track best by validation macro F1
        if val_f1 > best_val_f1:
            best_val_f1 = val_f1
            best_state = copy.deepcopy(model.state_dict())
            epochs_without_improvement = 0
            print(f"New best val F1: {best_val_f1:.4f}")
        else:
            epochs_without_improvement += 1

        # Early stopping
        if epochs_without_improvement >= early_stopping_patience:
            print("Early stopping triggered.")
            break

    # Restore best weights if we have them
    if best_state is not None:
        model.load_state_dict(best_state)

    # Save model checkpoint
    ckpt_path = MODELS_DIR / f"{model_name}_best.pt"
    torch.save({
        "model_name": model_name,
        "state_dict": model.state_dict(),
        "num_classes": num_classes,
        "best_val_f1": best_val_f1,
        "history": history
    }, ckpt_path)
    print(f"Saved best checkpoint to: {ckpt_path}")

    # IMPORTANT: actually return values
    return model, history, best_val_f1


Train EfficientNet

In [18]:
eff_name = "efficientnet_b0"

model_eff, history_eff, best_val_f1_eff = train_model(
    model_name=eff_name,
    num_classes=num_classes,
    train_loader=train_loader,
    val_loader=val_loader,
    max_epochs=20,
    lr=3e-4,
    weight_decay=1e-4,
    device=DEVICE,
    early_stopping_patience=5
)

print("Best validation F1 (EfficientNet-B0):", best_val_f1_eff)


Starting training for model: efficientnet_b0
  -> Creating model and optimizer...
  -> Model created, starting epochs...
  [train] batch 0/1386
  [train] batch 50/1386
  [train] batch 100/1386
  [train] batch 150/1386
  [train] batch 200/1386
  [train] batch 250/1386
  [train] batch 300/1386
  [train] batch 350/1386
  [train] batch 400/1386
  [train] batch 450/1386
  [train] batch 500/1386
  [train] batch 550/1386
  [train] batch 600/1386
  [train] batch 650/1386
  [train] batch 700/1386
  [train] batch 750/1386
  [train] batch 800/1386
  [train] batch 850/1386
  [train] batch 900/1386
  [train] batch 950/1386
  [train] batch 1000/1386
  [train] batch 1050/1386
  [train] batch 1100/1386
  [train] batch 1150/1386
  [train] batch 1200/1386
  [train] batch 1250/1386
  [train] batch 1300/1386
  [train] batch 1350/1386
Epoch 01 | Train loss: 2.1528, acc: 0.401, F1: 0.395 | Val loss: 0.6751, acc: 0.790, F1: 0.734 | time: 377.4s
New best val F1: 0.7344
  [train] batch 0/1386
  [train] batch 5

Evaluate EfficientNet on test set

In [20]:
criterion = nn.CrossEntropyLoss()

test_loss_eff, test_acc_eff, test_f1_eff, y_true_eff, y_pred_eff = evaluate(
    model_eff, test_loader, criterion, DEVICE
)

print(f"EfficientNet-B0 test loss: {test_loss_eff:.4f}")
print(f"EfficientNet-B0 test acc:  {test_acc_eff:.3f}")
print(f"EfficientNet-B0 test F1:   {test_f1_eff:.3f}")

print("\nClassification report (EfficientNet-B0):")
print(classification_report(y_true_eff, y_pred_eff))


EfficientNet-B0 test loss: 0.0218
EfficientNet-B0 test acc:  0.994
EfficientNet-B0 test F1:   0.993

Classification report (EfficientNet-B0):
              precision    recall  f1-score   support

           0       0.98      0.98      0.98        63
           1       1.00      1.00      1.00        63
           2       1.00      1.00      1.00        28
           3       0.98      0.99      0.98       165
           4       1.00      1.00      1.00       115
           5       0.99      1.00      1.00       151
           6       1.00      1.00      1.00        86
           7       1.00      1.00      1.00       106
           8       1.00      0.94      0.97        52
           9       1.00      1.00      1.00       120
          10       0.99      1.00      1.00       117
          11       0.97      0.99      0.98        99
          12       0.98      1.00      0.99       118
          13       1.00      0.99      0.99       139
          14       1.00      1.00      1.00    

Train ViT

In [None]:
# Example: ViT Base 16x16 patch, 224 input
vit_name = "vit_base_patch16_224"

model_vit, history_vit, best_val_f1_vit = train_model(
    model_name=vit_name,
    num_classes=num_classes,
    train_loader=train_loader,
    val_loader=val_loader,
    max_epochs=20,
    lr=3e-4,
    weight_decay=1e-4,
    device=DEVICE,
    early_stopping_patience=5
)

print("Best validation F1 (ViT):", best_val_f1_vit)


Evaluate ViT on test set

In [None]:
criterion = nn.CrossEntropyLoss()

test_loss_vit, test_acc_vit, test_f1_vit, y_true_vit, y_pred_vit = evaluate(
    model_vit, test_loader, criterion, DEVICE
)

print(f"ViT test loss: {test_loss_vit:.4f}")
print(f"ViT test acc:  {test_acc_vit:.3f}")
print(f"ViT test F1:   {test_f1_vit:.3f}")

print("\nClassification report (ViT):")
print(classification_report(y_true_vit, y_pred_vit))


Plot training curves

In [None]:
import matplotlib.pyplot as plt

def plot_history(history, title_prefix="Model"):
    epochs = range(1, len(history["train_loss"]) + 1)

    plt.figure(figsize=(12, 4))
    # Loss
    plt.subplot(1, 3, 1)
    plt.plot(epochs, history["train_loss"], label="train")
    plt.plot(epochs, history["val_loss"], label="val")
    plt.title(f"{title_prefix} loss")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()

    # Accuracy
    plt.subplot(1, 3, 2)
    plt.plot(epochs, history["train_acc"], label="train")
    plt.plot(epochs, history["val_acc"], label="val")
    plt.title(f"{title_prefix} accuracy")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy")
    plt.legend()

    # F1
    plt.subplot(1, 3, 3)
    plt.plot(epochs, history["train_f1"], label="train")
    plt.plot(epochs, history["val_f1"], label="val")
    plt.title(f"{title_prefix} macro F1")
    plt.xlabel("Epoch")
    plt.ylabel("F1")
    plt.legend()

    plt.tight_layout()
    plt.show()

plot_history(history_eff, title_prefix="EfficientNet-B0")
plot_history(history_vit, title_prefix="ViT-B16")
