# Complete Data Handling for Visual Emotion Recognition

This notebook contains all data handling functionality for visual emotion recognition, consolidating functionality from the src/data directory.

## Components Included:
1. **EmotionDataset** - Custom PyTorch dataset for emotion recognition
2. **Data Transformations** - Image preprocessing and augmentation
3. **Data Loaders** - Training, validation and test data loaders
4. **Dataset Statistics** - Computing dataset statistics and visualization
5. **Advanced Augmentations** - Albumentations integration


In [None]:
import os
import sys
from pathlib import Path
import pandas as pd
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from collections import Counter
import warnings
warnings.filterwarnings('ignore')

# PyTorch imports
import torch
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms

# Advanced augmentation
try:
    import albumentations as A
    from albumentations.pytorch import ToTensorV2
    ALBUMENTATIONS_AVAILABLE = True
    print("Albumentations available for advanced augmentations")
except ImportError:
    ALBUMENTATIONS_AVAILABLE = False
    print("Albumentations not available, using torchvision transforms only")

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 1. Custom Emotion Dataset Class

In [None]:
class EmotionDataset(Dataset):
    """
    Custom dataset class for emotion recognition that supports both grayscale and RGB images.
    
    This dataset is designed to work with both traditional CNN models (grayscale) and 
    transfer learning models (RGB) for visual emotion recognition.
    """
    
    def __init__(self, dataframe, root_dir, transform=None, label_map=None, rgb=False):
        """
        Initialize the emotion dataset.
        
        Args:
            dataframe (pd.DataFrame): DataFrame containing image paths and labels
            root_dir (str or Path): Root directory containing images
            transform (callable, optional): Optional transform to be applied on images
            label_map (dict, optional): Dictionary mapping emotion names to indices
            rgb (bool): If True, convert images to RGB; if False, use grayscale
        """
        self.df = dataframe.reset_index(drop=True)
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.label_map = label_map
        self.rgb = rgb
        
        # Auto-detect column names for flexibility
        self._detect_columns()
        
        # Create label map if not provided
        if self.label_map is None:
            self.label_map = self._create_label_map()
        
        # Add numeric labels if not present
        if 'label_idx' not in self.df.columns:
            self.df['label_idx'] = self.df[self.label_col].map(self.label_map)
        
        print(f"Dataset initialized:")
        print(f"- Samples: {len(self.df)}")
        print(f"- Image mode: {'RGB' if rgb else 'Grayscale'}")
        print(f"- Path column: '{self.path_col}'")
        print(f"- Label column: '{self.label_col}'")
        print(f"- Classes: {list(self.label_map.keys())}")
        print(f"- Class distribution:")
        self._print_class_distribution()
    
    def _detect_columns(self):
        """Auto-detect image path and label columns."""
        # Find path/image column
        path_candidates = [c for c in self.df.columns if any(
            keyword in c.lower() for keyword in ['path', 'file', 'image', 'img']
        )]
        
        if path_candidates:
            self.path_col = path_candidates[0]
        else:
            # Fallback to first column
            self.path_col = self.df.columns[0]
        
        # Find label/emotion column
        label_candidates = [c for c in self.df.columns if any(
            keyword in c.lower() for keyword in ['label', 'emotion', 'class', 'target']
        )]
        
        if label_candidates:
            self.label_col = label_candidates[0]
        else:
            # Fallback to second column if available
            if len(self.df.columns) > 1:
                self.label_col = self.df.columns[1]
            else:
                raise ValueError("Cannot detect label column")
    
    def _create_label_map(self):
        """Create label mapping from unique labels in the dataset."""
        unique_labels = sorted(self.df[self.label_col].unique())
        return {label: idx for idx, label in enumerate(unique_labels)}
    
    def _print_class_distribution(self):
        """Print class distribution."""
        for emotion, count in self.df[self.label_col].value_counts().items():
            percentage = (count / len(self.df)) * 100
            print(f"  {emotion}: {count} ({percentage:.1f}%)")
    
    def __len__(self):
        """Return the total number of samples."""
        return len(self.df)
    
    def __getitem__(self, idx):
        """
        Get a sample from the dataset.
        
        Args:
            idx (int): Index of the sample
            
        Returns:
            tuple: (image, label) where image is a torch.Tensor and label is an int
        """
        if torch.is_tensor(idx):
            idx = idx.tolist()
        
        # Get image path and label
        row = self.df.iloc[idx]
        img_path = self.root_dir / row[self.path_col]
        label = row['label_idx'] if 'label_idx' in row else self.label_map[row[self.label_col]]
        
        # Load image
        try:
            image = Image.open(img_path)
            
            # Convert to appropriate mode
            if self.rgb:
                image = image.convert('RGB')
            else:
                image = image.convert('L')  # Grayscale
            
            # Apply transforms
            if self.transform:
                image = self.transform(image)
            
            return image, label
            
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a default image and label
            if self.rgb:
                default_image = torch.zeros(3, 224, 224)
            else:
                default_image = torch.zeros(1, 48, 48)
            return default_image, 0
    
    def get_class_weights(self):
        """
        Calculate class weights for handling imbalanced datasets.
        
        Returns:
            torch.Tensor: Class weights for loss function
        """
        label_counts = self.df[self.label_col].value_counts()
        total_samples = len(self.df)
        num_classes = len(self.label_map)
        
        # Calculate weights
        weights = []
        for emotion in sorted(self.label_map.keys()):
            count = label_counts.get(emotion, 1)
            weight = total_samples / (num_classes * count)
            weights.append(weight)
        
        return torch.FloatTensor(weights)
    
    def get_sample_weights(self):
        """
        Get weights for each sample for WeightedRandomSampler.
        
        Returns:
            torch.Tensor: Sample weights
        """
        class_weights = self.get_class_weights()
        sample_weights = []
        
        for idx in range(len(self.df)):
            label_name = self.df.iloc[idx][self.label_col]
            label_idx = self.label_map[label_name]
            sample_weights.append(class_weights[label_idx])
        
        return torch.FloatTensor(sample_weights)
    
    def visualize_samples(self, num_samples=8, figsize=(12, 8)):
        """
        Visualize random samples from the dataset.
        
        Args:
            num_samples (int): Number of samples to visualize
            figsize (tuple): Figure size
        """
        indices = np.random.choice(len(self.df), size=num_samples, replace=False)
        
        fig, axes = plt.subplots(2, num_samples//2, figsize=figsize)
        axes = axes.flatten()
        
        for i, idx in enumerate(indices):
            image, label = self[idx]
            
            # Convert tensor to numpy for visualization
            if isinstance(image, torch.Tensor):
                if self.rgb:
                    # RGB image
                    image_np = image.permute(1, 2, 0).numpy()
                    # Denormalize if needed
                    if image_np.min() < 0:  # Likely normalized
                        image_np = (image_np * 0.5) + 0.5  # Approximate denormalization
                    image_np = np.clip(image_np, 0, 1)
                else:
                    # Grayscale image
                    image_np = image.squeeze().numpy()
                    if image_np.min() < 0:  # Likely normalized
                        image_np = (image_np * 0.5) + 0.5  # Approximate denormalization
                    image_np = np.clip(image_np, 0, 1)
            
            # Get emotion name from label
            emotion_name = [k for k, v in self.label_map.items() if v == label][0]
            
            # Display image
            if self.rgb:
                axes[i].imshow(image_np)
            else:
                axes[i].imshow(image_np, cmap='gray')
            
            axes[i].set_title(f'{emotion_name} (idx: {idx})')
            axes[i].axis('off')
        
        plt.tight_layout()
        plt.show()


# Test the dataset class
print("EmotionDataset class created successfully!")

## 2. Data Transformations

In [None]:
def get_transforms(input_size=224, rgb=True, augment=True, normalize=True):
    """
    Get data transforms for training and validation.
    
    Args:
        input_size (int): Target image size
        rgb (bool): Whether to use RGB (True) or grayscale (False)
        augment (bool): Whether to apply data augmentation for training
        normalize (bool): Whether to normalize images
        
    Returns:
        dict: Dictionary with 'train' and 'val' transforms
    """
    
    if rgb:
        # RGB transforms
        if normalize:
            # ImageNet normalization for transfer learning
            mean = [0.485, 0.456, 0.406]
            std = [0.229, 0.224, 0.225]
        else:
            mean = [0.5, 0.5, 0.5]
            std = [0.5, 0.5, 0.5]
        
        if augment:
            train_transform = transforms.Compose([
                transforms.Resize((input_size, input_size)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(degrees=15),
                transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
                transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
                transforms.ToTensor(),
                transforms.Normalize(mean=mean, std=std) if normalize else transforms.Lambda(lambda x: x)
            ])
        else:
            train_transform = transforms.Compose([
                transforms.Resize((input_size, input_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=mean, std=std) if normalize else transforms.Lambda(lambda x: x)
            ])
        
        val_transform = transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std) if normalize else transforms.Lambda(lambda x: x)
        ])
    
    else:
        # Grayscale transforms
        if normalize:
            # Grayscale normalization (you may need to adjust these values based on your dataset)
            mean = [0.5]
            std = [0.5]
        else:
            mean = [0.5]
            std = [0.5]
        
        if augment:
            train_transform = transforms.Compose([
                transforms.Resize((input_size, input_size)),
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomRotation(degrees=15),
                transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
                transforms.ToTensor(),
                transforms.Normalize(mean=mean, std=std) if normalize else transforms.Lambda(lambda x: x)
            ])
        else:
            train_transform = transforms.Compose([
                transforms.Resize((input_size, input_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=mean, std=std) if normalize else transforms.Lambda(lambda x: x)
            ])
        
        val_transform = transforms.Compose([
            transforms.Resize((input_size, input_size)),
            transforms.ToTensor(),
            transforms.Normalize(mean=mean, std=std) if normalize else transforms.Lambda(lambda x: x)
        ])
    
    return {
        'train': train_transform,
        'val': val_transform
    }


def get_grayscale_normalization_stats(dataset_df, root_dir, sample_size=1000):
    """
    Calculate mean and std for grayscale dataset normalization.
    
    Args:
        dataset_df (pd.DataFrame): Dataset dataframe
        root_dir (str): Root directory containing images
        sample_size (int): Number of images to sample for statistics
        
    Returns:
        tuple: (mean, std) for normalization
    """
    # Sample random images
    sample_indices = np.random.choice(len(dataset_df), min(sample_size, len(dataset_df)), replace=False)
    
    pixel_values = []
    
    for idx in sample_indices:
        try:
            img_path = Path(root_dir) / dataset_df.iloc[idx].iloc[0]  # Assuming first column is path
            image = Image.open(img_path).convert('L')
            img_array = np.array(image) / 255.0  # Normalize to [0, 1]
            pixel_values.extend(img_array.flatten())
        except Exception as e:
            continue
    
    pixel_values = np.array(pixel_values)
    mean = pixel_values.mean()
    std = pixel_values.std()
    
    print(f"Calculated normalization stats from {len(sample_indices)} images:")
    print(f"Mean: {mean:.6f}")
    print(f"Std: {std:.6f}")
    
    return mean, std


# Test transforms
print("Testing data transforms...")

# RGB transforms for transfer learning
rgb_transforms = get_transforms(input_size=224, rgb=True, augment=True, normalize=True)
print("RGB transforms created:")
print(f"- Train: {len(rgb_transforms['train'].transforms)} transforms")
print(f"- Val: {len(rgb_transforms['val'].transforms)} transforms")

# Grayscale transforms for CNN baseline
gray_transforms = get_transforms(input_size=48, rgb=False, augment=True, normalize=True)
print("\nGrayscale transforms created:")
print(f"- Train: {len(gray_transforms['train'].transforms)} transforms")
print(f"- Val: {len(gray_transforms['val'].transforms)} transforms")

## 3. Advanced Albumentations Transforms

In [None]:
if ALBUMENTATIONS_AVAILABLE:
    def get_albumentations_transforms(input_size=224, rgb=True, augment=True):
        """
        Get advanced augmentation transforms using Albumentations.
        
        Args:
            input_size (int): Target image size
            rgb (bool): Whether to use RGB or grayscale
            augment (bool): Whether to apply augmentation
            
        Returns:
            dict: Dictionary with 'train' and 'val' transforms
        """
        
        if augment:
            train_transform = A.Compose([
                A.Resize(input_size, input_size),
                A.HorizontalFlip(p=0.5),
                A.Rotate(limit=15, p=0.7),
                A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.7),
                A.HueSaturationValue(hue_shift_limit=10, sat_shift_limit=15, val_shift_limit=10, p=0.5) if rgb else A.NoOp(),
                A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.1, rotate_limit=0, p=0.5),
                A.GaussNoise(var_limit=(10.0, 50.0), p=0.3),
                A.MotionBlur(blur_limit=3, p=0.2),
                A.CLAHE(clip_limit=2.0, tile_grid_size=(8, 8), p=0.3),
                A.Normalize(mean=[0.485, 0.456, 0.406] if rgb else [0.5], 
                           std=[0.229, 0.224, 0.225] if rgb else [0.5]),
                ToTensorV2()
            ])
        else:
            train_transform = A.Compose([
                A.Resize(input_size, input_size),
                A.Normalize(mean=[0.485, 0.456, 0.406] if rgb else [0.5], 
                           std=[0.229, 0.224, 0.225] if rgb else [0.5]),
                ToTensorV2()
            ])
        
        val_transform = A.Compose([
            A.Resize(input_size, input_size),
            A.Normalize(mean=[0.485, 0.456, 0.406] if rgb else [0.5], 
                       std=[0.229, 0.224, 0.225] if rgb else [0.5]),
            ToTensorV2()
        ])
        
        return {
            'train': train_transform,
            'val': val_transform
        }
    
    
    class AlbumentationsDataset(EmotionDataset):
        """
        Enhanced dataset class using Albumentations for better data augmentation.
        """
        
        def __init__(self, dataframe, root_dir, transform=None, label_map=None, rgb=True):
            super().__init__(dataframe, root_dir, transform=None, label_map=label_map, rgb=rgb)
            self.albumentations_transform = transform
            self.force_rgb = rgb
        
        def __getitem__(self, idx):
            """Get item with Albumentations transform."""
            if torch.is_tensor(idx):
                idx = idx.tolist()
            
            # Get image path and label
            row = self.df.iloc[idx]
            img_path = self.root_dir / row[self.path_col]
            label = row['label_idx'] if 'label_idx' in row else self.label_map[row[self.label_col]]
            
            # Load image
            try:
                image = Image.open(img_path)
                
                # Convert to appropriate mode
                if self.force_rgb:
                    image = image.convert('RGB')
                else:
                    image = image.convert('L')  # Grayscale
                
                # Convert to numpy array for Albumentations
                image_np = np.array(image)
                
                # Apply Albumentations transform
                if self.albumentations_transform:
                    transformed = self.albumentations_transform(image=image_np)
                    image = transformed['image']
                
                return image, label
                
            except Exception as e:
                print(f"Error loading image {img_path}: {e}")
                # Return a default image and label
                if self.force_rgb:
                    default_image = torch.zeros(3, 224, 224)
                else:
                    default_image = torch.zeros(1, 48, 48)
                return default_image, 0
    
    
    print("Albumentations transforms and dataset created successfully!")
    
    # Test albumentations transforms
    albu_transforms = get_albumentations_transforms(input_size=224, rgb=True, augment=True)
    print(f"Albumentations train transform: {len(albu_transforms['train'])} operations")
    print(f"Albumentations val transform: {len(albu_transforms['val'])} operations")

