# **Step 01: Model Training**

# A. Import Libraries and Set Global Configuration

In [None]:
!wget https://github.com/ksnugroho/bncc-ai-azure/raw/refs/heads/main/leaf-disease-dataset.zip

In [None]:
from zipfile import ZipFile

with ZipFile("leaf-disease-dataset.zip","r") as zip_ref:
    zip_ref.extractall("leaf-disease-dataset")

In [None]:
!python -V

In [None]:
%pip -q install torchvista

In [None]:
import os
import random
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
from pathlib import Path
from time import perf_counter

# PyTorch core modules
import torch
import torchvision
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms, models

# Evaluation modules
import sklearn
from sklearn.model_selection import train_test_split
from sklearn.metrics import (
    accuracy_score, f1_score, classification_report,
    confusion_matrix, ConfusionMatrixDisplay, RocCurveDisplay
)

from tqdm import tqdm
from torchvista import trace_model

# Set a deterministic seed for reproducible results
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

# Select GPU if available, otherwise CPU
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", DEVICE)

# Path to the dataset directory
DATA_PATH = Path("leaf-disease-dataset")

In [None]:
print("torch:", torch.__version__)
print("torchvision:", torchvision.__version__)
print("numpy:", np.__version__)
print("scikit-learn:", sklearn.__version__)
print("matplotlib:", matplotlib.__version__)
print("pandas:", pd.__version__)

# B. Prepare Dataset, Transforms, and DataLoaders

In [None]:
# Define image preprocessing pipeline (resize, convert to tensor, normalize)
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],       # Standard ImageNet mean
        std=[0.229, 0.224, 0.225]         # Standard ImageNet std
    )
])

# Load dataset using folder structure (one folder per class)
full_dataset = datasets.ImageFolder(str(DATA_PATH), transform=transform)

# Extract class labels
targets = [y for _, y in full_dataset.samples]
CLASS_NAMES = full_dataset.classes
print("Classes found:", CLASS_NAMES)

# Create stratified train/validation split to preserve class ratio
train_idx, val_idx = train_test_split(
    np.arange(len(targets)),
    test_size=0.2,
    random_state=SEED,
    stratify=targets
)

# Wrap indices into Subset objects
train_ds = Subset(full_dataset, train_idx)
val_ds   = Subset(full_dataset, val_idx)

# Create DataLoaders for efficient batching
BATCH_SIZE = 16
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
val_dl   = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

print(f"Total samples: {len(full_dataset)}  Train: {len(train_ds)}  Val: {len(val_ds)}")

# C. Initialize Model and Trace Architecture

In [None]:
# More vision architectures in PyTorch (CNNs, Vision Transformers, and hybrids):
# https://docs.pytorch.org/vision/main/models.html

# Example model: EfficientNet (CNN-based)
# Original paper: "EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks"
# https://arxiv.org/abs/1905.11946

# Load pre-trained EfficientNet-B0 model with ImageNet weights:
# https://docs.pytorch.org/vision/main/models/generated/torchvision.models.efficientnet_b0.html#torchvision.models.efficientnet_b0
model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)

# Replace final classification layer to match the number of target classes
model.classifier = nn.Linear(model.classifier[1].in_features, len(CLASS_NAMES))

model = model.to(DEVICE)  # Move model to GPU/CPU device
model.eval()              # Set model to evaluation mode for tracing

# Create a dummy input tensor for model tracing/visualization
example_input = torch.randn(1, 3, 244, 244).to(DEVICE)

# Generate model trace visualization
trace_model(model, example_input)

# D. Training Setup

## Define Loss Function and Optimizer

In [None]:
LOSS_FUNCTION = nn.CrossEntropyLoss()
OPTIMIZER = torch.optim.Adam(model.parameters(), lr=1e-4)

## Define Training Loop Function

In [None]:
def train(model, dataloader, optimizer, loss_function, device):
    model.train()  # Set model to training mode
    
    epoch_loss = 0.0
    preds_all, labels_all = [], []
    start_time = perf_counter()  # Measure epoch duration

    for x, y in dataloader:
        x, y = x.to(device), y.to(device)

        optimizer.zero_grad()                     # Clear previous gradients
        output = model(x)                         # Forward pass
        batch_loss = loss_function(output, y)     # Compute loss
        batch_loss.backward()                     # Backpropagation
        optimizer.step()                          # Update model parameters

        epoch_loss += batch_loss.item() * x.size(0)          # Accumulate batch loss

        preds = output.argmax(dim=1).detach().cpu().numpy()  # Predictions
        labels = y.detach().cpu().numpy()                    # Ground truth
        preds_all.append(preds)
        labels_all.append(labels)

    end_time = perf_counter() - start_time
 
    # Aggregate predictions and compute metrics
    y_pred = np.concatenate(preds_all)
    y_true = np.concatenate(labels_all)
    n = len(y_true)

    avg_loss = epoch_loss / n
    accuracy = accuracy_score(y_true, y_pred)
    f1_macro = f1_score(y_true, y_pred, average="macro")

    return {
        "loss": avg_loss,
        "accuracy": accuracy,
        "f1_macro": f1_macro,
        "time_s": end_time
    }

