# Cassava Leaf Disease Classification Using Deep Learning Models

## 1. Introduction

### 1.1 Problem Statement
This research focuses on automated detection and classification of cassava leaf diseases using deep learning approaches. Cassava, a vital food security crop in Africa, faces significant yield losses due to various diseases. Early and accurate disease detection is crucial for effective crop management.

### 1.2 Dataset Overview
We utilize the [Cassava Leaf Disease Classification](https://www.kaggle.com/c/cassava-leaf-disease-classification) dataset, which contains images of cassava leaves affected by different diseases:
- Cassava Bacterial Blight (CBB)
- Cassava Brown Streak Disease (CBSD)
- Cassava Green Mottle (CGM)
- Cassava Mosaic Disease (CMD)
- Healthy specimens

### 1.3 Research Objectives
1. Develop and compare multiple deep learning architectures for disease classification
2. Evaluate model performance with emphasis on balanced metrics (macro-F1)
3. Analyze model robustness and generalization capabilities
4. Provide insights into architectural choices for plant disease detection

### 1.4 Methodological Overview
We implement and compare three state-of-the-art architectures:
- ResNet50 with progressive unfreezing
- EfficientNet-B0 with advanced augmentation
- Vision Transformer (ViT) with patch-based learning

## Kaggle Setup

First, let's set up our Kaggle environment and access the dataset:

In [None]:
# Import Kaggle dataset
import os
import json
import pandas as pd
from pathlib import Path

# Kaggle competition dataset
competition_name = "cassava-leaf-disease-classification"
dataset_path = Path("/kaggle/input/cassava-leaf-disease-classification")

# Load dataset metadata
train_df = pd.read_csv(dataset_path / "train.csv")
print("Dataset Overview:")
print(f"Total images: {len(train_df)}")
print("\nClass distribution:")
print(train_df['label'].value_counts())

# Map numeric labels to disease names
label_names = {
    0: "Cassava Bacterial Blight (CBB)",
    1: "Cassava Brown Streak Disease (CBSD)",
    2: "Cassava Green Mottle (CGM)",
    3: "Cassava Mosaic Disease (CMD)",
    4: "Healthy"
}

train_df['disease_name'] = train_df['label'].map(label_names)
print("\nSample data:")
print(train_df.head())

## 2. Background and Literature Review

### 2.1 Cassava Disease Impact
Cassava diseases pose a significant threat to food security in developing regions. Early detection through computer vision can significantly improve crop management and yield protection.

### 2.2 Deep Learning in Plant Pathology
Recent advances in deep learning have shown promising results in plant disease detection:
- Convolutional Neural Networks (CNNs) for image classification
- Transfer learning for limited dataset scenarios
- Real-time detection systems for field applications

### 2.3 Technical Background
Our implementation utilizes:
- PyTorch and torchvision for deep learning models
- Advanced data augmentation techniques
- Evaluation metrics focused on balanced performance
- Modern training strategies (progressive unfreezing, learning rate scheduling)

In [None]:
# Import required libraries
import os
import time
from pathlib import Path
from typing import Dict, List, Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torchvision import datasets, models, transforms
from tqdm.notebook import tqdm
from sklearn.metrics import confusion_matrix, classification_report, f1_score, accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from PIL import Image
from transformers import get_cosine_schedule_with_warmup

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Set device - Kaggle provides GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Create directories for saving models and results
os.makedirs("/kaggle/working/models", exist_ok=True)
os.makedirs("/kaggle/working/results", exist_ok=True)

# Competition dataset paths
DATASET_PATH = Path("/kaggle/input/cassava-leaf-disease-classification")
TRAIN_PATH = DATASET_PATH / "train_images"
TEST_PATH = DATASET_PATH / "test_images"

## 3. Methodology

### 3.1 Data Preparation and Analysis
We implement a robust data pipeline including:
1. Custom dataset implementation
2. Advanced augmentation strategies
3. Stratified train/validation splitting
4. Batch processing optimization

The following section details our data preparation approach:

In [None]:
# Custom Dataset class for Cassava
class CassavaDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df
        self.transform = transform
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):
        img_name = self.df.iloc[idx]['image_id']
        label = self.df.iloc[idx]['label']
        
        # Load image
        img_path = TRAIN_PATH / img_name
        image = Image.open(img_path).convert('RGB')
        
        if self.transform:
            image = self.transform(image)
            
        return image, label

# Define data transforms
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

