In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
import time
import matplotlib.pyplot as plt
import numpy as np
import os
from tqdm import tqdm

# Add these imports if you're creating a new file
# If adding to existing file, these should already be imported
# from model import DSAN
# from dataset import RAFDBFolderDataset

def train_model(model, train_loader, test_loader, criterion, optimizer, scheduler,
                device, num_epochs=25, save_path='./dsan_model_rafdb.pth',
                log_interval=10):
    """
    Train the DSAN model

    Args:
        model: Model to train
        train_loader: DataLoader for training data
        test_loader: DataLoader for testing data
        criterion: Loss function
        optimizer: Optimizer
        scheduler: Learning rate scheduler
        device: Device to train on (cuda/cpu)
        num_epochs: Number of epochs to train for
        save_path: Path to save best model weights
        log_interval: How often to log training progress within epoch

    Returns:
        Trained model and training history
    """
    start_time = time.time()
    best_acc = 0.0
    history = {
        'train_loss': [],
        'train_acc': [],
        'test_loss': [],
        'test_acc': [],
        'time_per_epoch': []
    }

    for epoch in range(num_epochs):
        epoch_start = time.time()
        print(f'Epoch {epoch+1}/{num_epochs}')
        print('-' * 10)

        # Training phase
        model.train()
        running_loss = 0.0
        running_corrects = 0
        total_samples = 0

        # Use tqdm for progress bar
        pbar = tqdm(enumerate(train_loader), total=len(train_loader))
        for batch_idx, (inputs, labels) in pbar:
            inputs = inputs.to(device)
            labels = labels.to(device)

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            # Statistics
            _, preds = torch.max(outputs, 1)
            batch_loss = loss.item() * inputs.size(0)
            batch_corrects = torch.sum(preds == labels.data).item()

            running_loss += batch_loss
            running_corrects += batch_corrects
            total_samples += inputs.size(0)

            # Update progress bar
            if batch_idx % log_interval == 0:
                batch_acc = batch_corrects / inputs.size(0)
                pbar.set_description(f'Train Loss: {loss.item():.4f} Acc: {batch_acc:.4f}')

        # Calculate epoch stats
        epoch_loss = running_loss / total_samples
        epoch_acc = running_corrects / total_samples

        history['train_loss'].append(epoch_loss)
        history['train_acc'].append(epoch_acc)

        print(f'Train Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

        # Validation phase
        model.eval()
        val_running_loss = 0.0
        val_running_corrects = 0
        val_total_samples = 0

        with torch.no_grad():
            for inputs, labels in test_loader:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Forward pass
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                # Statistics
                _, preds = torch.max(outputs, 1)
                batch_loss = loss.item() * inputs.size(0)
                batch_corrects = torch.sum(preds == labels.data).item()

                val_running_loss += batch_loss
                val_running_corrects += batch_corrects
                val_total_samples += inputs.size(0)

        # Calculate validation stats
        val_epoch_loss = val_running_loss / val_total_samples
        val_epoch_acc = val_running_corrects / val_total_samples

        history['test_loss'].append(val_epoch_loss)
        history['test_acc'].append(val_epoch_acc)

        # Time per epoch
        epoch_time = time.time() - epoch_start
        history['time_per_epoch'].append(epoch_time)

        print(f'Val Loss: {val_epoch_loss:.4f} Acc: {val_epoch_acc:.4f}')
        print(f'Epoch Time: {epoch_time:.2f}s')

        # Step the scheduler
        if scheduler:
            scheduler.step()

        # Save best model
        if val_epoch_acc > best_acc:
            best_acc = val_epoch_acc
            torch.save(model.state_dict(), save_path)
            print(f'Saved new best model with accuracy {best_acc:.4f}')

    # Load best model weights
    model.load_state_dict(torch.load(save_path))
    total_time = time.time() - start_time
    print(f'Training complete in {total_time // 60:.0f}m {total_time % 60:.0f}s')
    print(f'Best val Acc: {best_acc:.4f}')

    return model, history


def plot_training_history(history):
    """Plot training history"""
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

    # Plot training and validation accuracy
    ax1.plot(history['train_acc'], label='Train Accuracy')
    ax1.plot(history['test_acc'], label='Validation Accuracy')
    ax1.set_title('Model Accuracy')
    ax1.set_ylabel('Accuracy')
    ax1.set_xlabel('Epoch')
    ax1.legend()
    ax1.grid(True)

    # Plot training and validation loss
    ax2.plot(history['train_loss'], label='Train Loss')
    ax2.plot(history['test_loss'], label='Validation Loss')
    ax2.set_title('Model Loss')
    ax2.set_ylabel('Loss')
    ax2.set_xlabel('Epoch')
    ax2.legend()
    ax2.grid(True)

    plt.tight_layout()
    plt.savefig('training_history.png')
    return fig


def main():
    """Main function to train the model on RAF-DB dataset"""
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Data transformations with augmentation for training
    train_transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])

    # Path settings
    raf_db_root = "./data/rafdb/DATASET"  # Path to the dataset root

    try:
        # Create datasets
        train_dataset = RAFDBFolderDataset(
            root_dir=raf_db_root,
            split='train',
            transform=train_transform
        )

        test_dataset = RAFDBFolderDataset(
            root_dir=raf_db_root,
            split='test',
            transform=test_transform
        )

        # Create dataloaders
        train_loader = DataLoader(
            train_dataset,
            batch_size=32,
            shuffle=True,
            num_workers=4
        )

        test_loader = DataLoader(
            test_dataset,
            batch_size=32,
            shuffle=False,
            num_workers=4
        )

        print(f"Train dataset size: {len(train_dataset)}")
        print(f"Test dataset size: {len(test_dataset)}")

        # Create model
        model = DSAN(num_classes=7, pretrained=True)
        model = model.to(device)
        print(f"Model created with {count_parameters(model):,} trainable parameters")

        # Define loss function and optimizer
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)

        # Learning rate scheduler
        scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)

        # Train model
        model, history = train_model(
            model=model,
            train_loader=train_loader,
            test_loader=test_loader,
            criterion=criterion,
            optimizer=optimizer,
            scheduler=scheduler,
            device=device,
            num_epochs=30,  # Adjust as needed
            save_path='./dsan_model_rafdb.pth'
        )

        # Plot training history
        plot_training_history(history)

        # Evaluate model on test set after training
        accuracy, mean_f1, conf_matrix = test_model(model, test_loader, device)

        # Save evaluation results
        with open("final_evaluation_results.txt", "w") as f:
            f.write(f"Test Accuracy: {accuracy:.2f}%\n")
            f.write(f"Mean F1 Score: {mean_f1:.4f}\n")
            f.write(f"Average time per epoch: {np.mean(history['time_per_epoch']):.2f}s\n")

        print("Training and evaluation completed!")

    except Exception as e:
        print(f"Error during training: {e}")
        import traceback
        traceback.print_exc()


if __name__ == "__main__":
    main()