In [None]:
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import time
from tqdm import tqdm

def train_model(model, train_loader, val_loader, device, epochs=50, patience=10):
    """Train the model with early stopping"""
    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss(weight=calculate_class_weights(train_loader, device))
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5, verbose=True)
    
    # Initialize variables for early stopping
    best_f1 = 0.0
    best_model_state = None
    no_improve_epochs = 0
    
    # Training loop
    for epoch in range(epochs):
        start_time = time.time()
        
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            images, labels = images.to(device), labels.to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Statistics
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        train_loss = running_loss / total
        train_acc = 100 * correct / total
        
        # Validation phase
        val_acc, val_f1, _ = test_model(model, val_loader, device, verbose=False)
        
        # Print epoch statistics
        epoch_time = time.time() - start_time
        print(f'Epoch {epoch+1}/{epochs} | Time: {epoch_time:.2f}s')
        print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')
        print(f'Val Acc: {val_acc:.2f}% | Val F1: {val_f1:.4f}')
        
        # Learning rate scheduler step
        scheduler.step(val_f1)
        
        # Check for early stopping
        if val_f1 > best_f1:
            best_f1 = val_f1
            best_model_state = model.state_dict().copy()
            no_improve_epochs = 0
            # Save the best model
            torch.save(best_model_state, 'dsan_model_rafdb_best.pth')
            print(f"Saved new best model with F1: {best_f1:.4f}")
        else:
            no_improve_epochs += 1
            print(f"No improvement for {no_improve_epochs} epochs")
            
        if no_improve_epochs >= patience:
            print(f"Early stopping after {epoch+1} epochs")
            break
    
    # Load the best model
    if best_model_state:
        model.load_state_dict(best_model_state)
        print(f"Loaded best model with F1: {best_f1:.4f}")
    
    # Save the final model
    torch.save(model.state_dict(), 'dsan_model_rafdb.pth')
    print("Saved final model")
    
    return model

def calculate_class_weights(train_loader, device):
    """Calculate class weights for handling class imbalance"""
    class_counts = torch.zeros(7, dtype=torch.float)
    
    for _, labels in train_loader:
        for label in labels:
            class_counts[label] += 1
    
    # Calculate weights (inverse of frequency)
    weights = 1.0 / class_counts
    # Normalize weights
    weights = weights / weights.sum() * len(weights)
    
    return weights.to(device)

def test_model(model, test_loader, device, verbose=True):
    """Test the model on the test dataset"""
    model.eval()
    correct = 0
    total = 0
    class_correct = [0] * 7
    class_total = [0] * 7
    emotion_labels = ['Surprise', 'Fear', 'Disgust', 'Happiness', 'Sadness', 'Anger', 'Neutral']
    
    confusion_matrix = torch.zeros(7, 7)
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Per-class accuracy
            for i in range(len(labels)):
                label = labels[i].item()
                pred = predicted[i].item()
                class_total[label] += 1
                if label == pred:
                    class_correct[label] += 1
                
                # Update confusion matrix
                confusion_matrix[label][pred] += 1
    
    # Calculate overall accuracy
    accuracy = 100 * correct / total
    
    # Calculate F1 score for each class
    f1_scores = []
    for i in range(7):
        # Calculate precision and recall
        tp = confusion_matrix[i][i].item()
        fp = confusion_matrix[:, i].sum().item() - tp
        fn = confusion_matrix[i, :].sum().item() - tp
        
        precision = tp / max(tp + fp, 1)
        recall = tp / max(tp + fn, 1)
        
        # Calculate F1 score
        f1 = 2 * precision * recall / max(precision + recall, 1e-6)
        f1_scores.append(f1)
    
    # Calculate mean F1 score
    mean_f1 = sum(f1_scores) / len(f1_scores)
    
    if verbose:
        print(f'Test Accuracy: {accuracy:.2f}%')
        
        # Print per-class accuracy
        print('\nPer-class accuracy:')
        for i in range(7):
            class_acc = 100 * class_correct[i] / max(class_total[i], 1)
            print(f'{emotion_labels[i]}: {class_acc:.2f}% ({class_correct[i]}/{class_total[i]})')
        
        # Print F1 scores
        print('\nPer-class F1 scores:')
        for i in range(7):
            print(f'{emotion_labels[i]}: {f1_scores[i]:.4f}')
        
        print(f'\nMean F1 Score: {mean_f1:.4f}')
    
    return accuracy, mean_f1, confusion_matrix

