In [6]:
# !pip install opendatasets --quiet
# import opendatasets as od
# od.download("https://www.kaggle.com/datasets/abdallahalidev/plantvillage-dataset")

In [7]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, WeightedRandomSampler
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
from PIL import Image
import os
from collections import Counter  # ADD THIS

In [None]:
# ============================
# 2. CONSTANTS (HYPERPARAMETERS)
# ============================
BATCH_SIZE = 32
IMAGE_SIZE = 380
CHANNEL = 3
EPOCHS = 20
DATA_PATH = r"Plants"

In [9]:
# ADD DEVICE SETTING
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Using device: cuda


In [10]:
# Load dataset
dataset = datasets.ImageFolder(DATA_PATH)

# ============================
# NEW: CALCULATE CLASS WEIGHTS
# ============================
print("\n=== CALCULATING CLASS WEIGHTS FOR IMBALANCE ===")
# Get class counts
class_counts = {}
for class_name in dataset.classes:
    class_path = os.path.join(DATA_PATH, class_name)
    count = len([f for f in os.listdir(class_path)])
    class_counts[class_name] = count
    print(f"  {class_name}: {count} images")

# Calculate class weights (inverse frequency)
class_weights = []
for idx, class_name in enumerate(dataset.classes):
    weight = 1.0 / class_counts[class_name]  # Smaller class = higher weight
    class_weights.append(weight)

# Convert to tensor for loss function
class_weights_tensor = torch.tensor(class_weights, device=device)

# Calculate weights for WeightedRandomSampler
sample_weights = []
for _, label in dataset:
    sample_weights.append(class_weights[label])


=== CALCULATING CLASS WEIGHTS FOR IMBALANCE ===
  Apple___Apple_scab: 630 images
  Apple___Black_rot: 621 images
  Apple___Cedar_apple_rust: 275 images
  Apple___healthy: 1645 images
  Cherry_(including_sour)___Powdery_mildew: 1052 images
  Cherry_(including_sour)___healthy: 854 images
  Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot: 513 images
  Corn_(maize)___Common_rust_: 1192 images
  Corn_(maize)___Northern_Leaf_Blight: 985 images
  Corn_(maize)___healthy: 1162 images
  Grape___Black_rot: 1180 images
  Grape___Esca_(Black_Measles): 1383 images
  Grape___Leaf_blight_(Isariopsis_Leaf_Spot): 1076 images
  Grape___healthy: 423 images
  Peach___Bacterial_spot: 2297 images
  Peach___healthy: 360 images
  Pepper,_bell___Bacterial_spot: 997 images
  Pepper,_bell___healthy: 1478 images
  Potato___Early_blight: 1000 images
  Potato___Late_blight: 1000 images
  Potato___healthy: 152 images
  Strawberry___Leaf_scorch: 1109 images
  Strawberry___healthy: 456 images
  Tomato___Bacterial_s

In [11]:
# ============================
# NEW: STRATIFIED SPLIT
# ============================
from sklearn.model_selection import train_test_split

# Get indices and labels
indices = list(range(len(dataset)))
labels = [label for _, label in dataset]

# Split with stratification
train_indices, temp_indices = train_test_split(
    indices, test_size=0.3, random_state=42, stratify=labels
)
val_indices, test_indices = train_test_split(
    temp_indices, test_size=0.5, random_state=42,
    stratify=[labels[i] for i in temp_indices]
)

# Create subsets
train_data = torch.utils.data.Subset(dataset, train_indices)
val_data = torch.utils.data.Subset(dataset, val_indices)
test_data = torch.utils.data.Subset(dataset, test_indices)

print(f"\nDataset split:")
print(f"  Train: {len(train_data)} samples")
print(f"  Val: {len(val_data)} samples")
print(f"  Test: {len(test_data)} samples")


Dataset split:
  Train: 28000 samples
  Val: 6000 samples
  Test: 6000 samples


In [12]:
# ============================
# 3. TRANSFORMS (RESIZE + NORMALIZE)
# ============================
train_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(20),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),  # ADDED
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # ADDED
])

test_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # ADDED
])

In [None]:
# ============================
# 4. LOAD DATASET
# ============================
train_data.dataset.transform = train_transform
val_data.dataset.transform = test_transform
test_data.dataset.transform = test_transform

class_names = dataset.classes
num_classes = len(class_names)

# ============================
# NEW: CREATE WEIGHTED SAMPLER FOR TRAINING
# ============================
# Get weights for train subset only
train_sample_weights = [sample_weights[i] for i in train_indices]
train_sampler = WeightedRandomSampler(
    weights=train_sample_weights,
    num_samples=len(train_sample_weights),
    replacement=True  # Important for oversampling
)

# Create DataLoaders
train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, sampler=train_sampler)  # Use sampler instead of shuffle
val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False)
test_loader = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False)

print(f"\nClasses: {class_names}")
print(f"Number of classes: {num_classes}")


Classes: ['Apple___Apple_scab', 'Apple___Black_rot', 'Apple___Cedar_apple_rust', 'Apple___healthy', 'Cherry_(including_sour)___Powdery_mildew', 'Cherry_(including_sour)___healthy', 'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', 'Corn_(maize)___Common_rust_', 'Corn_(maize)___Northern_Leaf_Blight', 'Corn_(maize)___healthy', 'Grape___Black_rot', 'Grape___Esca_(Black_Measles)', 'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', 'Grape___healthy', 'Peach___Bacterial_spot', 'Peach___healthy', 'Pepper,_bell___Bacterial_spot', 'Pepper,_bell___healthy', 'Potato___Early_blight', 'Potato___Late_blight', 'Potato___healthy', 'Strawberry___Leaf_scorch', 'Strawberry___healthy', 'Tomato___Bacterial_spot', 'Tomato___Early_blight', 'Tomato___Late_blight', 'Tomato___Leaf_Mold', 'Tomato___Septoria_leaf_spot', 'Tomato___Spider_mites Two-spotted_spider_mite', 'Tomato___Target_Spot', 'Tomato___Tomato_Yellow_Leaf_Curl_Virus', 'Tomato___Tomato_mosaic_virus', 'Tomato___healthy']
Number of classes: 33