else:
    print("Albumentations not available. Install with: pip install albumentations")
    AlbumentationsDataset = None
    get_albumentations_transforms = None

## 4. Data Loader Creation Functions

In [None]:
def create_data_loaders(train_df, val_df, test_df, root_dir, 
                       batch_size=32, num_workers=4, rgb=True, input_size=224,
                       use_weighted_sampler=False, use_albumentations=False):
    """
    Create PyTorch data loaders for training, validation, and testing.
    
    Args:
        train_df (pd.DataFrame): Training dataset dataframe
        val_df (pd.DataFrame): Validation dataset dataframe
        test_df (pd.DataFrame): Test dataset dataframe
        root_dir (str): Root directory containing images
        batch_size (int): Batch size for data loaders
        num_workers (int): Number of worker processes for data loading
        rgb (bool): Whether to use RGB or grayscale images
        input_size (int): Input image size
        use_weighted_sampler (bool): Whether to use weighted sampling for imbalanced classes
        use_albumentations (bool): Whether to use Albumentations for augmentation
        
    Returns:
        dict: Dictionary containing 'train', 'val', 'test' data loaders and 'label_map'
    """
    
    # Get transforms
    if use_albumentations and ALBUMENTATIONS_AVAILABLE:
        transforms_dict = get_albumentations_transforms(input_size=input_size, rgb=rgb, augment=True)
        dataset_class = AlbumentationsDataset
        print("Using Albumentations for data augmentation")
    else:
        transforms_dict = get_transforms(input_size=input_size, rgb=rgb, augment=True, normalize=True)
        dataset_class = EmotionDataset
        print("Using torchvision transforms")
    
    # Create datasets
    train_dataset = dataset_class(
        train_df, root_dir, transform=transforms_dict['train'], rgb=rgb
    )
    
    val_dataset = dataset_class(
        val_df, root_dir, transform=transforms_dict['val'], 
        label_map=train_dataset.label_map, rgb=rgb
    )
    
    test_dataset = dataset_class(
        test_df, root_dir, transform=transforms_dict['val'], 
        label_map=train_dataset.label_map, rgb=rgb
    )
    
    # Create samplers
    train_sampler = None
    if use_weighted_sampler:
        sample_weights = train_dataset.get_sample_weights()
        train_sampler = WeightedRandomSampler(
            weights=sample_weights, 
            num_samples=len(sample_weights),
            replacement=True
        )
        print("Using WeightedRandomSampler for balanced training")
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset, 
        batch_size=batch_size,
        sampler=train_sampler,
        shuffle=(train_sampler is None),  # Don't shuffle if using sampler
        num_workers=num_workers,
        pin_memory=True,
        drop_last=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True
    )
    
    print(f"\nData loaders created:")
    print(f"- Train: {len(train_loader)} batches, {len(train_dataset)} samples")
    print(f"- Val: {len(val_loader)} batches, {len(val_dataset)} samples")
    print(f"- Test: {len(test_loader)} batches, {len(test_dataset)} samples")
    print(f"- Batch size: {batch_size}")
    print(f"- Input size: {input_size}x{input_size}")
    print(f"- Image mode: {'RGB' if rgb else 'Grayscale'}")
    
    return {
        'train': train_loader,
        'val': val_loader,
        'test': test_loader,
        'label_map': train_dataset.label_map,
        'class_weights': train_dataset.get_class_weights()
    }


