# 02 — Train PlantVillage Disease Classifier


## Overview — What This Notebook Does

This notebook trains a **image classification model** (Vision Transformer) on the processed dataset.

### Step-by-step workflow

1. **Set paths and hyperparameters**  
   - Point to the processed dataset, e.g. `data/processed/cassava`.  
   - Define image size, batch size, learning rate, number of epochs, and device (CPU/GPU).

2. **Define data transforms**  
   - For **training**: apply resizing plus random augmentations (flips / rotations) to increase robustness and reduce overfitting.  
   - For **validation**: apply only resizing and tensor conversion to keep evaluation consistent.

3. **Create training and validation datasets**  
   - Use `torchvision.datasets.ImageFolder` on:
     - `.../train` → training images grouped in class folders  
     - `.../val`   → validation images grouped in class folders  
   - `ImageFolder` automatically builds:
     - `classes` → list of class names  
     - `class_to_idx` → mapping from class name to numeric label.

4. **Wrap datasets in DataLoaders**  
   - `DataLoader` batches images and shuffles the training data each epoch.  
   - This gives us mini-batches like `(images, labels)` for efficient training on GPU/CPU.

5. **Initialize the model (ViT)**  
   - Load a pretrained **Vision Transformer** (`vit_b_16`) with ImageNet weights.  
   - Replace the final classification head with a new `Linear` layer whose output size equals the number of disease classes in this dataset.  
   - Move the model to the selected device (`cuda` if available).

6. **Define loss function and optimizer**  
   - Use **CrossEntropyLoss** for multi-class classification.  
   - Use **Adam** optimizer with the chosen learning rate to update all model parameters.

7. **Training loop (per epoch)**  
   - Set the model to `train()` mode.  
   - For each batch:
     1. Move images and labels to the device.  
     2. Do a forward pass to get predictions.  
     3. Compute loss between predictions and true labels.  
     4. Backpropagate (`loss.backward()`).  
     5. Update weights (`optimizer.step()`).  
     6. Zero the gradients for the next batch.  
   - Track the average training loss for monitoring.

8. **Validation after training (or per epoch)**  
   - Set the model to `eval()` mode and disable gradients (`torch.no_grad()`).  
   - Run the model on the validation DataLoader.  
   - Compute overall accuracy by comparing predicted labels to true labels.  
   - This tells us how well the model generalizes to unseen images.

9. **Save the trained model**  
   - Save the learned weights (e.g. `cassava_vit_best.pth`) so they can be reused later for:
     - Inference notebooks  
     - Comparison with other models  
     - Deployment in a demo or chatbot.


In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, models, datasets
from torch.optim import Adam
from tqdm import tqdm
import matplotlib.pyplot as plt
import os


In [14]:
DATA_DIR = "../data/processed/plantVillage"
BATCH_SIZE = 32
LR = 1e-4
EPOCHS = 15
IMG_SIZE = 224
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


In [13]:
import os

print(os.listdir("../data"))

['dataset_cards.md', 'interim', 'processed', 'raw']


In [18]:
train_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
])

val_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
])


train_ds = datasets.ImageFolder(os.path.join(DATA_DIR, "train"), transform=train_tfms)
val_ds   = datasets.ImageFolder(os.path.join(DATA_DIR, "val"),   transform=val_tfms)

train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

NUM_CLASSES = len(train_ds.classes)
train_ds.classes


['Apple___Apple_scab',
 'Apple___Black_rot',
 'Apple___Cedar_apple_rust',
 'Apple___healthy',
 'Blueberry___healthy',
 'Cherry_(including_sour)___Powdery_mildew',
 'Cherry_(including_sour)___healthy',
 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot',
 'Corn_(maize)___Common_rust_',
 'Corn_(maize)___Northern_Leaf_Blight',
 'Corn_(maize)___healthy',
 'Grape___Black_rot',
 'Grape___Esca_(Black_Measles)',
 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
 'Grape___healthy',
 'Orange___Haunglongbing_(Citrus_greening)',
 'Peach___Bacterial_spot',
 'Peach___healthy',
 'Pepper,_bell___Bacterial_spot',
 'Pepper,_bell___healthy',
 'Potato___Early_blight',
 'Potato___Late_blight',
 'Potato___healthy',
 'Raspberry___healthy',
 'Soybean___healthy',
 'Squash___Powdery_mildew',
 'Strawberry___Leaf_scorch',
 'Strawberry___healthy',
 'Tomato___Bacterial_spot',
 'Tomato___Early_blight',
 'Tomato___Late_blight',
 'Tomato___Leaf_Mold',
 'Tomato___Septoria_leaf_spot',
 'Tomato___Spider_mites Two-spotted_

### Train

In [None]:
model = models.vit_b_16(weights="IMAGENET1K_V1")
model.heads.head = nn.Linear(model.heads.head.in_features, NUM_CLASSES)
model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=LR)


for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for imgs, labels in tqdm(train_dl):
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)

        optimizer.zero_grad()
        preds = model(imgs)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {total_loss/len(train_dl):.4f}")


Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to C:\Users\User/.cache\torch\hub\checkpoints\vit_b_16-c867db91.pth


100%|██████████| 330M/330M [04:56<00:00, 1.17MB/s] 
  1%|          | 8/1188 [03:45<9:18:05, 28.38s/it]

### Evaluation

In [None]:
model.eval()
correct, total = 0, 0

with torch.no_grad():
    for imgs, labels in val_dl:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        preds = model(imgs).argmax(1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

accuracy = correct / total
accuracy


In [None]:
# Save the model
torch.save(model.state_dict(), "cassava_vit_best.pth")
