In [None]:
import torch
from torch.utils.data import Dataset, DataLoader, random_split
from PIL import Image
import pandas as pd
import os
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
import numpy as np

class AMHCDDataset(Dataset):
    def __init__(self, csv_file, root_dir, transform=None):
        self.labels_df = pd.read_csv(csv_file, header=None, names=['image_path', 'label'])
        self.root_dir = root_dir
        self.transform = transform
        # Map unique labels to indices
        self.label_map = {label: idx for idx, label in enumerate(sorted(self.labels_df['label'].unique()))}
        self.idx_to_label = {idx: label for label, idx in self.label_map.items()}
    
    def __len__(self):
        return len(self.labels_df)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.root_dir, self.labels_df.iloc[idx]['image_path'])
        image = Image.open(img_path).convert('L')  # Grayscale
        label = self.labels_df.iloc[idx]['label']
        label_idx = self.label_map[label]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label_idx

class LeNet5Tifinagh(nn.Module):
    def __init__(self, num_classes):
        super(LeNet5Tifinagh, self).__init__()
        
        # Convolutional layers (feature extraction)
        self.conv_layers = nn.Sequential(
            # C1: First convolutional layer
            nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, stride=1, padding=0),
            nn.Tanh(),  # Original LeNet-5 uses Tanh
            
            # S2: First subsampling layer (average pooling)
            nn.AvgPool2d(kernel_size=2, stride=2),
            
            # C3: Second convolutional layer
            nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5, stride=1, padding=0),
            nn.Tanh(),
            
            # S4: Second subsampling layer
            nn.AvgPool2d(kernel_size=2, stride=2),
            
            # C5: Third convolutional layer (acts like fully connected)
            nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5, stride=1, padding=0),
            nn.Tanh()
        )
        
        # Calculate the flattened size after conv layers
        # For 64x64 input: 64->60->30->26->13->9 (with kernel=5, no padding)
        # For 32x32 input: 32->28->14->10->5->1
        # We need to adapt for 64x64 input
        self.conv_output_size = self._get_conv_output_size()
        
        # Fully connected layers
        self.fc_layers = nn.Sequential(
            nn.Flatten(),
            # F6: First fully connected layer
            nn.Linear(self.conv_output_size, 84),
            nn.Tanh(),
            
            # Output layer
            nn.Linear(84, num_classes)
        )
    
    def _get_conv_output_size(self):
        # Create a dummy input to calculate the output size
        with torch.no_grad():
            dummy_input = torch.zeros(1, 1, 64, 64)  # Batch size 1, 1 channel, 64x64
            dummy_output = self.conv_layers(dummy_input)
            return dummy_output.numel()  # Total number of elements
    
    def forward(self, x):
        x = self.conv_layers(x)
        x = self.fc_layers(x)
        return x

# Define transforms - resize to work better with LeNet-5 style
transform = transforms.Compose([
    transforms.Resize((64, 64)),  # Keep 64x64 but ensure consistent size
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))  # Normalize to [-1, 1]
])

# Create dataset
dataset = AMHCDDataset(
    csv_file='amhcd-data-64/labels-map.csv',
    root_dir='amhcd-data-64',
    transform=transform
)

# Split dataset
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)

# Number of classes
num_classes = len(dataset.label_map)
print(f"Number of classes: {num_classes}")
print(f"Training samples: {train_size}, Validation samples: {val_size}")

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Initialize LeNet-5 model
model = LeNet5Tifinagh(num_classes=num_classes)
model = model.to(device)

# Print model architecture
print("\nLeNet-5 Architecture:")
print(model)

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\nTotal parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")

# Loss and optimizer (using SGD as in original LeNet-5)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)  # SGD with momentum