def create_quick_loaders(csv_path, root_dir, train_ratio=0.7, val_ratio=0.15, 
                        batch_size=32, rgb=True, input_size=224):
    """
    Quick function to create data loaders from a single CSV file.
    
    Args:
        csv_path (str): Path to CSV file containing image paths and labels
        root_dir (str): Root directory containing images
        train_ratio (float): Ratio of data for training
        val_ratio (float): Ratio of data for validation
        batch_size (int): Batch size
        rgb (bool): Whether to use RGB images
        input_size (int): Input image size
        
    Returns:
        dict: Dictionary with data loaders and metadata
    """
    # Load data
    df = pd.read_csv(csv_path)
    print(f"Loaded dataset with {len(df)} samples")
    
    # Split data
    from sklearn.model_selection import train_test_split
    
    # First split: train + val vs test
    train_val_df, test_df = train_test_split(
        df, test_size=(1 - train_ratio - val_ratio), 
        stratify=df.iloc[:, 1] if len(df.columns) > 1 else None,
        random_state=42
    )
    
    # Second split: train vs val
    train_df, val_df = train_test_split(
        train_val_df, test_size=(val_ratio / (train_ratio + val_ratio)),
        stratify=train_val_df.iloc[:, 1] if len(train_val_df.columns) > 1 else None,
        random_state=42
    )
    
    print(f"Data split:")
    print(f"- Train: {len(train_df)} samples ({len(train_df)/len(df)*100:.1f}%)")
    print(f"- Val: {len(val_df)} samples ({len(val_df)/len(df)*100:.1f}%)")
    print(f"- Test: {len(test_df)} samples ({len(test_df)/len(df)*100:.1f}%)")
    
    # Create data loaders
    return create_data_loaders(
        train_df, val_df, test_df, root_dir,
        batch_size=batch_size, rgb=rgb, input_size=input_size
    )