# Split data into train and validation
train_df, val_df = train_test_split(
    train_df, 
    test_size=0.2, 
    stratify=train_df['label'],
    random_state=42
)

# Create datasets
train_dataset = CassavaDataset(train_df, transform=data_transforms['train'])
val_dataset = CassavaDataset(val_df, transform=data_transforms['val'])

# Create data loaders (we will optionally replace the train loader with a WeightedRandomSampler)
batch_size = 32

# Compute class distribution and class weights
class_counts = train_df['label'].value_counts().sort_index()
print("Training class distribution:\n", class_counts)

plt.figure(figsize=(8,4))
sns.barplot(x=class_counts.index, y=class_counts.values, palette='viridis')
plt.xticks(ticks=np.arange(len(class_counts)), labels=[label_names[i] for i in class_counts.index], rotation=45, ha='right')
plt.title('Original Training Set Class Distribution')
plt.ylabel('Count')
plt.tight_layout()
plt.show()

# Compute class weights (balanced)
classes = np.unique(train_df['label'])
class_weights_np = compute_class_weight(class_weight='balanced', classes=classes, y=train_df['label'].values)
class_weights = torch.tensor(class_weights_np, dtype=torch.float)
print('Class weights (balanced):', class_weights_np)

# Option: use WeightedRandomSampler to balance batches
use_weighted_sampler = True
if use_weighted_sampler:
    # For the sampler we need a weight per sample (based on its label)
    sample_weights = [class_weights_np[label] for label in train_df['label'].tolist()]
    sampler = WeightedRandomSampler(weights=sample_weights, num_samples=len(sample_weights), replacement=True)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, num_workers=2)
    print('Using WeightedRandomSampler for train_loader')
else:
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=2)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

# Display sample images
def show_batch(loader, num_images=8):
    batch = next(iter(loader))
    images, labels = batch
    
    fig, axes = plt.subplots(2, 4, figsize=(15, 8))
    for i, ax in enumerate(axes.flat):
        if i < min(num_images, len(images)):
            img = images[i].permute(1, 2, 0).numpy()
            img = img * [0.229, 0.224, 0.225] + [0.485, 0.456, 0.406]  # Denormalize
            img = np.clip(img, 0, 1)
            ax.imshow(img)
            ax.axis('off')
            ax.set_title(label_names[labels[i].item()])
    plt.tight_layout()
    plt.show()

print("\nShowing a batch of training images with augmentation:")
show_batch(train_loader)

# Verify class distribution after augmentation/sampling
def plot_augmented_distribution(train_loader, num_classes, label_names, num_batches=100):
    """
    Plot the class distribution after augmentation by sampling from the train_loader
    
    Args:
        train_loader: DataLoader with augmentation/sampling
        num_classes: Number of classes
        label_names: Dictionary mapping class indices to names
        num_batches: Number of batches to sample (default=100)
    """
    augmented_counts = np.zeros(num_classes)
    
    # Sample batches and count classes
    print("\nSampling batches to verify class distribution...")
    for i, (_, labels) in enumerate(tqdm(train_loader, total=num_batches)):
        if i >= num_batches:
            break
        labels_np = labels.cpu().numpy()
        for label in range(num_classes):
            augmented_counts[label] += np.sum(labels_np == label)
    
    # Plot original vs augmented distribution
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 8))
    
    # Original distribution
    original_counts = train_df['label'].value_counts().sort_index()
    sns.barplot(x=np.arange(num_classes), y=original_counts, ax=ax1, palette='viridis')
    ax1.set_title('Original Class Distribution')
    ax1.set_xticks(np.arange(num_classes))
    ax1.set_xticklabels([label_names[i] for i in range(num_classes)], rotation=45, ha='right')
    ax1.set_ylabel('Count')
    
    # Augmented distribution
    sns.barplot(x=np.arange(num_classes), y=augmented_counts, ax=ax2, palette='viridis')
    ax2.set_title('Distribution After Augmentation/Sampling')
    ax2.set_xticks(np.arange(num_classes))
    ax2.set_xticklabels([label_names[i] for i in range(num_classes)], rotation=45, ha='right')
    ax2.set_ylabel('Count')
    
    plt.tight_layout()
    plt.show()
    
    # Print class proportions
    print("\nClass proportions after augmentation:")
    proportions = augmented_counts / augmented_counts.sum()
    for i in range(num_classes):
        print(f"{label_names[i]}: {proportions[i]:.3f}")

