In [4]:
import os
import time
from pathlib import Path
from typing import Tuple

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, models, transforms
from sklearn.metrics import classification_report, confusion_matrix

Import Data for CV modeling. 

Data: "Cassava Leaf Disease Classification"

21 367 images 5 classes

Unbalanced classes, natural noise

In [7]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("nirmalsankalana/cassava-leaf-disease-classification")

print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/nirmalsankalana/cassava-leaf-disease-classification?dataset_version_number=2...


100%|██████████| 2.39G/2.39G [11:25<00:00, 3.74MB/s]

Extracting files...





Path to dataset files: C:\Users\OlesiaBrusentseva\.cache\kagglehub\datasets\nirmalsankalana\cassava-leaf-disease-classification\versions\2


In [5]:
# ───────────────────────────  Config & Paths  ─────────────────────────
DATA_DIR       = Path("dataLab1")   # root folder with 5 sub-folders
BATCH_SIZE     = 32
NUM_WORKERS    = os.cpu_count() or 2
NUM_EPOCHS     = 5                         # quick demo – increase if GPU budget allows
VAL_SPLIT      = 0.20
SEED           = 42
DEVICE         = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(SEED)

<torch._C.Generator at 0x1fa78b1cd30>

In [None]:
IMG_HEIGHT = 224
IMG_WIDTH = 224
BATCH_SIZE = 32
TRAIN_SPLIT = 0.7 # 70% for training
VALIDATION_SPLIT = 0.15 # 15% for validation
TEST_SPLIT = 0.15 # 15% for testing
SEED = 42 # for reproducibility

In [6]:
# ────────────────────────  Data & Transforms  ─────────────────────────
train_tfms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])
val_tfms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225]),
])

full_ds = datasets.ImageFolder(DATA_DIR, transform=train_tfms)
num_classes = len(full_ds.classes)

# Custom split (train/val) so the same images aren’t reused.
val_len = int(len(full_ds) * VAL_SPLIT)
train_len = len(full_ds) - val_len
train_ds, val_ds = random_split(full_ds, [train_len, val_len],
                                generator=torch.Generator().manual_seed(SEED))
# Validation uses deterministic transforms
val_ds.dataset.transform = val_tfms

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE,
                          shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE,
                          shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)


In [9]:
# ────────────────────────  Helper Functions  ──────────────────────────
def train_one_epoch(model: nn.Module, loader: DataLoader,
                    criterion, optimizer) -> float:
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for inputs, labels in loader:
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
        _, preds = outputs.max(1)
        correct += (preds == labels).sum().item()
        total   += labels.size(0)

    return running_loss / total, correct / total


@torch.no_grad()
def evaluate(model: nn.Module, loader: DataLoader, criterion) -> Tuple[float, float]:
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    all_preds, all_labels = [], []
    for inputs, labels in loader:
        inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
        outputs = model(inputs)
        loss = criterion(outputs, labels)

        running_loss += loss.item() * inputs.size(0)
        _, preds = outputs.max(1)
        correct    += (preds == labels).sum().item()
        total      += labels.size(0)
        all_preds.extend(preds.cpu().tolist())
        all_labels.extend(labels.cpu().tolist())

    report = classification_report(all_labels, all_preds,
                                   target_names=full_ds.classes,
                                   digits=3, zero_division=0)
    return running_loss / total, correct / total, report

def run_training(model, epochs=NUM_EPOCHS):
    model.to(DEVICE)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    scheduler = lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.3)

    best_acc = 0.0
    for epoch in range(1, epochs + 1):
        t0 = time.time()
        tr_loss, tr_acc = train_one_epoch(model, train_loader, criterion, optimizer)
        vl_loss, vl_acc, _ = evaluate(model, val_loader, criterion)
        scheduler.step()

        if vl_acc > best_acc:
            best_acc = vl_acc
            best_wts = model.state_dict()

        print(f"[{epoch:02}/{epochs}] "
              f"Train loss={tr_loss:.4f} acc={tr_acc:.3f} | "
              f"Val loss={vl_loss:.4f} acc={vl_acc:.3f} "
              f"({time.time() - t0:.1f}s)")

    model.load_state_dict(best_wts)
    _, _, final_report = evaluate(model, val_loader, criterion)
    print("\nBest validation accuracy: {:.3f}".format(best_acc))
    print(final_report)
    return best_acc, final_report


In [10]:
# ────────────────────────────  Baseline CNN  ──────────────────────────
class SmallCNN(nn.Module):
    """A lightweight CNN built from scratch – serves as the baseline."""
    def __init__(self, n_classes: int):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1), nn.ReLU(), nn.AdaptiveAvgPool2d(1),
        )
        self.classifier = nn.Sequential(
            nn.Dropout(0.5),
            nn.Linear(256, n_classes)
        )

    def forward(self, x):
        x = self.features(x).flatten(1)
        return self.classifier(x)

