In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms
from PIL import Image
import os
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
from tqdm import tqdm

# Your CustomCNN class (unchanged)
class CustomCNN(nn.Module):
    def __init__(self, input_channels: int, nodes: list, kernels: list, num_classes: int = 10):
        super(CustomCNN, self).__init__()
        assert len(nodes) == len(kernels), "nodes and kernels must have the same length"
        
        layers = []
        in_channels = input_channels
        
        for out_channels, k in zip(nodes, kernels):
            conv = nn.Conv2d(in_channels, out_channels, kernel_size=k, padding=k//2)
            layers.append(conv)
            layers.append(nn.ReLU())
            layers.append(nn.MaxPool2d(2))
            in_channels = out_channels
        
        self.conv_layers = nn.Sequential(*layers)
        self.flatten = nn.Flatten()
        self.fc = nn.Linear(nodes[-1], num_classes)
    
    def forward(self, x):
        x = self.conv_layers(x)
        x = F.adaptive_avg_pool2d(x, (1, 1))
        x = self.flatten(x)
        x = self.fc(x)
        return x


# Dataset class
class FacialExpressionDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.samples = []
        self.emotion_labels = ['Anger', 'Contempt', 'Disgust', 'Fear', 'Happy', 'Neutral', 'Sad', 'Surprised']
        
        # Load all image paths and labels
        for person_id in os.listdir(root_dir):
            person_path = os.path.join(root_dir, person_id)
            if not os.path.isdir(person_path):
                continue
            
            for emotion_file in os.listdir(person_path):
                if not emotion_file.endswith('.jpg'):
                    continue
                
                emotion_name = emotion_file.replace('.jpg', '')
                if emotion_name in self.emotion_labels:
                    img_path = os.path.join(person_path, emotion_file)
                    label = self.emotion_labels.index(emotion_name)
                    self.samples.append((img_path, label))
        
        print(f"Loaded {len(self.samples)} images")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path)
        
        if self.transform:
            image = self.transform(image)
        
        return image, label


def preprocessing(
    data_dir='archive/images',
    image_size=64,
    grayscale=False,  # Set to True for grayscale, False for RGB
    batch_size=32,
    train_split=0.8,
    augment=True,
    random_seed=42
):
    """
    Preprocess the facial expression dataset.
    
    Parameters:
    - data_dir: path to the images folder
    - image_size: resize images to this size (square)
    - grayscale: if True, convert to grayscale (1 channel), else RGB (3 channels)
    - batch_size: batch size for DataLoader
    - train_split: proportion of data for training (rest for validation)
    - augment: whether to apply data augmentation
    - random_seed: for reproducibility
    """
    torch.manual_seed(random_seed)
    np.random.seed(random_seed)
    
    # Define transforms
    transform_list = []
    
    if grayscale:
        transform_list.append(transforms.Grayscale(num_output_channels=1))
    
    transform_list.append(transforms.Resize((image_size, image_size)))
    
    if augment:
        transform_list.extend([
            transforms.RandomHorizontalFlip(p=0.5),
            transforms.RandomRotation(10),
            transforms.ColorJitter(brightness=0.2, contrast=0.2)
        ])
    
    transform_list.extend([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5] * (1 if grayscale else 3),
                           std=[0.5] * (1 if grayscale else 3))
    ])
    
    train_transform = transforms.Compose(transform_list)
    
    # For validation, no augmentation
    val_transform_list = [
        transforms.Grayscale(num_output_channels=1) if grayscale else lambda x: x,
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5] * (1 if grayscale else 3),
                           std=[0.5] * (1 if grayscale else 3))
    ]
    val_transform_list = [t for t in val_transform_list if callable(t) or isinstance(t, transforms.Compose)]
    val_transform = transforms.Compose(val_transform_list)
    
    # Load full dataset
    full_dataset = FacialExpressionDataset(data_dir, transform=train_transform)
    
    # Split into train and validation
    train_size = int(train_split * len(full_dataset))
    val_size = len(full_dataset) - train_size
    
    train_dataset, val_dataset = torch.utils.data.random_split(
        full_dataset, [train_size, val_size],
        generator=torch.Generator().manual_seed(random_seed)
    )
    
    # Update validation dataset transform
    val_dataset.dataset.transform = val_transform
    
    # Create DataLoaders
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=0)
    
    input_channels = 1 if grayscale else 3
    num_classes = 8  # 8 emotions
    
    print(f"\nDataset Info:")
    print(f"  Total samples: {len(full_dataset)}")
    print(f"  Training samples: {train_size}")
    print(f"  Validation samples: {val_size}")
    print(f"  Input channels: {input_channels}")
    print(f"  Number of classes: {num_classes}")
    print(f"  Image size: {image_size}x{image_size}")
    
    return train_loader, val_loader, input_channels, num_classes