# Verify class balance after augmentation/sampling
print("\nVerifying class distribution after augmentation and sampling:")
plot_augmented_distribution(
    train_loader,
    len(label_names),
    label_names,
    num_batches=100  # Sample 100 batches to check distribution
)

### 3.2 Model Architecture: ResNet50

#### 3.2.1 Implementation Details
We implement ResNet50 with the following enhancements:
1. Progressive layer unfreezing to prevent catastrophic forgetting
2. Mixup augmentation for improved generalization
3. Custom learning rate scheduling
4. Advanced regularization techniques

#### 3.2.2 Training Strategy
The model is trained in multiple stages:
- Initial stage: Only classifier training
- Intermediate stage: Fine-tuning upper layers
- Final stage: End-to-end fine-tuning

Implementation and training code follows:

In [None]:
# Mixup augmentation
def mixup_data(x, y, alpha=1.0):
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    index = torch.randperm(batch_size).to(device)

    mixed_x = lam * x + (1 - lam) * x[index]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam

def mixup_criterion(criterion, pred, y_a, y_b, lam):
    return lam * criterion(pred, y_a) + (1 - lam) * criterion(pred, y_b)

# Initialize ResNet50
def create_resnet_model(num_classes):
    model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
    
    # Freeze all parameters initially
    for param in model.parameters():
        param.requires_grad = False
        
    # Modify final classifier
    num_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Linear(num_features, 1024),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(1024, 512),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(512, num_classes)
    )
    
    return model.to(device)

# Training utilities
def train_epoch(model, loader, criterion, optimizer, scheduler=None, mixup=True):
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    for inputs, labels in tqdm(loader, leave=False):
        inputs, labels = inputs.to(device), labels.to(device)
        
        if mixup:
            inputs, labels_a, labels_b, lam = mixup_data(inputs, labels)
            
        optimizer.zero_grad()
        outputs = model(inputs)
        
        if mixup:
            loss = mixup_criterion(criterion, outputs, labels_a, labels_b, lam)
            # For metrics, use the dominant label (higher lambda weight)
            true_labels = labels_a if lam > 0.5 else labels_b
        else:
            loss = criterion(outputs, labels)
            true_labels = labels
            
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        
        # Store predictions and true labels for metric calculation
        _, preds = outputs.max(1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(true_labels.cpu().numpy())
        
    if scheduler:
        scheduler.step()
    
    # Calculate metrics
    # Use class_weights tensor if present to report weighted metrics if desired; here we compute accuracy and macro-F1
    accuracy = 100. * accuracy_score(all_labels, all_preds)
    macro_f1 = 100. * f1_score(all_labels, all_preds, average='macro')
        
    return running_loss / len(loader), accuracy, macro_f1

def validate(model, loader, criterion):
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, preds = outputs.max(1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Calculate metrics
    accuracy = 100. * accuracy_score(all_labels, all_preds)
    macro_f1 = 100. * f1_score(all_labels, all_preds, average='macro')
            
    return running_loss / len(loader), accuracy, macro_f1

# Initialize model and training parameters
num_classes = len(label_names)
resnet_model = create_resnet_model(num_classes)
# Use class_weights if computed; fall back to unweighted if not
try:
    class_weights_tensor = class_weights.to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
    print('Using class-weighted CrossEntropyLoss for ResNet')
except NameError:
    criterion = nn.CrossEntropyLoss()
    print('Class weights not found; using standard CrossEntropyLoss')

optimizer = optim.AdamW(resnet_model.fc.parameters(), lr=0.001, weight_decay=0.01)
scheduler = optim.lr_scheduler.OneCycleLR(
    optimizer,
    max_lr=0.001,
    epochs=num_epochs,
    steps_per_epoch=len(train_loader)
)

print("Stage 1: Training only the classifier")

In [None]:
# Train ResNet50
num_epochs = 15
best_val_f1 = 0  # Changed to track best F1 score instead of accuracy
train_losses, val_losses = [], []
train_acc, val_acc = [], []
train_f1, val_f1 = [], []  # New lists for F1 scores

# Training loop with progressive unfreezing
stages = [
    ('Initial (FC only)', ['fc']),
    ('Stage 2 (Last block)', ['fc', 'layer4']),
    ('Stage 3 (Last 2 blocks)', ['fc', 'layer4', 'layer3'])
]

for stage_name, layers_to_unfreeze in stages:
    print(f"\n{stage_name}")
    
    # Unfreeze specified layers
    for name, param in resnet_model.named_parameters():
        param.requires_grad = any(layer in name for layer in layers_to_unfreeze)
    
    # Update optimizer with unfrozen parameters
    optimizer = optim.AdamW(
        [p for p in resnet_model.parameters() if p.requires_grad],
        lr=0.001,
        weight_decay=0.01
    )
    
    scheduler = optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=0.001,
        epochs=num_epochs,
        steps_per_epoch=len(train_loader)
    )
    
    for epoch in range(num_epochs):
        train_loss, train_accuracy, train_macro_f1 = train_epoch(
            resnet_model, train_loader, criterion, optimizer, scheduler
        )
        val_loss, val_accuracy, val_macro_f1 = validate(resnet_model, val_loader, criterion)
        
        train_losses.append(train_loss)
        val_losses.append(val_loss)
        train_acc.append(train_accuracy)
        val_acc.append(val_accuracy)
        train_f1.append(train_macro_f1)
        val_f1.append(val_macro_f1)
        
        if val_macro_f1 > best_val_f1:  # Save model based on F1 score
            best_val_f1 = val_macro_f1
            torch.save(resnet_model.state_dict(), '/kaggle/working/models/resnet50_best.pth')
        
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, Train Macro-F1: {train_macro_f1:.2f}%')
        print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%, Val Macro-F1: {val_macro_f1:.2f}%')
    
# Plot training progress
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.title('ResNet50 Loss Progress')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 3, 2)
plt.plot(train_acc, label='Train Accuracy')
plt.plot(val_acc, label='Val Accuracy')
plt.title('ResNet50 Accuracy Progress')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()