print("\n─────────  Baseline (training from scratch)  ─────────")
baseline_model = SmallCNN(num_classes)
baseline_acc, baseline_report = run_training(baseline_model)



─────────  Baseline (training from scratch)  ─────────




[01/5] Train loss=1.0611 acc=0.626 | Val loss=0.9641 acc=0.636 (959.5s)




[02/5] Train loss=0.9499 acc=0.647 | Val loss=0.8964 acc=0.670 (974.5s)




[03/5] Train loss=0.9017 acc=0.662 | Val loss=0.8681 acc=0.678 (955.8s)




[04/5] Train loss=0.8318 acc=0.686 | Val loss=0.7929 acc=0.694 (948.8s)




[05/5] Train loss=0.7988 acc=0.699 | Val loss=0.7532 acc=0.711 (946.9s)





Best validation accuracy: 0.711
                                precision    recall  f1-score   support

    Cassava___bacterial_blight      0.394     0.179     0.246       207
Cassava___brown_streak_disease      0.576     0.322     0.413       401
        Cassava___green_mottle      0.660     0.127     0.214       502
             Cassava___healthy      0.437     0.579     0.498       539
      Cassava___mosaic_disease      0.794     0.951     0.866      2630

                      accuracy                          0.711      4279
                     macro avg      0.572     0.432     0.447      4279
                  weighted avg      0.694     0.711     0.670      4279



In [11]:
# ─────────────────────────  Fine-Tuned Model  ─────────────────────────
print("\n─────────  Fine-tuning EfficientNet-B0  ─────────")
ft_model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1)
ft_model.classifier[1] = nn.Linear(ft_model.classifier[1].in_features, num_classes)
# Option A: fine-tune *all* layers
for p in ft_model.parameters():
    p.requires_grad = True
fine_tune_acc, fine_tune_report = run_training(ft_model)


─────────  Fine-tuning EfficientNet-B0  ─────────
Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to C:\Users\OlesiaBrusentseva/.cache\torch\hub\checkpoints\efficientnet_b0_rwightman-7f5810bc.pth


100%|██████████| 20.5M/20.5M [00:03<00:00, 6.05MB/s]


[01/5] Train loss=0.6153 acc=0.781 | Val loss=0.5471 acc=0.823 (2104.3s)




[02/5] Train loss=0.4669 acc=0.840 | Val loss=0.4537 acc=0.845 (2159.6s)




[03/5] Train loss=0.3925 acc=0.866 | Val loss=0.4374 acc=0.852 (2259.7s)




[04/5] Train loss=0.2347 acc=0.921 | Val loss=0.4238 acc=0.860 (2178.9s)




[05/5] Train loss=0.1430 acc=0.951 | Val loss=0.4914 acc=0.856 (2096.0s)





Best validation accuracy: 0.860
                                precision    recall  f1-score   support

    Cassava___bacterial_blight      0.598     0.488     0.537       207
Cassava___brown_streak_disease      0.681     0.810     0.740       401
        Cassava___green_mottle      0.818     0.743     0.779       502
             Cassava___healthy      0.680     0.653     0.666       539
      Cassava___mosaic_disease      0.944     0.954     0.949      2630

                      accuracy                          0.856      4279
                     macro avg      0.744     0.730     0.734      4279
                  weighted avg      0.855     0.856     0.854      4279



In [12]:
# ───────────────────────────  Comparison  ─────────────────────────────
print("━━━━━━━━━━━━━━━━  Summary  ━━━━━━━━━━━━━━━━")
print(f"Baseline accuracy:    {baseline_acc*100:.2f} %")
print(f"Fine-tuned accuracy:  {fine_tune_acc*100:.2f} %")
print("\nBaseline confusion / precision / recall:\n", baseline_report)
print("\nFine-tuned confusion / precision / recall:\n", fine_tune_report)

━━━━━━━━━━━━━━━━  Summary  ━━━━━━━━━━━━━━━━
Baseline accuracy:    71.14 %
Fine-tuned accuracy:  85.98 %

Baseline confusion / precision / recall:
                                 precision    recall  f1-score   support

    Cassava___bacterial_blight      0.394     0.179     0.246       207
Cassava___brown_streak_disease      0.576     0.322     0.413       401
        Cassava___green_mottle      0.660     0.127     0.214       502
             Cassava___healthy      0.437     0.579     0.498       539
      Cassava___mosaic_disease      0.794     0.951     0.866      2630

                      accuracy                          0.711      4279
                     macro avg      0.572     0.432     0.447      4279
                  weighted avg      0.694     0.711     0.670      4279


Fine-tuned confusion / precision / recall:
                                 precision    recall  f1-score   support

    Cassava___bacterial_blight      0.598     0.488     0.537       207
Cassava___br