print("Data loader creation functions ready!")

## 5. Dataset Analysis and Visualization Functions

In [None]:
def analyze_dataset(dataset_df, label_col=None, figsize=(12, 8)):
    """
    Analyze dataset distribution and statistics.
    
    Args:
        dataset_df (pd.DataFrame): Dataset dataframe
        label_col (str): Name of label column (auto-detected if None)
        figsize (tuple): Figure size for plots
    """
    if label_col is None:
        # Auto-detect label column
        label_candidates = [c for c in dataset_df.columns if any(
            keyword in c.lower() for keyword in ['label', 'emotion', 'class', 'target']
        )]
        label_col = label_candidates[0] if label_candidates else dataset_df.columns[-1]
    
    print(f"Dataset Analysis")
    print(f"{'='*50}")
    print(f"Total samples: {len(dataset_df)}")
    print(f"Label column: '{label_col}'")
    
    # Class distribution
    class_counts = dataset_df[label_col].value_counts()
    print(f"\nClass Distribution:")
    for emotion, count in class_counts.items():
        percentage = (count / len(dataset_df)) * 100
        print(f"  {emotion}: {count} samples ({percentage:.1f}%)")
    
    # Calculate class imbalance
    max_count = class_counts.max()
    min_count = class_counts.min()
    imbalance_ratio = max_count / min_count
    print(f"\nClass Imbalance Ratio: {imbalance_ratio:.2f}")
    
    if imbalance_ratio > 2.0:
        print("⚠️  Dataset is imbalanced - consider using weighted sampling or class weights")
    else:
        print("✅ Dataset is relatively balanced")
    
    # Visualize distribution
    fig, axes = plt.subplots(1, 2, figsize=figsize)
    
    # Bar plot
    class_counts.plot(kind='bar', ax=axes[0], color='skyblue', alpha=0.7)
    axes[0].set_title('Class Distribution')
    axes[0].set_xlabel('Emotion')
    axes[0].set_ylabel('Count')
    axes[0].tick_params(axis='x', rotation=45)
    
    # Pie chart
    class_counts.plot(kind='pie', ax=axes[1], autopct='%1.1f%%', startangle=90)
    axes[1].set_title('Class Distribution (Percentage)')
    axes[1].set_ylabel('')  # Remove ylabel for pie chart
    
    plt.tight_layout()
    plt.show()


