# 02 — Train PlantVillage Disease Classifier


In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, models, datasets
from torch.optim import Adam
from tqdm import tqdm
import matplotlib.pyplot as plt
import os


In [14]:
DATA_DIR = "../data/processed/plantVillage"
BATCH_SIZE = 32
LR = 1e-4
EPOCHS = 15
IMG_SIZE = 224
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"


In [13]:
import os

print(os.listdir("../data"))

['dataset_cards.md', 'interim', 'processed', 'raw']


In [None]:
import os
from torchvision import datasets, transforms
from torch.utils.data import DataLoader



IMG_SIZE = 224
BATCH_SIZE = 32

train_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
])

val_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
])

# correctly load train/val folders inside each dataset
train_ds = datasets.ImageFolder(os.path.join(DATA_DIR, "train"), transform=train_tfms)
val_ds   = datasets.ImageFolder(os.path.join(DATA_DIR, "val"), transform=val_tfms)

train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

# NUM_CLASSES depends on dataset (auto-detect)
NUM_CLASSES = len(train_ds.classes)

train_ds.classes


['Apple___Apple_scab',
 'Apple___Black_rot',
 'Apple___Cedar_apple_rust',
 'Apple___healthy',
 'Blueberry___healthy',
 'Cherry_(including_sour)___Powdery_mildew',
 'Cherry_(including_sour)___healthy',
 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot',
 'Corn_(maize)___Common_rust_',
 'Corn_(maize)___Northern_Leaf_Blight',
 'Corn_(maize)___healthy',
 'Grape___Black_rot',
 'Grape___Esca_(Black_Measles)',
 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
 'Grape___healthy',
 'Orange___Haunglongbing_(Citrus_greening)',
 'Peach___Bacterial_spot',
 'Peach___healthy',
 'Pepper,_bell___Bacterial_spot',
 'Pepper,_bell___healthy',
 'Potato___Early_blight',
 'Potato___Late_blight',
 'Potato___healthy',
 'Raspberry___healthy',
 'Soybean___healthy',
 'Squash___Powdery_mildew',
 'Strawberry___Leaf_scorch',
 'Strawberry___healthy',
 'Tomato___Bacterial_spot',
 'Tomato___Early_blight',
 'Tomato___Late_blight',
 'Tomato___Leaf_Mold',
 'Tomato___Septoria_leaf_spot',
 'Tomato___Spider_mites Two-spotted_

### Train

In [None]:
import torch

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", DEVICE)

LR = 1e-4   # learning rate
EPOCHS = 10

model = models.vit_b_16(weights="IMAGENET1K_V1")
model.heads.head = nn.Linear(model.heads.head.in_features, NUM_CLASSES)
model = model.to(DEVICE)

criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=LR)


for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    for imgs, labels in tqdm(train_dl):
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)

        optimizer.zero_grad()
        preds = model(imgs)
        loss = criterion(preds, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {total_loss/len(train_dl):.4f}")


Downloading: "https://download.pytorch.org/models/vit_b_16-c867db91.pth" to C:\Users\User/.cache\torch\hub\checkpoints\vit_b_16-c867db91.pth


100%|██████████| 330M/330M [04:56<00:00, 1.17MB/s] 
  1%|          | 8/1188 [03:45<9:18:05, 28.38s/it]

### Evaluation

In [None]:
# ---------------------
# Evaluation
# ---------------------
model.eval()
correct = 0
total = 0

with torch.no_grad():
    for imgs, labels in val_dl:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)

        preds = model(imgs)
        preds = preds.argmax(1)

        correct += (preds == labels).sum().item()
        total += labels.size(0)

val_accuracy = correct / total
print("Validation Accuracy:", val_accuracy * 100, "%")


# ---------------------
# Save Model
# ---------------------
MODEL_NAME = "plant_village_best.pth"   # change per dataset
torch.save(model.state_dict(), MODEL_NAME)

print("Saved:", MODEL_NAME)


In [None]:
# Save the model
torch.save(model.state_dict(), "PlantVillage_best.pth")