def train_model(
    model,
    train_loader,
    val_loader,
    num_epochs=20,
    learning_rate=0.001,
    device='cuda' if torch.cuda.is_available() else 'cpu'
):
    """Train the model and return training history."""
    model = model.to(device)
    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': []
    }
    
    print(f"\nTraining on {device}")
    print("=" * 60)
    
    for epoch in range(num_epochs):
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        train_bar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs} [Train]')
        for images, labels in train_bar:
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            train_total += labels.size(0)
            train_correct += (predicted == labels).sum().item()
            
            train_bar.set_postfix({'loss': f'{loss.item():.4f}'})
        
        train_loss = train_loss / len(train_loader)
        train_acc = 100 * train_correct / train_total
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item()
                _, predicted = torch.max(outputs.data, 1)
                val_total += labels.size(0)
                val_correct += (predicted == labels).sum().item()
        
        val_loss = val_loss / len(val_loader)
        val_acc = 100 * val_correct / val_total
        
        # Store history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'  Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%')
        print(f'  Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%')
        print('-' * 60)
    
    return history


def evaluate_model(model, val_loader, device='cuda' if torch.cuda.is_available() else 'cpu'):
    """Evaluate model and return predictions and labels."""
    model = model.to(device)
    model.eval()
    
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in val_loader:
            images = images.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())
    
    return np.array(all_preds), np.array(all_labels)


def plot_results(history, predictions, true_labels, emotion_labels):
    """Plot training history and confusion matrix."""
    fig = plt.figure(figsize=(18, 5))
    
    # Plot training history
    ax1 = plt.subplot(1, 3, 1)
    plt.plot(history['train_loss'], label='Train Loss', marker='o')
    plt.plot(history['val_loss'], label='Val Loss', marker='s')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Loss')
    plt.legend()
    plt.grid(True)
    
    ax2 = plt.subplot(1, 3, 2)
    plt.plot(history['train_acc'], label='Train Acc', marker='o')
    plt.plot(history['val_acc'], label='Val Acc', marker='s')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.title('Training and Validation Accuracy')
    plt.legend()
    plt.grid(True)
    
    # Plot confusion matrix
    ax3 = plt.subplot(1, 3, 3)
    cm = confusion_matrix(true_labels, predictions)
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
                xticklabels=emotion_labels, 
                yticklabels=emotion_labels,
                cbar_kws={'label': 'Count'})
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    
    plt.tight_layout()
    plt.show()
    
    # Print classification report
    print("\nClassification Report:")
    print("=" * 60)
    print(classification_report(true_labels, predictions, 
                                target_names=emotion_labels, 
                                digits=3))


# =============================================================================
# MAIN EXECUTION
# =============================================================================

if __name__ == "__main__":
    # CONFIGURABLE PARAMETERS
    IMAGE_SIZE = 64          # Size to resize images to
    GRAYSCALE = False        # True for grayscale (1 channel), False for RGB (3 channels)
    BATCH_SIZE = 32          # Batch size
    NUM_EPOCHS = 20          # Number of training epochs
    LEARNING_RATE = 0.001    # Learning rate
    TRAIN_SPLIT = 0.8        # Train/validation split ratio
    AUGMENT = True           # Whether to use data augmentation
    
    # CNN Architecture parameters
    CNN_NODES = [32, 64, 128]      # Number of filters per conv layer
    CNN_KERNELS = [3, 3, 5]        # Kernel sizes per conv layer
    
    print("=" * 60)
    print("FACIAL EXPRESSION CLASSIFICATION")
    print("=" * 60)
    
    # Step 1: Preprocessing
    print("\n[1/4] Preprocessing data...")
    train_loader, val_loader, input_channels, num_classes = preprocessing(
        data_dir='archive/images',
        image_size=IMAGE_SIZE,
        grayscale=GRAYSCALE,
        batch_size=BATCH_SIZE,
        train_split=TRAIN_SPLIT,
        augment=AUGMENT,
        random_seed=42
    )
    
    # Step 2: Create model
    print("\n[2/4] Creating model...")
    model = CustomCNN(
        input_channels=input_channels,
        nodes=CNN_NODES,
        kernels=CNN_KERNELS,
        num_classes=num_classes
    )
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {total_params:,}")
    print(f"Trainable parameters: {trainable_params:,}")
    print(f"\nModel architecture:")
    print(model)
    
    # Step 3: Train model
    print("\n[3/4] Training model...")
    history = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        num_epochs=NUM_EPOCHS,
        learning_rate=LEARNING_RATE
    )
    
    # Step 4: Evaluate and visualize
    print("\n[4/4] Evaluating model...")
    predictions, true_labels = evaluate_model(model, val_loader)
    
    emotion_labels = ['Anger', 'Contempt', 'Disgust', 'Fear', 'Happy', 'Neutral', 'Sad', 'Surprised']
    
    final_val_acc = history['val_acc'][-1]
    print(f"\nFinal Validation Accuracy: {final_val_acc:.2f}%")
    
    plot_results(history, predictions, true_labels, emotion_labels)
    
    print("\n" + "=" * 60)
    print("TRAINING COMPLETE!")
    print("=" * 60)