In [5]:
# Import OS module
import os

# Imports
import numpy as np
import pandas as pd

# Visualization
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, models
from torchvision.transforms import RandAugment
from sklearn.metrics import classification_report
import numpy as np
import os

In [6]:
dataset_root = "/kaggle/input/visual-plant-disease-detection/dataset"

train_dir = os.path.join(dataset_root,"train")
test_dir = os.path.join(dataset_root,"test")

### Trainning with pure train dataset

In [7]:
# Transforms
image_size = 224

train_transform = transforms.Compose([
    transforms.RandomResizedCrop(image_size),
    transforms.RandomHorizontalFlip(),
    RandAugment(num_ops=2, magnitude=9),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor()
])

In [8]:
def train_transfer_model(
    train_dir,
    test_dir,
    train_transform,
    test_transform,
    image_size=224,
    batch_size=32,
    epochs=30,
    patience=3,
    base_model_trainable=False,
    optimizer_name="adam",
    loss_fn_name="cross_entropy",
    steps_per_epoch=150,
    validation_split=0.2
):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Dataset and split
    full_dataset = datasets.ImageFolder(train_dir, transform=train_transform)
    num_val = int(validation_split * len(full_dataset))
    num_train = len(full_dataset) - num_val
    train_dataset, val_dataset = random_split(full_dataset, [num_train, num_val])
    val_dataset.dataset.transform = test_transform  # Use test transforms for validation

    test_dataset = datasets.ImageFolder(test_dir, transform=test_transform)

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    class_names = full_dataset.classes
    num_classes = len(class_names)

    # Load base model
    base_model = models.mobilenet_v2(pretrained=True)
    base_model.classifier = nn.Sequential(
        nn.Dropout(0.2),
        nn.Linear(base_model.last_channel, num_classes)
    )
    if not base_model_trainable:
        for param in base_model.features.parameters():
            param.requires_grad = False

    base_model.to(device)

    # Loss and optimizer
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(base_model.parameters()) if optimizer_name == "adam" else optim.SGD(base_model.parameters(), lr=0.01)

    # Early stopping
    best_loss = float("inf")
    patience_counter = 0

    history = {"train_loss": [], "val_loss": [], "val_accuracy": []}

    for epoch in range(epochs):
        base_model.train()
        train_loss = 0.0
        for i, (inputs, labels) in enumerate(train_loader):
            if i >= steps_per_epoch:
                break
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = base_model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()

        # Validation
        base_model.eval()
        val_loss = 0.0
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                inputs, labels = inputs.to(device), labels.to(device)
                outputs = base_model(inputs)
                loss = criterion(outputs, labels)
                val_loss += loss.item()

                _, preds = torch.max(outputs, 1)
                correct += (preds == labels).sum().item()
                total += labels.size(0)

        avg_train_loss = train_loss / steps_per_epoch
        avg_val_loss = val_loss / len(val_loader)
        val_acc = correct / total
        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.4f}")

        history["train_loss"].append(avg_train_loss)
        history["val_loss"].append(avg_val_loss)
        history["val_accuracy"].append(val_acc)

        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            patience_counter = 0
            best_model_state = base_model.state_dict()
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print("Early stopping triggered.")
                break

    # Load best model
    base_model.load_state_dict(best_model_state)

    # Test
    base_model.eval()
    all_preds = []
    all_labels = []
    test_loss = 0.0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = base_model(inputs)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

    test_loss /= len(test_loader)
    test_accuracy = np.mean(np.array(all_preds) == np.array(all_labels))

    print(f"\n✅ Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.4f}")
    print("\nClassification Report:")
    print(classification_report(all_labels, all_preds, target_names=class_names))

    return {
        "model": base_model,
        "history": history,
        "test_loss": test_loss,
        "test_accuracy": test_accuracy,
        "true_classes": all_labels,
        "predicted_classes": all_preds,
        "class_labels": class_names
    }

In [9]:
result = train_transfer_model(
    train_dir = train_dir,
    test_dir = test_dir,
    train_transform=train_transform,
    test_transform=test_transform,
)

Downloading: "https://download.pytorch.org/models/mobilenet_v2-b0353104.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-b0353104.pth
100%|██████████| 13.6M/13.6M [00:00<00:00, 177MB/s]


Epoch 1/30, Train Loss: 1.0568, Val Loss: 2.0997, Val Acc: 0.4077
Epoch 2/30, Train Loss: 0.7155, Val Loss: 1.7244, Val Acc: 0.4828
Epoch 3/30, Train Loss: 0.5833, Val Loss: 1.5987, Val Acc: 0.5193
Epoch 4/30, Train Loss: 0.5055, Val Loss: 1.5151, Val Acc: 0.5322
Epoch 5/30, Train Loss: 0.4521, Val Loss: 1.5294, Val Acc: 0.5107
Epoch 6/30, Train Loss: 0.4132, Val Loss: 1.4680, Val Acc: 0.5172
Epoch 7/30, Train Loss: 0.3833, Val Loss: 1.4424, Val Acc: 0.5279
Epoch 8/30, Train Loss: 0.3490, Val Loss: 1.4566, Val Acc: 0.5429
Epoch 9/30, Train Loss: 0.3290, Val Loss: 1.4181, Val Acc: 0.5579
Epoch 10/30, Train Loss: 0.3211, Val Loss: 1.4192, Val Acc: 0.5730
Epoch 11/30, Train Loss: 0.3012, Val Loss: 1.4250, Val Acc: 0.5322
Epoch 12/30, Train Loss: 0.2897, Val Loss: 1.4292, Val Acc: 0.5558
Early stopping triggered.

✅ Test Loss: 1.4671, Test Accuracy: 0.5000

Classification Report:
                            precision    recall  f1-score   support

           Apple Scab Leaf       0.50     