plt.subplot(1, 3, 3)
plt.plot(train_f1, label='Train Macro-F1')
plt.plot(val_f1, label='Val Macro-F1')
plt.title('ResNet50 Macro-F1 Progress')
plt.xlabel('Epoch')
plt.ylabel('Macro-F1 (%)')
plt.legend()

plt.tight_layout()
plt.show()

### 3.3 Model Architecture: EfficientNet-B0

#### 3.3.1 Implementation Details
EfficientNet-B0 is implemented with:
1. Advanced data augmentation pipeline
2. Progressive learning rate warmup
3. Discriminative learning rates
4. Batch normalization and dropout optimization

#### 3.3.2 Training Strategy
Our training approach includes:
- Gradual layer unfreezing
- Custom learning rate scheduling
- Advanced regularization
- Performance monitoring

Implementation and training code follows:

In [None]:
# Initialize EfficientNet
def create_efficient_model(num_classes):
    model = models.efficientnet_b0(weights=models.EfficientNet_B0_Weights.DEFAULT)
    
    # Freeze all parameters
    for param in model.parameters():
        param.requires_grad = False
        
    # Modify classifier with dropout and batch norm
    model.classifier = nn.Sequential(
        nn.Dropout(p=0.5),
        nn.Linear(1280, 1024),
        nn.ReLU(),
        nn.BatchNorm1d(1024),
        nn.Dropout(p=0.3),
        nn.Linear(1024, 512),
        nn.ReLU(),
        nn.BatchNorm1d(512),
        nn.Dropout(p=0.2),
        nn.Linear(512, num_classes)
    )
    
    return model.to(device)

class GradualWarmupScheduler(torch.optim.lr_scheduler._LRScheduler):
    def __init__(self, optimizer, multiplier, warmup_epochs, total_epochs, after_scheduler=None):
        self.multiplier = multiplier
        self.warmup_epochs = warmup_epochs
        self.total_epochs = total_epochs
        self.after_scheduler = after_scheduler
        super().__init__(optimizer)

    def get_lr(self):
        if self.last_epoch < self.warmup_epochs:
            return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.warmup_epochs + 1.) 
                    for base_lr in self.base_lrs]
        else:
            if self.after_scheduler:
                return self.after_scheduler.get_last_lr()
            return [base_lr * self.multiplier for base_lr in self.base_lrs]

    def step(self, epoch=None):
        if epoch is None:
            epoch = self.last_epoch + 1
        self.last_epoch = epoch
        if self.last_epoch < self.warmup_epochs:
            for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
                param_group['lr'] = lr
        else:
            if self.after_scheduler:
                self.after_scheduler.step(epoch - self.warmup_epochs)

# Initialize model and training parameters
efficient_model = create_efficient_model(num_classes)
# Use class_weights if computed; fall back to unweighted if not
try:
    class_weights_tensor = class_weights.to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
    print('Using class-weighted CrossEntropyLoss for EfficientNet')
except NameError:
    criterion = nn.CrossEntropyLoss()
    print('Class weights not found; using standard CrossEntropyLoss')