In [14]:
# ============================
# 5. MODEL
# ============================
import torchvision.models as models

# Load pretrained EfficientNet
model = models.efficientnet_b4(pretrained=True)

# Freeze all layers
for param in model.parameters():
    param.requires_grad = False

# Replace classifier
num_ftrs = model.classifier[1].in_features
model.classifier = nn.Sequential(
    nn.Dropout(0.3),
    nn.Linear(num_ftrs, len(class_names))
)

# Move model to device
model = model.to(device)



In [15]:
# ============================
# 6. LOSS & OPTIMIZER WITH CLASS WEIGHTS
# ============================
# Use weighted CrossEntropyLoss
criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)  # ADDED CLASS WEIGHTS
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [16]:
# ============================
# 7. TRAINING LOOP
# ============================
print("\n=== STARTING TRAINING ===")
for epoch in range(EPOCHS):
    model.train()
    running_loss = 0
    correct = 0
    total = 0  # ADDED

    for imgs, labels in train_loader:
        imgs, labels = imgs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0) # ADDED(32,)

    train_acc = correct / total  # CHANGED: divide by total samples, not batches
    train_loss = running_loss / len(train_loader)

    # Validation
    model.eval()
    val_correct = 0
    val_loss_total = 0
    val_total = 0  # ADDED
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            val_loss_total += loss.item()

            _, preds = torch.max(outputs, 1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)  # ADDED

    val_acc = val_correct / val_total  # CHANGED
    val_loss = val_loss_total / len(val_loader)

    print(f"Epoch {epoch+1}/{EPOCHS} | "
          f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
          f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f}")


=== STARTING TRAINING ===
Epoch 1/20 | Train Loss: 1.0879 | Train Acc: 0.6924 | Val Loss: 0.6522 | Val Acc: 0.8057
Epoch 2/20 | Train Loss: 0.3862 | Train Acc: 0.8658 | Val Loss: 0.4349 | Val Acc: 0.8572
Epoch 3/20 | Train Loss: 0.2751 | Train Acc: 0.8927 | Val Loss: 0.3499 | Val Acc: 0.8765
Epoch 4/20 | Train Loss: 0.2255 | Train Acc: 0.9063 | Val Loss: 0.2952 | Val Acc: 0.8962


KeyboardInterrupt: 

In [None]:
# ============================
# 8. EVALUATE MODEL (WITH PER-CLASS ACCURACY)
# ============================
model.eval()
correct = 0
total = 0

# ADDED: Track per-class accuracy
class_correct = [0] * num_classes
class_total = [0] * num_classes

with torch.no_grad():
    for imgs, labels in test_loader:
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        _, preds = torch.max(outputs, 1)

        correct += (preds == labels).sum().item()
        total += labels.size(0)#â†’ the first dimension, which is the batch size(32,)

        # ADDED: Count per-class
        for i in range(num_classes):
            idx = (labels == i)
            if idx.sum().item() > 0:
                class_correct[i] += (preds[idx] == labels[idx]).sum().item()
                class_total[i] += idx.sum().item()

print(f"\n=== TEST RESULTS ===")
print(f"Overall Test Accuracy: {correct / total:.4f}")

# ADDED: Print per-class accuracy
print("\nPer-class Test Accuracy:")
for i in range(num_classes):
    if class_total[i] > 0:
        acc = class_correct[i] / class_total[i]
        print(f"  {class_names[i]}: {acc:.4f} ({class_total[i]} samples)")


=== TEST RESULTS ===
Overall Test Accuracy: 0.9507

Per-class Test Accuracy:
  Apple___Apple_scab: 0.9681 (94 samples)
  Apple___Black_rot: 0.9785 (93 samples)
  Apple___Cedar_apple_rust: 0.9756 (41 samples)
  Apple___healthy: 0.9756 (246 samples)
  Blueberry___healthy: 1.0000 (226 samples)
  Cherry_(including_sour)___Powdery_mildew: 0.9937 (158 samples)
  Cherry_(including_sour)___healthy: 0.9844 (128 samples)
  Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot: 0.9351 (77 samples)
  Corn_(maize)___Common_rust_: 0.9944 (179 samples)
  Corn_(maize)___Northern_Leaf_Blight: 0.8571 (147 samples)
  Corn_(maize)___healthy: 1.0000 (174 samples)
  Grape___Black_rot: 0.9831 (177 samples)
  Grape___Esca_(Black_Measles): 0.9758 (207 samples)
  Grape___Leaf_blight_(Isariopsis_Leaf_Spot): 1.0000 (162 samples)
  Grape___healthy: 0.9844 (64 samples)
  Orange___Haunglongbing_(Citrus_greening): 0.9927 (826 samples)
  Peach___Bacterial_spot: 0.9448 (344 samples)
  Peach___healthy: 1.0000 (54 samples)

In [None]:
# ============================
# 9. SAVE MODEL
# ============================
torch.save(model.state_dict(), "model_imbalanced.pth")
print("\nModel saved!")


Model saved!