# Training function
def train_epoch(model, dataloader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for batch_idx, (images, labels) in enumerate(dataloader):
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        
        if batch_idx % 50 == 0:
            print(f'  Batch [{batch_idx}/{len(dataloader)}], Loss: {loss.item():.4f}')
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc

# Validation function
def validate_epoch(model, dataloader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc, all_predictions, all_labels

# Training loop
num_epochs = 10  # Slightly more epochs for LeNet-5
train_losses = []
train_accuracies = []
val_losses = []
val_accuracies = []

print("\nStarting LeNet-5 Training...")
print("=" * 60)

for epoch in range(num_epochs):
    print(f'\nEpoch [{epoch+1}/{num_epochs}]')
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    
    # Validate
    val_loss, val_acc, val_predictions, val_labels = validate_epoch(model, val_loader, criterion, device)
    
    # Store metrics
    train_losses.append(train_loss)
    train_accuracies.append(train_acc)
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    
    print(f'Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
    print(f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
    print('-' * 50)

# Plot training curves
def plot_training_curves(train_losses, val_losses, train_accuracies, val_accuracies):
    plt.style.use('seaborn-v0_8')
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
    
    # Loss curve
    epochs = range(1, len(train_losses) + 1)
    ax1.plot(epochs, train_losses, 'b-', label='Training Loss', linewidth=2, marker='o')
    ax1.plot(epochs, val_losses, 'r-', label='Validation Loss', linewidth=2, marker='s')
    ax1.set_title('LeNet-5 Training and Validation Loss', fontsize=14, fontweight='bold')
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Loss', fontsize=12)
    ax1.legend(fontsize=11)
    ax1.grid(True, alpha=0.3)
    ax1.set_xlim(1, len(train_losses))
    
    # Accuracy curve
    ax2.plot(epochs, train_accuracies, 'b-', label='Training Accuracy', linewidth=2, marker='o')
    ax2.plot(epochs, val_accuracies, 'r-', label='Validation Accuracy', linewidth=2, marker='s')
    ax2.set_title('LeNet-5 Training and Validation Accuracy', fontsize=14, fontweight='bold')
    ax2.set_xlabel('Epoch', fontsize=12)
    ax2.set_ylabel('Accuracy (%)', fontsize=12)
    ax2.legend(fontsize=11)
    ax2.grid(True, alpha=0.3)
    ax2.set_xlim(1, len(train_accuracies))
    ax2.set_ylim(0, 100)
    
    plt.tight_layout()
    plt.show()

# Plot confusion matrix
def plot_confusion_matrix(y_true, y_pred, class_names):
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(14, 12))
    
    # Calculate percentages for better readability
    cm_percent = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] * 100
    
    # Create annotations that show both count and percentage
    annotations = np.empty_like(cm).astype(str)
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            annotations[i, j] = f'{cm[i, j]}\n({cm_percent[i, j]:.1f}%)'
    
    sns.heatmap(cm, annot=annotations, fmt='', cmap='Blues', 
                xticklabels=class_names, yticklabels=class_names,
                cbar_kws={'label': 'Number of Samples'})
    
    plt.title('LeNet-5 Confusion Matrix\n(Count and Percentage)', fontsize=16, fontweight='bold', pad=20)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.ylabel('True Label', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()
    
    return cm

# Generate detailed evaluation report
def generate_evaluation_report(y_true, y_pred, class_names):
    report = classification_report(y_true, y_pred, target_names=class_names, output_dict=True)
    
    print("\n" + "=" * 80)
    print("LENET-5 TIFINAGH CHARACTER RECOGNITION - EVALUATION REPORT")
    print("=" * 80)
    
    print("\nDETAILED CLASSIFICATION REPORT:")
    print("-" * 50)
    print(classification_report(y_true, y_pred, target_names=class_names))
    
    # Overall metrics
    accuracy = report['accuracy']
    macro_avg = report['macro avg']
    weighted_avg = report['weighted avg']
    
    print("SUMMARY METRICS:")
    print("-" * 30)
    print(f"Overall Accuracy: {accuracy:.4f} ({accuracy*100:.2f}%)")
    print(f"Macro Average Precision: {macro_avg['precision']:.4f}")
    print(f"Macro Average Recall: {macro_avg['recall']:.4f}")
    print(f"Macro Average F1-Score: {macro_avg['f1-score']:.4f}")
    print(f"Weighted Average Precision: {weighted_avg['precision']:.4f}")
    print(f"Weighted Average Recall: {weighted_avg['recall']:.4f}")
    print(f"Weighted Average F1-Score: {weighted_avg['f1-score']:.4f}")
    
    # Per-class analysis
    print("\nPER-CLASS PERFORMANCE ANALYSIS:")
    print("-" * 40)
    class_metrics = []
    for class_name in class_names:
        if class_name in report:
            metrics = report[class_name]
            class_metrics.append({
                'class': class_name,
                'precision': metrics['precision'],
                'recall': metrics['recall'],
                'f1': metrics['f1-score'],
                'support': metrics['support']
            })
    
    # Sort by F1-score
    class_metrics.sort(key=lambda x: x['f1'], reverse=True)
    
    print(f"{'Class':<15} {'Precision':<10} {'Recall':<10} {'F1-Score':<10} {'Support':<10}")
    print("-" * 65)
    for metrics in class_metrics:
        print(f"{metrics['class']:<15} {metrics['precision']:<10.3f} {metrics['recall']:<10.3f} "
              f"{metrics['f1']:<10.3f} {metrics['support']:<10}")
    
    # Best and worst performing classes
    best_class = class_metrics[0]
    worst_class = class_metrics[-1]
    
    print(f"\nBEST PERFORMING CLASS: {best_class['class']} (F1: {best_class['f1']:.3f})")
    print(f"WORST PERFORMING CLASS: {worst_class['class']} (F1: {worst_class['f1']:.3f})")
    
    print("=" * 80)
    
    return report

# Final evaluation
print("\n" + "="*80)
print("FINAL EVALUATION - LENET-5 ON TIFINAGH CHARACTERS")
print("="*80)

# Get final predictions on validation set
model.eval()
final_predictions = []
final_labels = []

with torch.no_grad():
    for images, labels in val_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        
        final_predictions.extend(predicted.cpu().numpy())
        final_labels.extend(labels.cpu().numpy())

# Get class names
class_names = [dataset.idx_to_label[i] for i in range(num_classes)]

# Create all visualizations
print("Generating visualizations...")

# 1. Training curves
plot_training_curves(train_losses, val_losses, train_accuracies, val_accuracies)

# 2. Confusion matrix
cm = plot_confusion_matrix(final_labels, final_predictions, class_names)

# 3. Evaluation report
report = generate_evaluation_report(final_labels, final_predictions, class_names)

# Training summary
print(f"\nTRAINING SUMMARY:")
print(f"Best validation accuracy: {max(val_accuracies):.2f}%")
print(f"Final validation accuracy: {val_accuracies[-1]:.2f}%")
print(f"Lowest validation loss: {min(val_losses):.4f}")
print(f"Total parameters: {total_params:,}")

Number of classes: 33
Training samples: 22545, Validation samples: 5637
Using device: cpu

LeNet-5 Architecture:
LeNet5Tifinagh(
  (conv_layers): Sequential(
    (0): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))
    (1): Tanh()
    (2): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (3): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
    (4): Tanh()
    (5): AvgPool2d(kernel_size=2, stride=2, padding=0)
    (6): Conv2d(16, 120, kernel_size=(5, 5), stride=(1, 1))
    (7): Tanh()
  )
  (fc_layers): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=9720, out_features=84, bias=True)
    (2): Tanh()
    (3): Linear(in_features=84, out_features=33, bias=True)
  )
)

Total parameters: 870,061
Trainable parameters: 870,061

Starting LeNet-5 Training...

Epoch [1/10]
  Batch [0/705], Loss: 3.4762
  Batch [50/705], Loss: 2.6460
  Batch [100/705], Loss: 2.3204
  Batch [150/705], Loss: 1.4223
  Batch [200/705], Loss: 1.4796
  Batch [250/705], Loss: 1.2987
  