def visualize_batch(data_loader, num_samples=8, figsize=(12, 8), denormalize=True):
    """
    Visualize a batch of data from a data loader.
    
    Args:
        data_loader (DataLoader): PyTorch data loader
        num_samples (int): Number of samples to visualize
        figsize (tuple): Figure size
        denormalize (bool): Whether to denormalize images for visualization
    """
    # Get one batch
    data_iter = iter(data_loader)
    images, labels = next(data_iter)
    
    # Get label map (assuming it's stored in dataset)
    label_map = data_loader.dataset.label_map
    inverse_label_map = {v: k for k, v in label_map.items()}
    
    # Select samples to display
    num_samples = min(num_samples, len(images))
    indices = np.random.choice(len(images), num_samples, replace=False)
    
    fig, axes = plt.subplots(2, num_samples//2, figsize=figsize)
    axes = axes.flatten() if num_samples > 1 else [axes]
    
    for i, idx in enumerate(indices):
        image = images[idx]
        label = labels[idx].item()
        emotion_name = inverse_label_map[label]
        
        # Convert tensor to numpy
        if image.shape[0] == 3:  # RGB
            img_np = image.permute(1, 2, 0).numpy()
            if denormalize:
                # Approximate denormalization for ImageNet
                img_np = img_np * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
            img_np = np.clip(img_np, 0, 1)
            axes[i].imshow(img_np)
        else:  # Grayscale
            img_np = image.squeeze().numpy()
            if denormalize:
                img_np = img_np * 0.5 + 0.5  # Approximate denormalization
            img_np = np.clip(img_np, 0, 1)
            axes[i].imshow(img_np, cmap='gray')
        
        axes[i].set_title(f'{emotion_name}')
        axes[i].axis('off')
    
    plt.suptitle(f'Sample Batch from Data Loader')
    plt.tight_layout()
    plt.show()
    
    print(f"Batch info:")
    print(f"- Batch size: {len(images)}")
    print(f"- Image shape: {images[0].shape}")
    print(f"- Image dtype: {images[0].dtype}")
    print(f"- Labels: {labels[:num_samples].tolist()}")


def compute_dataset_statistics(dataset_df, root_dir, sample_size=1000, rgb=True):
    """
    Compute pixel statistics for the dataset.
    
    Args:
        dataset_df (pd.DataFrame): Dataset dataframe
        root_dir (str): Root directory containing images
        sample_size (int): Number of images to sample
        rgb (bool): Whether images are RGB or grayscale
        
    Returns:
        dict: Statistics dictionary
    """
    print(f"Computing dataset statistics from {min(sample_size, len(dataset_df))} images...")
    
    # Sample images
    sample_indices = np.random.choice(len(dataset_df), min(sample_size, len(dataset_df)), replace=False)
    
    pixel_values = [] if not rgb else [[], [], []]  # For RGB channels
    image_sizes = []
    
    for idx in sample_indices:
        try:
            img_path = Path(root_dir) / dataset_df.iloc[idx].iloc[0]
            image = Image.open(img_path)
            
            if rgb:
                image = image.convert('RGB')
                img_array = np.array(image) / 255.0  # Normalize to [0, 1]
                # Separate channels
                for c in range(3):
                    pixel_values[c].extend(img_array[:, :, c].flatten())
            else:
                image = image.convert('L')
                img_array = np.array(image) / 255.0
                pixel_values.extend(img_array.flatten())
            
            image_sizes.append(image.size)
            
        except Exception as e:
            continue
    
    if rgb:
        # RGB statistics
        means = [np.mean(pixel_values[c]) for c in range(3)]
        stds = [np.std(pixel_values[c]) for c in range(3)]
        print(f"RGB Statistics:")
        print(f"  Mean: {means}")
        print(f"  Std:  {stds}")
        stats = {'mean': means, 'std': stds, 'rgb': True}
    else:
        # Grayscale statistics
        mean = np.mean(pixel_values)
        std = np.std(pixel_values)
        print(f"Grayscale Statistics:")
        print(f"  Mean: {mean:.6f}")
        print(f"  Std:  {std:.6f}")
        stats = {'mean': [mean], 'std': [std], 'rgb': False}
    
    # Image size statistics
    unique_sizes = list(set(image_sizes))
    print(f"\nImage Size Statistics:")
    print(f"  Unique sizes: {unique_sizes[:10]}{'...' if len(unique_sizes) > 10 else ''}")
    print(f"  Most common size: {Counter(image_sizes).most_common(1)[0]}")
    
    stats['image_sizes'] = image_sizes
    return stats


print("Dataset analysis and visualization functions ready!")

## Summary

This notebook provides a complete data handling pipeline for visual emotion recognition:

### Core Components:
1. **EmotionDataset**: Flexible PyTorch dataset class supporting both RGB and grayscale images
2. **Data Transforms**: Standard torchvision transforms with proper normalization
3. **Advanced Augmentations**: Albumentations integration for better augmentation
4. **Data Loaders**: Functions to create training, validation, and test data loaders
5. **Analysis Tools**: Dataset statistics, visualization, and imbalance analysis

### Key Features:
- **Auto-detection** of column names for flexibility
- **Class balancing** with weighted sampling
- **Comprehensive augmentation** strategies
- **Visualization tools** for data exploration
- **Statistics computation** for normalization
- **Error handling** for robust data loading

All functionality is self-contained within this notebook and doesn't require the src folder structure.