## Define Validation Loop Function

In [None]:
def validate(model, dataloader, loss_function, device):
    model.eval()  # Set model to evaluation mode
    
    epoch_loss = 0.0
    preds_all, labels_all = [], []
    start_time = perf_counter()  # Measure validation duration

    with torch.no_grad():        # Disable gradient computation
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)

            output = model(x)                           # Forward pass only
            batch_loss = loss_function(output, y)       # Compute validation loss
            epoch_loss += batch_loss.item() * x.size(0)

            preds = output.argmax(dim=1).cpu().numpy()  # Predictions
            labels = y.cpu().numpy()                    # Ground truth
            preds_all.append(preds)
            labels_all.append(labels)

    end_time = perf_counter() - start_time

    # Aggregate predictions and compute metrics
    y_pred = np.concatenate(preds_all)
    y_true = np.concatenate(labels_all)
    n = len(y_true)

    avg_loss = epoch_loss / n if n > 0 else 0.0
    acc = accuracy_score(y_true, y_pred)
    f1_macro = f1_score(y_true, y_pred, average="macro")

    return {
        "loss": avg_loss,
        "accuracy": acc,
        "f1_macro": f1_macro,
        "time_s": end_time
    }

# E. Run Training and Validation Loop

In [None]:
%%time

EPOCHS = 3
history = {"train": [], "val": []}   # Store metrics for monitoring

for ep in range(1, EPOCHS + 1):
    # Perform one epoch of training and validation
    train_stats = train(model, train_dl, OPTIMIZER, LOSS_FUNCTION, DEVICE)
    val_stats   = validate(model, val_dl, LOSS_FUNCTION, DEVICE)

    # Save metrics
    history["train"].append(train_stats)
    history["val"].append(val_stats)

    # Print progress summary for the current epoch
    print(f"Epoch {ep} â€” Train loss: {train_stats['loss']:.4f}  "
          f"acc: {train_stats['accuracy']:.4f}  f1: {train_stats['f1_macro']:.4f}  "
          f"time: {train_stats['time_s']:.1f}s")
    
    print(f"           Val   loss: {val_stats['loss']:.4f}  "
          f"acc: {val_stats['accuracy']:.4f}  f1: {val_stats['f1_macro']:.4f}  "
          f"time: {val_stats['time_s']:.1f}s")

## Plot Training and Validation Metrics

In [None]:
train_loss = [h['loss'] for h in history['train']]
val_loss = [h['loss'] for h in history['val']]

train_acc = [h['accuracy'] for h in history['train']]
val_acc = [h['accuracy'] for h in history['val']]

epochs_range = range(1, len(train_loss) + 1)
plt.figure(figsize=(10,4))

plt.subplot(1,2,1)
plt.plot(epochs_range, train_loss, label="Training Loss") 
plt.plot(epochs_range, val_loss, label="Validation Loss")
plt.title("Loss per Epoch")
plt.xlabel("Epoch")
plt.legend()

plt.subplot(1,2,2)
plt.plot(epochs_range, train_acc, label="Training Accuracy") 
plt.plot(epochs_range, val_acc, label="Validation Accuracy")
plt.title("Accuracy per Epoch")
plt.xlabel("Epoch")
plt.legend()

plt.tight_layout()
plt.show()

# F. Evaluation
Read more: https://ksnugroho.medium.com/confusion-matrix-untuk-evaluasi-model-pada-unsupervised-machine-learning-bc4b1ae9ae3f

## Define Prediction Extraction Function

In [None]:
def get_preds_and_trues(model, dataloader, device):
    model.eval()
    
    labels_all, preds_all, probs_all = [], [], []
    
    with torch.no_grad():
        for x, y in dataloader:
            x, y = x.to(device), y.to(device)
            output = model(x)
            
            probs = torch.softmax(output, dim=1)[:, 1].cpu().numpy() 
            preds = output.argmax(dim=1).cpu().numpy()
            
            labels_all.append(y.cpu().numpy())
            preds_all.append(preds)
            probs_all.append(probs)
            
    y_true = np.concatenate(labels_all)
    y_pred = np.concatenate(preds_all)
    y_score = np.concatenate(probs_all) 
    
    return y_true, y_pred, y_score

In [None]:
y_true, y_pred, y_score = get_preds_and_trues(model, val_dl, DEVICE)

## Classification Report

In [None]:
print(classification_report(y_true, y_pred, digits=4, target_names=CLASS_NAMES))

## Plot Confusion Matrix

In [None]:
disp = ConfusionMatrixDisplay(
    confusion_matrix=confusion_matrix(y_true, y_pred),
    display_labels=CLASS_NAMES
)

disp.plot(cmap=plt.cm.Blues) # Menggunakan skema warna Biru
plt.title("Confusion Matrix")
plt.show()

## Plot ROC Curve

In [None]:
RocCurveDisplay.from_predictions(y_true, y_score)
plt.title("ROC Curve")
plt.show()

# G. Save Trained Model for Deployment

In [None]:
os.makedirs("artifacts", exist_ok=True)

torch.save({
    "model_state_dict": model.state_dict(), 
    "classes": CLASS_NAMES
}, "artifacts/model.pth")