In [None]:
def get_transforms():
    """Get data transformations for training and testing"""
    train_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    return train_transform, test_transform

In [None]:
def main():
    """Main function to train and test the model on RAF-DB dataset"""
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Set random seed for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    
    # Data transformations
    train_transform, test_transform = get_transforms()
    
    # Path settings for the directory structure
    raf_db_root = "./data/rafdb/DATASET"  # Path to the dataset root
    
    # Create train and validation datasets
    try:
        # Load train dataset
        train_dataset = RAFDBFolderDataset(
            root_dir=raf_db_root,
            split='train',
            transform=train_transform
        )
        
        # Create validation split (20% of training data)
        train_size = int(0.8 * len(train_dataset))
        val_size = len(train_dataset) - train_size
        
        train_dataset, val_dataset = torch.utils.data.random_split(
            train_dataset, 
            [train_size, val_size],
            generator=torch.Generator().manual_seed(42)
        )
        
        # Apply correct transforms to validation split
        val_dataset.dataset = copy.deepcopy(train_dataset.dataset)
        val_dataset.dataset.transform = test_transform
        
        # Create test dataset
        test_dataset = RAFDBFolderDataset(
            root_dir=raf_db_root,
            split='test',
            transform=test_transform
        )
        
        # Handle class imbalance with WeightedRandomSampler
        # Count samples per class in training set
        class_counts = [0] * 7
        for _, label in train_dataset:
            class_counts[label] += 1
        
        # Calculate weights for each sample in training set
        weights = torch.zeros(len(train_dataset))
        for idx, (_, label) in enumerate(train_dataset):
            # Weight = 1 / class_count
            weights[idx] = 1.0 / class_counts[label]
        
        # Create WeightedRandomSampler
        sampler = torch.utils.data.WeightedRandomSampler(
            weights=weights,
            num_samples=len(train_dataset),
            replacement=True
        )
        
        # Create dataloaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=32,
            sampler=sampler,
            num_workers=4,
            pin_memory=True
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=32,
            shuffle=False,
            num_workers=4,
            pin_memory=True
        )
        
        test_loader = DataLoader(
            test_dataset,
            batch_size=32,
            shuffle=False,
            num_workers=4,
            pin_memory=True
        )
        
        print(f"Train dataset size: {len(train_dataset)}")
        print(f"Validation dataset size: {len(val_dataset)}")
        print(f"Test dataset size: {len(test_dataset)}")
        
        # Create model
        model = DSAN(num_classes=7, pretrained=True)
        model = model.to(device)
        print(f"Model created with {count_parameters(model):,} trainable parameters")
        
        # Check if we want to load a pretrained model
        model_path = "./dsan_model_rafdb.pth"
        train_model_flag = True
        
        if os.path.exists(model_path) and not train_model_flag:
            model.load_state_dict(torch.load(model_path, map_location=device))
            print(f"Loaded pretrained model from {model_path}")
        else:
            print("Training new model...")
            # Train the model
            model = train_model(
                model=model,
                train_loader=train_loader,
                val_loader=val_loader,
                device=device,
                epochs=50,
                patience=10
            )
        
        # Test the model
        print("\nEvaluating model on test set:")
        accuracy, mean_f1, conf_matrix = test_model(model, test_loader, device)
        
        # Visualize confusion matrix
        cm_fig = visualize_confusion_matrix(conf_matrix)
        cm_fig.savefig("confusion_matrix_rafdb.png")
        print("Saved confusion matrix visualization to confusion_matrix_rafdb.png")
        
        # Visualize sample predictions
        sample_fig = visualize_sample_predictions(model, test_loader, device)
        sample_fig.savefig("sample_predictions_rafdb.png")
        print("Saved sample predictions visualization to sample_predictions_rafdb.png")
        
        # Visualize attention maps
        visualize_attention_maps(model, test_loader, device)
        
        # Save evaluation results
        with open("evaluation_results_rafdb.txt", "w") as f:
            f.write(f"Test Accuracy: {accuracy:.2f}%\n")
            f.write(f"Mean F1 Score: {mean_f1:.4f}\n")
        
        print("Evaluation completed!")
        
    except Exception as e:
        print(f"Error during dataset loading or evaluation: {e}")
        import traceback
        traceback.print_exc()