# Set different learning rates for different layers
params = [
    {'params': efficient_model.features.parameters(), 'lr': 1e-4},
    {'params': efficient_model.classifier.parameters(), 'lr': 1e-3}
]
optimizer = optim.AdamW(params, weight_decay=0.01)

# Create schedulers
main_scheduler = optim.lr_scheduler.CosineAnnealingLR(
    optimizer, 
    T_max=num_epochs-3,
    eta_min=1e-6
)
scheduler = GradualWarmupScheduler(
    optimizer,
    multiplier=8,
    warmup_epochs=3,
    total_epochs=num_epochs,
    after_scheduler=main_scheduler
)

print("Starting EfficientNet training...")

In [None]:
# Train EfficientNet
num_epochs = 15
best_val_f1 = 0  # Track best F1 score
train_losses, val_losses = [], []
train_acc, val_acc = [], []
train_f1, val_f1 = [], []  # New lists for F1 scores

for epoch in range(num_epochs):
    # Gradually unfreeze layers
    if epoch == 5:
        print("Unfreezing last 2 blocks...")
        for param in efficient_model.features[-2:].parameters():
            param.requires_grad = True
        
    if epoch == 10:
        print("Unfreezing last 4 blocks...")
        for param in efficient_model.features[-4:].parameters():
            param.requires_grad = True
    
    train_loss, train_accuracy, train_macro_f1 = train_epoch(
        efficient_model, train_loader, criterion, optimizer, None, mixup=True
    )
    val_loss, val_accuracy, val_macro_f1 = validate(efficient_model, val_loader, criterion)
    
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_acc.append(train_accuracy)
    val_acc.append(val_accuracy)
    train_f1.append(train_macro_f1)
    val_f1.append(val_macro_f1)
    
    scheduler.step()
    
    if val_macro_f1 > best_val_f1:  # Save model based on F1 score
        best_val_f1 = val_macro_f1
        torch.save(efficient_model.state_dict(), '/kaggle/working/models/efficientnet_best.pth')
    
    print(f'Epoch {epoch+1}/{num_epochs}:')
    print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.2f}%, Train Macro-F1: {train_macro_f1:.2f}%')
    print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.2f}%, Val Macro-F1: {val_macro_f1:.2f}%')
    for param_group in optimizer.param_groups:
        print(f'Learning rate: {param_group["lr"]:.6f}')

# Plot training progress
plt.figure(figsize=(15, 5))
plt.subplot(1, 3, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.title('EfficientNet Loss Progress')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 3, 2)
plt.plot(train_acc, label='Train Accuracy')
plt.plot(val_acc, label='Val Accuracy')
plt.title('EfficientNet Accuracy Progress')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()

plt.subplot(1, 3, 3)
plt.plot(train_f1, label='Train Macro-F1')
plt.plot(val_f1, label='Val Macro-F1')
plt.title('EfficientNet Macro-F1 Progress')
plt.xlabel('Epoch')
plt.ylabel('Macro-F1 (%)')
plt.legend()

plt.tight_layout()
plt.show()

# Plot learning rate schedule
plt.figure(figsize=(10, 4))
lrs = []
optimizer.zero_grad()
for _ in range(num_epochs):
    lrs.append([param_group['lr'] for param_group in optimizer.param_groups])
    scheduler.step()
    
plt.plot(np.array(lrs))
plt.title('Learning Rate Schedule')
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.legend(['Features', 'Classifier'])
plt.show()

### 3.4 Model Architecture: Vision Transformer (ViT)

#### 3.4.1 Implementation Details
Our ViT implementation includes:
1. Pretrained ViT-B/16 adaptation
2. Patch-based image tokenization
3. Custom head for classification
4. Advanced positional embeddings

#### 3.4.2 Training Strategy
The training process involves:
- Gradual fine-tuning of attention layers
- Layer-wise learning rate decay
- Advanced regularization techniques
- Patch dropout for robustness

Implementation and training code follows:

In [None]:
# Initialize ViT model
def create_vit_model(num_classes):
    """
    Create and initialize a Vision Transformer model with custom classification head.
    """
    # Load pretrained ViT-B/16
    model = models.vit_b_16(weights=models.ViT_B_16_Weights.IMAGENET1K_V1)
    
    # Modify the head for our classification task
    num_features = model.heads.head.in_features
    model.heads = nn.Sequential(
        nn.LayerNorm(num_features),
        nn.Dropout(0.2),
        nn.Linear(num_features, 512),
        nn.GELU(),
        nn.Dropout(0.1),
        nn.Linear(512, num_classes)
    )
    
    return model.to(device)

