In [None]:
# Plant Disease Classification - Data Pipeline

This notebook demonstrates the use of refactored data modules for the PlantVillage dataset.

**Modules:**
- `data.dataset`: Custom PyTorch Dataset and DataLoader creation
- `data.transforms`: Image augmentation and normalization
- `data.utils`: Dataset splitting and subset creation utilities

In [None]:
# Standard imports
import os
import sys

# Add parent directory to path to import data modules
sys.path.insert(0, os.path.abspath('..'))

# Import our custom data modules
from data.dataset import MultiModalityDataset, create_dataloaders
from data.transforms import get_transforms
from data.utils import build_class_mapping, gather_samples, split_dataset, make_subset

import torch
from torch.utils.data import DataLoader

# ===========================
# Configuration
# ===========================
DATA_DIR = "/kaggle/input/plantvillage-dataset"  # üëà Change this to your dataset path
MODALITIES = ["color", "grayscale", "segmented"]
IMAGE_SIZE = 224  # Standard for pretrained models
BATCH_SIZE = 32

print("‚úÖ Modules imported successfully!")

In [None]:
# 1Ô∏è‚É£ Build class name ‚Üí ID mapping
class_names, class_to_idx = build_class_mapping(DATA_DIR, modality="color")

# 2Ô∏è‚É£ Gather all samples (paths + labels + modality)
samples = gather_samples(DATA_DIR, MODALITIES, class_to_idx)

print(f"Total samples found: {len(samples)}")
print(f"Number of classes: {len(class_names)}")
print(f"Example classes: {class_names[:5]}")

In [None]:
# 3Ô∏è‚É£ Train/Val/Test split
train, val, test = split_dataset(samples, test_size=0.15, val_size=0.18)
# Final: ~70% train / 15% val / 15% test

print(f"Train: {len(train)}, Val: {len(val)}, Test: {len(test)}")

In [None]:
# 4Ô∏è‚É£ Build datasets
train_dataset = MultiModalityDataset(train, get_transforms(IMAGE_SIZE, train=True))
val_dataset   = MultiModalityDataset(val, get_transforms(IMAGE_SIZE, train=False))
test_dataset  = MultiModalityDataset(test, get_transforms(IMAGE_SIZE, train=False))


# 5Ô∏è‚É£ DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
test_loader  = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)


print("‚úÖ DataLoaders are ready!")

In [None]:
# 6Ô∏è‚É£ Optional: Create subsets for quick testing/prototyping

# Small subset for quick testing
train_tiny = make_subset(train, 0.05)
train_tiny_dataset = MultiModalityDataset(train_tiny, get_transforms(IMAGE_SIZE, train=True))
train_tiny_loader = DataLoader(train_tiny_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

# Medium subset for hyperparameter tuning
train_medium = make_subset(train, 0.3)
train_medium_dataset = MultiModalityDataset(train_medium, get_transforms(IMAGE_SIZE, train=True))
train_medium_loader = DataLoader(train_medium_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)

print(f"Tiny subset: {len(train_tiny)} samples")
print(f"Medium subset: {len(train_medium)} samples")

In [None]:
import torch
import torch.nn as nn
from torchvision.models import vit_b_16, ViT_B_16_Weights
from torch.optim import Adam
from torch.optim.lr_scheduler import StepLR
import copy
import os

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


# 1Ô∏è‚É£ Load pretrained ViT
weights = ViT_B_16_Weights.IMAGENET1K_V1
model = vit_b_16(weights=weights)


# 2Ô∏è‚É£ Freeze feature layers (freeze everything except head)
for param in model.parameters():
    param.requires_grad = False


# 3Ô∏è‚É£ Replace the classification head
num_classes = len(class_names)  # from previous cell
model.heads = nn.Sequential(
    nn.Linear(model.heads.head.in_features, 512),
    nn.ReLU(),
    nn.Dropout(0.3),
    nn.Linear(512, num_classes)
)

model.to(device)


# 4Ô∏è‚É£ Define Loss & Optimizer (only head parameters train)
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.heads.parameters(), lr=1e-4)
scheduler = StepLR(optimizer, step_size=3, gamma=0.1)


# 5Ô∏è‚É£ Training + Validation Loop
SAVE_PATH = "best_vit_model.pth"  # Saved in working directory


def train_model(num_epochs=1000, patience=100):
    best_val_loss = float("inf")
    best_model_wts = copy.deepcopy(model.state_dict())
    no_improve_epochs = 0

    for epoch in range(num_epochs):
        # ---------- Training ----------
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0

        for batch in train_loader:
            images = batch["image"].to(device)
            labels = batch["label"].to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            train_correct += predicted.eq(labels).sum().item()
            train_total += labels.size(0)

        train_loss /= train_total
        train_acc = train_correct / train_total

        # ---------- Validation ----------
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0

        with torch.no_grad():
            for batch in val_loader:
                images = batch["image"].to(device)
                labels = batch["label"].to(device)

                outputs = model(images)
                loss = criterion(outputs, labels)

                val_loss += loss.item() * images.size(0)
                _, predicted = outputs.max(1)
                val_correct += predicted.eq(labels).sum().item()
                val_total += labels.size(0)

        val_loss /= val_total
        val_acc = val_correct / val_total

        # ---------- Scheduler step ----------
        scheduler.step()

        # ---------- Best model save ----------
        if val_loss < best_val_loss:
            best_val_loss = val_loss
            best_model_wts = copy.deepcopy(model.state_dict())
            torch.save(best_model_wts, SAVE_PATH)
            no_improve_epochs = 0
            improved = "‚úÖ (improved & saved)"
        else:
            no_improve_epochs += 1
            improved = ""

        print(f"Epoch [{epoch+1}/{num_epochs}] "
              f"| Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} "
              f"| Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} "
              + improved)

        # ---------- Optional Early Stopping ----------
        if patience is not None and no_improve_epochs >= patience:
            print(f"‚èπ Early stopping at epoch {epoch+1} ‚Äî no improvement for {patience} epochs.")
            break

    print("üèÅ Training finished!")

    # Load best weights before returning
    model.load_state_dict(best_model_wts)
    return model


model = train_model(num_epochs=1000, patience=100)
print("‚úÖ Best model restored & ready!")