In [None]:
class RAFDBFolderDataset(Dataset):
    """
    RAF-DB dataset loader for folder-based structure
    
    The RAF-DB dataset contains 7 emotion categories mapped to folder numbers:
    1: Surprise, 2: Fear, 3: Disgust, 4: Happiness, 5: Sadness, 6: Anger, 7: Neutral
    """
    def __init__(self, root_dir, split='test', transform=None):
        """
        Args:
            root_dir (string): Root directory of the RAF-DB dataset.
            split (string): 'train' or 'test' split.
            transform (callable, optional): Optional transform to be applied on a sample.
        """
        self.root_dir = os.path.join(root_dir, split)
        self.transform = transform
        self.samples = []
        
        # Check if directory exists
        if not os.path.exists(self.root_dir):
            raise RuntimeError(f"Dataset directory not found: {self.root_dir}")
        
        # Class mapping based on RAF-DB folder numbering
        self.class_to_idx = {
            '1': 0,  # Surprise
            '2': 1,  # Fear
            '3': 2,  # Disgust
            '4': 3,  # Happiness
            '5': 4,  # Sadness
            '6': 5,  # Anger
            '7': 6,  # Neutral
        }
        
        # Load all samples from the directory structure
        class_samples = {cls: 0 for cls in self.class_to_idx.values()}
        
        for class_folder in sorted(os.listdir(self.root_dir)):
            class_path = os.path.join(self.root_dir, class_folder)
            if os.path.isdir(class_path) and class_folder in self.class_to_idx:
                class_idx = self.class_to_idx[class_folder]
                img_files = [f for f in os.listdir(class_path) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
                class_samples[class_idx] = len(img_files)
                print(f"Class {class_folder} ({self.get_class_name(class_idx)}): {len(img_files)} images")
                for img_file in img_files:
                    self.samples.append((os.path.join(class_path, img_file), class_idx))
        
        print(f"Total samples in {split} set: {len(self.samples)}")
        
        # Calculate class weights for handling imbalance
        total_samples = sum(class_samples.values())
        self.class_weights = {cls: total_samples / (len(class_samples) * count) 
                             if count > 0 else 0 
                             for cls, count in class_samples.items()}
    
    def get_class_weights(self):
        """Get class weights for handling imbalance"""
        return self.class_weights
    
    def get_class_name(self, class_idx):
        """Get emotion name from class index"""
        emotion_labels = ['Surprise', 'Fear', 'Disgust', 'Happiness', 'Sadness', 'Anger', 'Neutral']
        return emotion_labels[class_idx]
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        try:
            image = Image.open(img_path).convert('RGB')
            
            if self.transform:
                image = self.transform(image)
                
            return image, label
        except Exception as e:
            print(f"Error loading image {img_path}: {e}")
            # Return a placeholder image and the same label
            placeholder = torch.zeros(3, 224, 224) if self.transform else Image.new('RGB', (224, 224))
            return placeholder, label

In [None]:
# Updated model class with additional improvements
class DSANPlus(nn.Module):
    """
    Enhanced Dual Stream Attention Network for Facial Emotion Recognition
    
    Improvements:
    - Added Dropout for regularization
    - Enhanced attention modules
    - Added residual connections
    - Added batch normalization
    """
    def __init__(self, num_classes=7, pretrained=True, dropout_rate=0.5):
        super().__init__()

        # Use ResNet18 as backbone for feature extraction
        resnet = models.resnet18(pretrained=pretrained)

        # GFE-AN Stream (Global Feature Extraction with Attention Network)
        self.conv1 = resnet.conv1
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool

        self.layer1 = resnet.layer1
        self.gfe_attention1 = SparseAttention(64)
        self.bn_gfe1 = nn.BatchNorm2d(64)

        self.layer2 = resnet.layer2
        self.gfe_attention2 = SparseAttention(128)
        self.bn_gfe2 = nn.BatchNorm2d(128)

        # MFF-AN Stream (Multi-scale Feature Fusion with Attention Network)
        self.layer3 = resnet.layer3
        self.mff_attention1 = LocalFeatureAttention(256)
        self.bn_mff1 = nn.BatchNorm2d(256)

        self.layer4 = resnet.layer4
        self.mff_attention2 = LocalFeatureAttention(512)
        self.bn_mff2 = nn.BatchNorm2d(512)

        # Global pooling and classification
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        self.dropout = nn.Dropout(dropout_rate)
        
        # Two-stage classifier
        self.fc1 = nn.Linear(512, 256)
        self.bn_fc = nn.BatchNorm1d(256)
        self.relu_fc = nn.ReLU(inplace=True)
        self.dropout_fc = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(256, num_classes)

        # Store attention maps for visualization
        self.attention_maps = []

    def forward(self, x):
        # Reset attention maps storage
        self.attention_maps = []

        # GFE-AN Stream
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        # Layer 1 with residual connection
        identity1 = x
        x = self.layer1(x)
        # Store pre-attention feature map for visualization
        pre_attn1 = x.detach().clone()
        x_attn = self.gfe_attention1(x)
        x = x_attn * x + x  # Residual connection
        x = self.bn_gfe1(x)
        # Store attention effect for visualization
        self.attention_maps.append((pre_attn1, x.detach().clone()))

        # Layer 2 with residual connection
        identity2 = x
        x = self.layer2(x)
        pre_attn2 = x.detach().clone()
        x_attn = self.gfe_attention2(x)
        x = x_attn * x + x  # Residual connection
        x = self.bn_gfe2(x)
        self.attention_maps.append((pre_attn2, x.detach().clone()))

        # MFF-AN Stream
        identity3 = x
        x = self.layer3(x)
        pre_attn3 = x.detach().clone()
        x_attn = self.mff_attention1(x)
        x = x + x_attn  # Residual connection
        x = self.bn_mff1(x)
        self.attention_maps.append((pre_attn3, x.detach().clone()))

        identity4 = x
        x = self.layer4(x)
        pre_attn4 = x.detach().clone()
        x_attn = self.mff_attention2(x)
        x = x + x_attn  # Residual connection
        x = self.bn_mff2(x)
        self.attention_maps.append((pre_attn4, x.detach().clone()))

        # Classification
        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.dropout(x)
        
        # Two-stage classification
        x = self.fc1(x)
        x = self.bn_fc(x)
        x = self.relu_fc(x)
        x = self.dropout_fc(x)
        x = self.fc2(x)

        return x

    def visualize_attention(self, input_img):
        """
        Generate attention visualizations for a given input image

        Args:
            input_img: Input tensor of shape [1, 3, H, W]

        Returns:
            List of attention visualization figures
        """
        # Ensure model is in eval mode
        self.eval()

        with torch.no_grad():
            # Forward pass to populate attention maps
            _ = self.forward(input_img)

            visualizations = []

            for i, (pre_attn, post_attn) in enumerate(self.attention_maps):
                # Convert tensors to numpy for visualization
                pre_feature = pre_attn[0].cpu().numpy()  # Take first image in batch
                post_feature = post_attn[0].cpu().numpy()

                # Average across channels to get attention heatmap
                pre_feature_map = np.mean(pre_feature, axis=0)
                post_feature_map = np.mean(post_feature, axis=0)

                # Create difference map to highlight attention effect
                diff_map = post_feature_map - pre_feature_map

                # Create figure
                fig, axes = plt.subplots(1, 3, figsize=(15, 5))

                # Plot pre-attention feature map
                im1 = axes[0].imshow(pre_feature_map, cmap='viridis')
                axes[0].set_title(f'Layer {i+1}: Pre-Attention')
                axes[0].axis('off')
                plt.colorbar(im1, ax=axes[0], fraction=0.046, pad=0.04)

                # Plot post-attention feature map
                im2 = axes[1].imshow(post_feature_map, cmap='viridis')
                axes[1].set_title(f'Layer {i+1}: Post-Attention')
                axes[1].axis('off')
                plt.colorbar(im2, ax=axes[1], fraction=0.046, pad=0.04)

                # Plot difference map
                im3 = axes[2].imshow(diff_map, cmap='RdBu_r')
                axes[2].set_title(f'Layer {i+1}: Attention Effect')
                axes[2].axis('off')
                plt.colorbar(im3, ax=axes[2], fraction=0.046, pad=0.04)

                plt.tight_layout()
                visualizations.append(fig)

            return visualizations

In [None]:
import copy
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import time
from tqdm import tqdm

In [None]:
# Import missing modules
import copy
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import time
from tqdm import tqdm

# Function to get data transformations
def get_transforms():
    """Get data transformations for training and testing"""
    train_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(10),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    return train_transform, test_transform

# Function to calculate class weights
def calculate_class_weights(train_loader, device):
    """Calculate class weights for handling class imbalance"""
    class_counts = torch.zeros(7, dtype=torch.float)
    
    for _, labels in train_loader:
        for label in labels:
            class_counts[label] += 1
    
    # Calculate weights (inverse of frequency)
    weights = 1.0 / class_counts
    # Normalize weights
    weights = weights / weights.sum() * len(weights)
    
    return weights.to(device)

# Modified test function to allow silent operation when needed
def test_model(model, test_loader, device, verbose=True):
    """Test the model on the test dataset"""
    model.eval()
    correct = 0
    total = 0
    class_correct = [0] * 7
    class_total = [0] * 7
    emotion_labels = ['Surprise', 'Fear', 'Disgust', 'Happiness', 'Sadness', 'Anger', 'Neutral']
    
    confusion_matrix = torch.zeros(7, 7)
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            
            # Per-class accuracy
            for i in range(len(labels)):
                label = labels[i].item()
                pred = predicted[i].item()
                class_total[label] += 1
                if label == pred:
                    class_correct[label] += 1
                
                # Update confusion matrix
                confusion_matrix[label][pred] += 1
    
    # Calculate overall accuracy
    accuracy = 100 * correct / total
    
    # Calculate F1 score for each class
    f1_scores = []
    for i in range(7):
        # Calculate precision and recall
        tp = confusion_matrix[i][i].item()
        fp = confusion_matrix[:, i].sum().item() - tp
        fn = confusion_matrix[i, :].sum().item() - tp
        
        precision = tp / max(tp + fp, 1)
        recall = tp / max(tp + fn, 1)
        
        # Calculate F1 score
        f1 = 2 * precision * recall / max(precision + recall, 1e-6)
        f1_scores.append(f1)
    
    # Calculate mean F1 score
    mean_f1 = sum(f1_scores) / len(f1_scores)
    
    if verbose:
        print(f'Test Accuracy: {accuracy:.2f}%')
        
        # Print per-class accuracy
        print('\nPer-class accuracy:')
        for i in range(7):
            class_acc = 100 * class_correct[i] / max(class_total[i], 1)
            print(f'{emotion_labels[i]}: {class_acc:.2f}% ({class_correct[i]}/{class_total[i]})')
        
        # Print F1 scores
        print('\nPer-class F1 scores:')
        for i in range(7):
            print(f'{emotion_labels[i]}: {f1_scores[i]:.4f}')
        
        print(f'\nMean F1 Score: {mean_f1:.4f}')
    
    return accuracy, mean_f1, confusion_matrix

# Training function with early stopping
def train_model(model, train_loader, val_loader, device, epochs=50, patience=10):
    """Train the model with early stopping"""
    # Define loss function and optimizer
    criterion = nn.CrossEntropyLoss(weight=calculate_class_weights(train_loader, device))
    optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-4)
    scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=5, verbose=True)
    
    # Initialize variables for early stopping
    best_f1 = 0.0
    best_model_state = None
    no_improve_epochs = 0
    
    # Save training history
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_acc': [],
        'val_f1': []
    }
    
    # Training loop
    for epoch in range(epochs):
        start_time = time.time()
        
        # Training phase
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        
        for images, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}"):
            images, labels = images.to(device), labels.to(device)
            
            # Zero the parameter gradients
            optimizer.zero_grad()
            
            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            # Backward pass and optimize
            loss.backward()
            optimizer.step()
            
            # Statistics
            running_loss += loss.item() * images.size(0)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        train_loss = running_loss / total
        train_acc = 100 * correct / total
        
        # Validation phase
        val_acc, val_f1, _ = test_model(model, val_loader, device, verbose=False)
        
        # Update history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_acc'].append(val_acc)
        history['val_f1'].append(val_f1)
        
        # Print epoch statistics
        epoch_time = time.time() - start_time
        print(f'Epoch {epoch+1}/{epochs} | Time: {epoch_time:.2f}s')
        print(f'Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%')
        print(f'Val Acc: {val_acc:.2f}% | Val F1: {val_f1:.4f}')
        
        # Learning rate scheduler step
        scheduler.step(val_f1)
        
        # Check for early stopping
        if val_f1 > best_f1:
            best_f1 = val_f1
            best_model_state = model.state_dict().copy()
            no_improve_epochs = 0
            # Save the best model
            torch.save(best_model_state, 'dsan_model_rafdb_best.pth')
            print(f"Saved new best model with F1: {best_f1:.4f}")
        else:
            no_improve_epochs += 1
            print(f"No improvement for {no_improve_epochs} epochs")
            
        if no_improve_epochs >= patience:
            print(f"Early stopping after {epoch+1} epochs")
            break
    
    # Plot training history
    plt.figure(figsize=(12, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(history['train_loss'], label='Train Loss')
    plt.title('Training Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    
    plt.subplot(1, 2, 2)
    plt.plot(history['train_acc'], label='Train Accuracy')
    plt.plot(history['val_acc'], label='Validation Accuracy')
    plt.title('Accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    
    plt.tight_layout()
    plt.savefig('training_history.png')
    plt.show()
    
    # Load the best model
    if best_model_state:
        model.load_state_dict(best_model_state)
        print(f"Loaded best model with F1: {best_f1:.4f}")
    
    # Save the final model
    torch.save(model.state_dict(), 'dsan_model_rafdb.pth')
    print("Saved final model")
    
    return model, history

# Run the training and evaluation
def run_experiment():
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    # Set random seed for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    
    # Data transformations
    train_transform, test_transform = get_transforms()
    
    # Path settings for the directory structure
    raf_db_root = "./data/rafdb/DATASET"  # Path to the dataset root
    
    try:
        # Load train dataset
        print("Loading training dataset...")
        train_dataset = RAFDBFolderDataset(
            root_dir=raf_db_root,
            split='train',
            transform=train_transform
        )
        
        # Create validation split (20% of training data)
        train_size = int(0.8 * len(train_dataset))
        val_size = len(train_dataset) - train_size
        
        train_dataset, val_dataset = torch.utils.data.random_split(
            train_dataset, 
            [train_size, val_size],
            generator=torch.Generator().manual_seed(42)
        )
        
        # Apply correct transforms to validation split
        val_dataset = copy.deepcopy(train_dataset)
        val_dataset.dataset.transform = test_transform
        
        # Create test dataset
        print("Loading test dataset...")
        test_dataset = RAFDBFolderDataset(
            root_dir=raf_db_root,
            split='test',
            transform=test_transform
        )
        
        # Handle class imbalance with WeightedRandomSampler
        # Count samples per class in training set
        class_counts = [0] * 7
        for idx in range(len(train_dataset)):
            _, label = train_dataset[idx]
            class_counts[label] += 1
        
        # Calculate weights for each sample in training set
        weights = torch.zeros(len(train_dataset))
        for idx in range(len(train_dataset)):
            _, label = train_dataset[idx]
            # Weight = 1 / class_count
            weights[idx] = 1.0 / class_counts[label] if class_counts[label] > 0 else 0
        
        # Create WeightedRandomSampler
        sampler = torch.utils.data.WeightedRandomSampler(
            weights=weights,
            num_samples=len(train_dataset),
            replacement=True
        )
        
        # Create dataloaders
        print("Creating data loaders...")
        train_loader = DataLoader(
            train_dataset,
            batch_size=32,
            sampler=sampler,
            num_workers=2,  # Reduced from 4 to avoid warnings
            pin_memory=True
        )
        
        val_loader = DataLoader(
            val_dataset,
            batch_size=32,
            shuffle=False,
            num_workers=2,
            pin_memory=True
        )
        
        test_loader = DataLoader(
            test_dataset,
            batch_size=32,
            shuffle=False,
            num_workers=2,
            pin_memory=True
        )
        
        print(f"Train dataset size: {len(train_dataset)}")
        print(f"Validation dataset size: {len(val_dataset)}")
        print(f"Test dataset size: {len(test_dataset)}")
        
        # Create model - Use enhanced model for better performance
        print("Creating model...")
        model = DSANPlus(num_classes=7, pretrained=True, dropout_rate=0.5)
        model = model.to(device)
        print(f"Model created with {count_parameters(model):,} trainable parameters")
        
        # Define whether to train or just evaluate
        train_model_flag = True
        model_path = "./dsan_model_rafdb.pth"
        
        if os.path.exists(model_path) and not train_model_flag:
            model.load_state_dict(torch.load(model_path, map_location=device))
            print(f"Loaded pretrained model from {model_path}")
            
            # Test on validation set to check performance
            print("\nEvaluating model on validation set:")
            val_acc, val_f1, _ = test_model(model, val_loader, device)
        else:
            print("\nTraining new model...")
            # Train the model
            model, history = train_model(
                model=model,
                train_loader=train_loader,
                val_loader=val_loader,
                device=device,
                epochs=30,  # Reduced for notebook environment
                patience=7   # Reduced for notebook environment
            )
        
        # Test the model
        print("\nEvaluating model on test set:")
        accuracy, mean_f1, conf_matrix = test_model(model, test_loader, device)
        
        # Visualize confusion matrix
        print("\nGenerating confusion matrix...")
        cm_fig = visualize_confusion_matrix(conf_matrix)
        cm_fig.savefig("confusion_matrix_rafdb.png")
        print("Saved confusion matrix visualization to confusion_matrix_rafdb.png")
        
        # Visualize sample predictions
        print("\nGenerating sample predictions...")
        sample_fig = visualize_sample_predictions(model, test_loader, device)
        sample_fig.savefig("sample_predictions_rafdb.png")
        print("Saved sample predictions visualization to sample_predictions_rafdb.png")
        
        # Visualize attention maps
        print("\nGenerating attention maps...")
        visualize_attention_maps(model, test_loader, device, num_samples=2)  # Reduced to 2 samples
        
        # Save evaluation results
        with open("evaluation_results_rafdb.txt", "w") as f:
            f.write(f"Test Accuracy: {accuracy:.2f}%\n")
            f.write(f"Mean F1 Score: {mean_f1:.4f}\n")
        
        print("\nEvaluation completed!")
        
    except Exception as e:
        print(f"Error during experiment: {e}")
        import traceback
        traceback.print_exc()

# Run the experiment
if __name__ == "__main__":
    run_experiment()