# Layer-wise learning rate decay
def get_optimizer_params(model, weight_decay, lr_init):
    """
    Apply different learning rates to different layers using decay.
    """
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    
    # Separate parameters into decay and no_decay groups
    optimizer_parameters = [
        {
            "params": [
                p for n, p in param_optimizer 
                if not any(nd in n for nd in no_decay) and "heads" not in n
            ],
            "weight_decay": weight_decay,
            "lr": lr_init
        },
        {
            "params": [
                p for n, p in param_optimizer 
                if any(nd in n for nd in no_decay) and "heads" not in n
            ],
            "weight_decay": 0.0,
            "lr": lr_init
        },
        {
            "params": [p for n, p in param_optimizer if "heads" in n],
            "weight_decay": weight_decay,
            "lr": lr_init * 10  # Higher learning rate for classification head
        },
    ]
    
    return optimizer_parameters

# Initialize model and training parameters
print("Initializing ViT model...")
vit_model = create_vit_model(num_classes)
# Use class_weights if computed; fall back to unweighted if not
try:
    class_weights_tensor = class_weights.to(device)
    criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
    print('Using class-weighted CrossEntropyLoss for ViT')
except NameError:
    criterion = nn.CrossEntropyLoss()
    print('Class weights not found; using standard CrossEntropyLoss')

# Set up optimizer with layer-wise learning rate decay
optimizer_params = get_optimizer_params(
    vit_model,
    weight_decay=0.01,
    lr_init=1e-4
)
optimizer = optim.AdamW(optimizer_params)

# Create scheduler with warm-up and cosine decay
scheduler = get_cosine_schedule_with_warmup(
    optimizer,
    num_warmup_steps=len(train_loader) * 2,  # 2 epochs of warmup
    num_training_steps=len(train_loader) * num_epochs
)

# Training configuration
print("\nTraining Configuration:")
print(f"Number of classes: {num_classes}")
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Batch size: {batch_size}")
print(f"Number of epochs: {num_epochs}")
print(f"Device: {device}")

# Training loop
print("\nStarting ViT training...")
best_val_f1 = 0
train_losses, val_losses = [], []
train_f1s, val_f1s = [], []

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    # Training phase
    vit_model.train()
    train_loss = 0
    train_preds, train_labels = [], []
    
    for batch_idx, (images, labels) in enumerate(tqdm(train_loader, desc="Training")):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = vit_model(images)
        loss = criterion(outputs, labels)
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(vit_model.parameters(), max_norm=1.0)
        optimizer.step()
        scheduler.step()
        
        train_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        train_preds.extend(predicted.cpu().numpy())
        train_labels.extend(labels.cpu().numpy())
        
    # Calculate training metrics
    train_loss = train_loss / len(train_loader)
    train_f1 = f1_score(train_labels, train_preds, average='macro')
    
    # Validation phase
    vit_model.eval()
    val_loss = 0
    val_preds, val_labels = [], []
    
    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc="Validation"):
            images, labels = images.to(device), labels.to(device)
            outputs = vit_model(images)
            loss = criterion(outputs, labels)
            
            val_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            val_preds.extend(predicted.cpu().numpy())
            val_labels.extend(labels.cpu().numpy())
    
    # Calculate validation metrics
    val_loss = val_loss / len(val_loader)
    val_f1 = f1_score(val_labels, val_preds, average='macro')
    
    # Store metrics
    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_f1s.append(train_f1)
    val_f1s.append(val_f1)
    
    # Save best model
    if val_f1 > best_val_f1:
        best_val_f1 = val_f1
        print("Saving best model...")
        torch.save(vit_model.state_dict(), '/kaggle/working/models/vit_best.pth')
    
    # Print epoch metrics
    print(f"Train Loss: {train_loss:.4f}, Train Macro-F1: {train_f1:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Macro-F1: {val_f1:.4f}")
    
    # Print learning rates
    print("Learning rates:")
    for i, param_group in enumerate(optimizer.param_groups):
        print(f"Group {i}: {param_group['lr']:.6f}")

# Plot training progress
plt.figure(figsize=(15, 5))

# Loss plot
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Val Loss')
plt.title('ViT Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

# F1 score plot
plt.subplot(1, 2, 2)
plt.plot(train_f1s, label='Train Macro-F1')
plt.plot(val_f1s, label='Val Macro-F1')
plt.title('ViT Training and Validation Macro-F1')
plt.xlabel('Epoch')
plt.ylabel('Macro-F1 Score')
plt.legend()

plt.tight_layout()
plt.show()

print("\nVision Transformer training completed!")
print(f"Best validation Macro-F1: {best_val_f1:.4f}")

## 4. Results and Analysis

### 4.1 Experimental Setup
- Hardware: Kaggle GPU runtime
- Software: PyTorch, torchvision, Ultralytics
- Metrics: Accuracy, Macro-F1, Inference Time
- Cross-validation: Stratified validation split

### 4.2 Quantitative Analysis
We evaluate models on multiple metrics:
1. Macro-F1 scores for balanced performance
2. Per-class precision and recall
3. Confusion matrices
4. Inference speed benchmarks

### 4.3 Comparative Analysis
Detailed comparison of models across:
- Overall performance metrics
- Disease-specific accuracy
- Computational efficiency
- Model robustness

The following section presents our comprehensive evaluation:

In [None]:
# Evaluation utilities
def evaluate_model(model, loader, name=""):
    model.eval()
    all_preds = []
    all_labels = []
    all_probs = []
    inference_times = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(loader, desc=f"Evaluating {name}"):
            inputs, labels = inputs.to(device), labels.to(device)
            
            # Measure inference time
            start_time = time.time()
            outputs = model(inputs)
            inference_times.append(time.time() - start_time)
            
            probs = torch.softmax(outputs, dim=1)
            _, preds = outputs.max(1)
            
            all_probs.extend(probs.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    # Calculate metrics
    report = classification_report(all_labels, all_preds, 
                                target_names=list(label_names.values()), 
                                output_dict=True)
    cm = confusion_matrix(all_labels, all_preds)
    avg_inference_time = np.mean(inference_times) * 1000  # Convert to ms
    
    return report, cm, avg_inference_time, np.array(all_probs), np.array(all_labels)

# Evaluate CNN models
models = {
    'ResNet50': resnet_model,
    'EfficientNet': efficient_model
}

results = {}
for name, model in models.items():
    report, cm, inf_time, probs, labels = evaluate_model(model, val_loader, name)
    results[name] = {
        'report': report,
        'confusion_matrix': cm,
        'inference_time': inf_time,
        'probabilities': probs,
        'true_labels': labels
    }

# Add YOLO results
yolo_metrics = yolo_model.val()
results['YOLOv8'] = {
    'report': yolo_metrics.results_dict,
    'confusion_matrix': yolo_metrics.confusion_matrix.matrix,
    'inference_time': yolo_metrics.speed['inference']
}

# Create comparison visualizations
plt.figure(figsize=(15, 5))

# 1. Macro F1-Score Comparison
plt.subplot(131)
macro_f1_scores = []
for model in models:
    macro_f1_scores.append(results[model]['report']['macro avg']['f1-score'] * 100)
macro_f1_scores.append(results['YOLOv8']['report']['metrics/f1'] * 100)  # YOLO's F1

plt.bar(['ResNet50', 'EfficientNet', 'YOLOv8'], macro_f1_scores)
plt.title('Macro F1-Score Comparison')
plt.ylabel('Macro F1-Score (%)')
plt.ylim(0, 100)

# 2. Accuracy Comparison
plt.subplot(132)
accuracies = [results[model]['report']['accuracy'] * 100 for model in models]
accuracies.append(results['YOLOv8']['report']['metrics/accuracy'] * 100)
plt.bar(['ResNet50', 'EfficientNet', 'YOLOv8'], accuracies)
plt.title('Model Accuracy Comparison')
plt.ylabel('Accuracy (%)')
plt.ylim(0, 100)

# 3. Inference Time Comparison
plt.subplot(133)
inf_times = [results[model]['inference_time'] for model in list(models.keys()) + ['YOLOv8']]
plt.bar(['ResNet50', 'EfficientNet', 'YOLOv8'], inf_times)
plt.title('Inference Time Comparison')
plt.ylabel('Time per batch (ms)')

plt.tight_layout()
plt.show()

# Plot per-class F1 scores
plt.figure(figsize=(12, 6))
disease_names = [name[:15] for name in label_names.values()]
x = np.arange(len(disease_names))
width = 0.25

for i, (name, model_results) in enumerate(results.items()):
    if name != 'YOLOv8':
        f1_scores = [model_results['report'][cls]['f1-score'] * 100 for cls in label_names.values()]
    else:
        # Assuming YOLO provides per-class metrics
        f1_scores = [model_results['report'].get(f'metrics/f1_class{i}', 0) * 100 for i in range(len(disease_names))]
    
    plt.bar(x + i*width, f1_scores, width, label=name)

plt.xlabel('Disease Categories')
plt.ylabel('F1-Score (%)')
plt.title('Per-Class F1 Scores')
plt.xticks(x + width, disease_names, rotation=45, ha='right')
plt.legend()
plt.tight_layout()
plt.show()

# Plot confusion matrices
fig, axes = plt.subplots(1, 3, figsize=(20, 6))
for idx, (name, model_results) in enumerate(results.items()):
    sns.heatmap(model_results['confusion_matrix'], 
                annot=True, 
                fmt='d',
                ax=axes[idx],
                cmap='Blues',
                xticklabels=[name[:3] for name in label_names.values()],
                yticklabels=[name[:3] for name in label_names.values()])
    axes[idx].set_title(f'{name} Confusion Matrix')
    axes[idx].set_xlabel('Predicted')
    axes[idx].set_ylabel('True')
    plt.setp(axes[idx].get_xticklabels(), rotation=45, ha='right')
    
plt.tight_layout()
plt.show()

# Print detailed results with emphasis on macro-F1
print("\nDetailed Model Comparison:")
print("-" * 80)
comparison_data = []
for name, model_results in results.items():
    if name != 'YOLOv8':
        macro_f1 = model_results['report']['macro avg']['f1-score'] * 100
        accuracy = model_results['report']['accuracy'] * 100
        weighted_f1 = model_results['report']['weighted avg']['f1-score'] * 100
    else:
        macro_f1 = results['YOLOv8']['report']['metrics/f1'] * 100
        accuracy = results['YOLOv8']['report']['metrics/accuracy'] * 100
        weighted_f1 = results['YOLOv8']['report'].get('metrics/f1_weighted', macro_f1)
    
    comparison_data.append({
        'Model': name,
        'Macro F1-Score (%)': f"{macro_f1:.2f}",
        'Accuracy (%)': f"{accuracy:.2f}",
        'Weighted F1-Score (%)': f"{weighted_f1:.2f}",
        'Inference Time (ms)': f"{model_results['inference_time']:.2f}"
    })

# Display comparison table
comparison_df = pd.DataFrame(comparison_data)
print("\nModel Performance Summary:")
print(comparison_df.to_string(index=False))

# Save detailed results
comparison_df.to_csv('/kaggle/working/model_comparison_results.csv', index=False)
print("\nDetailed results saved to 'model_comparison_results.csv'")

# Print per-class performance
print("\nPer-Class Performance:")
print("-" * 80)
for name, model_results in results.items():
    if name != 'YOLOv8':
        print(f"\n{name} Per-Class Metrics:")
        class_metrics = pd.DataFrame({
            'F1-Score (%)': {k: v['f1-score']*100 for k, v in model_results['report'].items() 
                            if k in label_names.values()},
            'Precision (%)': {k: v['precision']*100 for k, v in model_results['report'].items() 
                             if k in label_names.values()},
            'Recall (%)': {k: v['recall']*100 for k, v in model_results['report'].items() 
                          if k in label_names.values()}
        })
        print(class_metrics.round(2))

## 5. Discussion and Future Work

### 5.1 Key Findings
Our experimental results reveal several important insights:
1. Model Performance Trade-offs
   - Accuracy vs. computational efficiency
   - Architecture-specific strengths
   - Disease-specific detection capabilities

2. Technical Contributions
   - Effective training strategies
   - Optimal hyperparameter configurations
   - Performance optimization techniques

### 5.2 Limitations
Current limitations include:
- Dataset constraints
- Computational resource requirements
- Real-world deployment considerations

### 5.3 Future Directions
Potential areas for future research:
1. Model Improvements
   - Ensemble methods
   - Architecture modifications
   - Optimization techniques

2. Practical Applications
   - Mobile deployment
   - Real-time processing
   - Edge device implementation

### 5.4 Conclusion
This study demonstrates the effectiveness of deep learning approaches in cassava disease classification, with each architecture offering unique advantages. The results provide valuable insights for implementing similar systems in agricultural applications.

### References
1. He, K., et al. (2016). Deep Residual Learning for Image Recognition
2. Tan, M., & Le, Q. (2019). EfficientNet: Rethinking Model Scaling for CNNs
3. Jocher, G., et al. (2023). Ultralytics YOLOv8