In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import vit_b_16, ViT_B_16_Weights
from PIL import Image
from pathlib import Path
import re
from sklearn.model_selection import train_test_split
import numpy as np
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import time

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
            verbose (bool): If True, prints a message for each validation loss improvement.
            delta (float): Minimum change in monitored quantity to qualify as an improvement.
            path (str): Path for the checkpoint to be saved to.
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

class DirectionalSoundDataset(Dataset):
    def __init__(self, base_dir, transform=None, target_size=(224, 224)):
        self.base_dir = Path(base_dir)
        self.transform = transform
        self.target_size = target_size
        
        # Define class mapping for all vehicle types and directions
        self.class_to_idx = {
            'ambulance_L': 0,
            'ambulance_R': 1,
            'carhorns_L': 2,
            'carhorns_R': 3,
            'FireTruck_L': 4,
            'FireTruck_R': 5,
            'policecar_L': 6,
            'policecar_R': 7
        }
        
        # Collect all files and their labels
        self.files = []
        self.labels = []
        
        for class_name in self.class_to_idx.keys():
            class_dir = self.base_dir / class_name
            if class_dir.exists():
                class_files = list(class_dir.glob(f"{class_name}_*.png"))
                self.files.extend(class_files)
                self.labels.extend([self.class_to_idx[class_name]] * len(class_files))
        
        if len(self.files) == 0:
            raise RuntimeError(f"No spectrogram files found in {base_dir}")
            
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        img_path = self.files[idx]
        label = self.labels[idx]
        
        # Load and process image
        spectrogram = Image.open(img_path).convert('RGB')
        if spectrogram.size != self.target_size:
            spectrogram = spectrogram.resize(self.target_size)
        
        if self.transform:
            spectrogram = self.transform(spectrogram)
        
        return spectrogram, label

def create_data_loaders(base_dir, batch_size=16):
    """
    Create data loaders with 70-10-20 split for training, validation, and test sets.
    
    Args:
        base_dir (str): Path to the dataset directory
        batch_size (int): Batch size for the data loaders
    
    Returns:
        tuple: (train_loader, val_loader, test_loader)
    """
    # Define transforms
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225]),
    ])
    
    # Create full dataset
    full_dataset = DirectionalSoundDataset(
        base_dir=base_dir,
        transform=transform,
        target_size=(224, 224)
    )
    
    # Calculate sizes
    total_size = len(full_dataset)
    test_size = 0.2  # 20% for test set
    val_size = 0.1   # 10% for validation set
    train_size = 0.7 # 70% for training set
    
    # Create indices for the splits
    indices = list(range(total_size))
    
    # First split: separate out test set (20%)
    train_val_idx, test_idx = train_test_split(
        indices,
        test_size=test_size,
        random_state=42,
        shuffle=True
    )
    
    # Second split: separate training (70%) and validation (10%) from the remaining 80%
    # To get 10% validation from the remaining 80%, we need val_size/(train_size + val_size) = 0.125
    train_idx, val_idx = train_test_split(
        train_val_idx,
        test_size=val_size/(train_size + val_size),  # This gives us 10% of total dataset
        random_state=42,
        shuffle=True
    )
    
    # Verify split sizes
    print(f"\nDataset split sizes:")
    print(f"Total dataset size: {total_size}")
    print(f"Training set size: {len(train_idx)} ({len(train_idx)/total_size*100:.1f}%)")
    print(f"Validation set size: {len(val_idx)} ({len(val_idx)/total_size*100:.1f}%)")
    print(f"Test set size: {len(test_idx)} ({len(test_idx)/total_size*100:.1f}%)\n")
    
    # Create samplers
    train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
    val_sampler = torch.utils.data.SubsetRandomSampler(val_idx)
    test_sampler = torch.utils.data.SubsetRandomSampler(test_idx)
    
    # Create data loaders
    train_loader = DataLoader(
        full_dataset,
        batch_size=batch_size,
        sampler=train_sampler,
        num_workers=2,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        full_dataset,
        batch_size=batch_size,
        sampler=val_sampler,
        num_workers=2,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        full_dataset,
        batch_size=batch_size,
        sampler=test_sampler,
        num_workers=2,
        pin_memory=True
    )
    
    # Print dataset information
    print("Dataset class distribution:")
    class_counts = {class_name: 0 for class_name in full_dataset.class_to_idx.keys()}
    for idx in range(len(full_dataset)):
        _, label = full_dataset[idx]
        for class_name, class_idx in full_dataset.class_to_idx.items():
            if label == class_idx:
                class_counts[class_name] += 1
    
    for class_name, count in class_counts.items():
        print(f"{class_name}: {count} samples")
    
    return train_loader, val_loader, test_loader

class DirectionalSoundViT(nn.Module):
    def __init__(self, num_classes=8):  # Updated to 12 classes
        super().__init__()
        self.vit = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        num_features = self.vit.heads.head.in_features
        self.vit.heads.head = nn.Linear(num_features, num_classes)

    def forward(self, x):
        return self.vit(x)

import matplotlib.pyplot as plt

def plot_training_metrics(train_losses, val_losses, train_accs, val_accs, save_path='training_metrics.png'):
    """
    Plot training and validation metrics.
    
    Args:
        train_losses (list): Training losses per epoch
        val_losses (list): Validation losses per epoch
        train_accs (list): Training accuracies per epoch
        val_accs (list): Validation accuracies per epoch
        save_path (str): Path to save the plot
    """
    epochs = range(1, len(train_losses) + 1)
    
    plt.figure(figsize=(15, 5))
    
    # Plot losses
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, 'b-', label='Training Loss')
    plt.plot(epochs, val_losses, 'r-', label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    # Plot accuracies
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accs, 'b-', label='Training Accuracy')
    plt.plot(epochs, val_accs, 'r-', label='Validation Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def train_model(model, train_loader, val_loader, num_epochs=30):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    model = model.to(device)
    
    # Initialize optimizer and loss
    optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
    criterion = nn.CrossEntropyLoss()
    
    # Create directory for saving checkpoints
    save_dir = Path('model_checkpoints')
    save_dir.mkdir(exist_ok=True)
    
    best_val_acc = 0.0
    
    # Lists to store metrics
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []
    
    try:
        for epoch in range(num_epochs):
            # Training phase
            model.train()
            total_loss = 0
            correct = 0
            total = 0
            start_time = time.time()
            
            for batch_idx, (spectrograms, labels) in enumerate(train_loader):
                try:
                    spectrograms, labels = spectrograms.to(device), labels.to(device)
                    
                    outputs = model(spectrograms)
                    loss = criterion(outputs, labels)
                    
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    
                    total_loss += loss.item()
                    
                    _, predicted = outputs.max(1)
                    total += labels.size(0)
                    correct += predicted.eq(labels).sum().item()
                    
                    if (batch_idx + 1) % 10 == 0:
                        print(f'Epoch [{epoch+1}/{num_epochs}], '
                              f'Batch [{batch_idx+1}/{len(train_loader)}], '
                              f'Loss: {loss.item():.4f}, '
                              f'Acc: {100.*correct/total:.2f}%')
                except Exception as e:
                    print(f"Error in batch {batch_idx}: {str(e)}")
                    continue
            
            train_acc = 100.*correct/total
            avg_train_loss = total_loss / len(train_loader)
            epoch_time = time.time() - start_time
            
            # Store training metrics
            train_losses.append(avg_train_loss)
            train_accs.append(train_acc)
            
            # Validation phase
            model.eval()
            val_loss = 0
            correct = 0
            total = 0
            
            with torch.no_grad():
                for spectrograms, labels in val_loader:
                    spectrograms, labels = spectrograms.to(device), labels.to(device)
                    outputs = model(spectrograms)
                    loss = criterion(outputs, labels)
                    
                    val_loss += loss.item()
                    _, predicted = outputs.max(1)
                    total += labels.size(0)
                    correct += predicted.eq(labels).sum().item()
            
            val_acc = 100.*correct/total
            avg_val_loss = val_loss / len(val_loader)
            
            # Store validation metrics
            val_losses.append(avg_val_loss)
            val_accs.append(val_acc)
            
            print(f'Epoch [{epoch+1}/{num_epochs}], '
                  f'Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
                  f'Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.2f}%, '
                  f'Time: {epoch_time:.2f}s')
            
            # Save best model
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_acc': val_acc,
                }, save_dir / 'short_best_model.pth')
            
            scheduler.step()
            
            # Plot and save training metrics after each epoch
            plot_training_metrics(train_losses, val_losses, train_accs, val_accs)
            
    except Exception as e:
        print(f"Training error: {str(e)}")
        raise e
    
    # Return the training history
    return {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_accs': train_accs,
        'val_accs': val_accs
    }

def evaluate_model(model, test_loader):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()
    
    correct = 0
    total = 0
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for spectrograms, labels in test_loader:
            spectrograms, labels = spectrograms.to(device), labels.to(device)
            outputs = model(spectrograms)
            _, predicted = outputs.max(1)
            
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')
    
    # Print per-class accuracy
    class_names = [
        'Ambulance Left', 'Ambulance Right',
        'Car Horn Left', 'Car Horn Right',
        'Fire Truck Left',  'Fire Truck Right',
        'Police Car Left',  'Police Car Right'
    ]
    
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    
    for i, class_name in enumerate(class_names):
        class_mask = (all_labels == i)
        if np.sum(class_mask) > 0:
            class_acc = 100 * np.sum((all_predictions == i) & class_mask) / np.sum(class_mask)
            print(f'{class_name} Accuracy: {class_acc:.2f}%')
    
    return accuracy, all_predictions, all_labels

def save_model(model, optimizer, epoch, val_acc, filename):
    """Save model checkpoint with all necessary state information"""
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_acc': val_acc,
    }, filename)
    print(f"Model saved to {filename}")

def load_best_model(model, filepath):
    """Load the best model weights"""
    if not Path(filepath).exists():
        raise FileNotFoundError(f"No model checkpoint found at {filepath}")
    
    checkpoint = torch.load(filepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded model from epoch {checkpoint['epoch']} with validation accuracy {checkpoint['val_acc']:.2f}%")
    return model, checkpoint['epoch'], checkpoint['val_acc']

def inference(model, image_path, device=None):
    """Run inference on a single image"""
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Define the same transforms used during training
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225]),
    ])
    
    # Load and preprocess the image
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    
    # Set model to evaluation mode
    model.eval()
    model = model.to(device)
    
    # Define class names
    class_names = [
        'Ambulance Left',  'Ambulance Right',
        'Car Horn Left',  'Car Horn Right',
        'Fire Truck Left',  'Fire Truck Right',
        'Police Car Left', 'Police Car Right'
    ]
    
    with torch.no_grad():
        outputs = model(image)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        predicted_class = torch.argmax(outputs, dim=1).item()
        confidence = probabilities[0][predicted_class].item()
        
    return {
        'predicted_class': class_names[predicted_class],
        'confidence': confidence * 100,
        'all_probabilities': {
            class_name: prob.item() * 100 
            for class_name, prob in zip(class_names, probabilities[0])
        }
    }

if __name__ == "__main__":
    # Set the path to your dataset directory
    base_dir = "Dataset of warning sound types and source directions"
    checkpoint_dir = Path('model_checkpoints')
    best_model_path = checkpoint_dir / 'short_best_model.pth'
    
    # Create data loaders
    train_loader, val_loader, test_loader = create_data_loaders(base_dir, batch_size=16)
    
    # Create model with 12 classes
    model = DirectionalSoundViT(num_classes=8)
    
    # Training phase
    print("Starting training...")
    train_model(model, train_loader, val_loader, num_epochs=20)
    
    # Load the best model for evaluation
    print("\nLoading best model for evaluation...")
    model, best_epoch, best_val_acc = load_best_model(model, best_model_path)
    
    # Evaluate the best model
    print("\nEvaluating best model...")
    accuracy, predictions, labels = evaluate_model(model, test_loader)
    
   


Dataset split sizes:
Total dataset size: 3840
Training set size: 2687 (70.0%)
Validation set size: 385 (10.0%)
Test set size: 768 (20.0%)

Dataset class distribution:
ambulance_L: 480 samples
ambulance_R: 480 samples
carhorns_L: 480 samples
carhorns_R: 480 samples
FireTruck_L: 480 samples
FireTruck_R: 480 samples
policecar_L: 480 samples
policecar_R: 480 samples
Starting training...
Using device: cuda
Epoch [1/20], Batch [10/168], Loss: 1.5206, Acc: 21.88%
Epoch [1/20], Batch [20/168], Loss: 1.0574, Acc: 34.38%
Epoch [1/20], Batch [30/168], Loss: 0.8901, Acc: 37.50%
Epoch [1/20], Batch [40/168], Loss: 0.7260, Acc: 42.19%
Epoch [1/20], Batch [50/168], Loss: 0.9116, Acc: 44.38%
Epoch [1/20], Batch [60/168], Loss: 0.8251, Acc: 46.35%
Epoch [1/20], Batch [70/168], Loss: 0.7618, Acc: 47.50%
Epoch [1/20], Batch [80/168], Loss: 0.7926, Acc: 48.67%
Epoch [1/20], Batch [90/168], Loss: 0.8101, Acc: 49.44%
Epoch [1/20], Batch [100/168], Loss: 0.7696, Acc: 50.31%
Epoch [1/20], Batch [110/168], Lo

  checkpoint = torch.load(filepath)


Loaded model from epoch 7 with validation accuracy 95.84%

Evaluating best model...
Test Accuracy: 92.71%
Ambulance Left Accuracy: 98.33%
Ambulance Right Accuracy: 98.82%
Car Horn Left Accuracy: 97.00%
Car Horn Right Accuracy: 100.00%
Fire Truck Left Accuracy: 89.25%
Fire Truck Right Accuracy: 57.47%
Police Car Left Accuracy: 98.82%
Police Car Right Accuracy: 97.94%


In [25]:
def test_model_inference():
    # Load the model
    model = DirectionalSoundViT(num_classes=8)
    print("\nLoading best model for evaluation...")
    model, best_epoch, best_val_acc = load_best_model(model, 'model_checkpoints/short_best_model.pth')
    print(f"Loaded model from epoch {best_epoch} with validation accuracy {best_val_acc:.2f}%")

    # Run example inference
    print("\nRunning example inference...")
    test_image_path = "./Dataset of warning sound types and source directions/noise/noise_16.png"
    # test_image_path = "./test/test_output/final_stitched.png"

    
    if Path(test_image_path).exists():
        # Run multiple inference iterations
        num_runs = 10
        times = []
        
        for _ in range(num_runs):
            start_time = time.perf_counter()
            result = inference(model, test_image_path)
            end_time = time.perf_counter()
            times.append(end_time - start_time)
        
        # Calculate statistics
        import statistics
        avg_time = statistics.mean(times) * 1000
        std_dev = statistics.stdev(times) * 1000


        print(f"\nInference results:")
        print(f"Predicted class: {result['predicted_class']}")
        print(f"Confidence: {result['confidence']:.2f}%")
        print(f"\nInference Time:")
        print(f"Average: {avg_time:.4f} milliseconds")
        print(f"Standard Deviation: {std_dev:.4f} milliseconds")
        print("\nAll class probabilities:")
        for class_name, prob in result['all_probabilities'].items():
            print(f"{class_name}: {prob:.2f}%")
    else:
        print(f"Error: Test image not found at {test_image_path}")
        


if __name__ == "__main__":
    test_model_inference()


Loading best model for evaluation...


  checkpoint = torch.load(filepath)


Loaded model from epoch 14 with validation accuracy 97.16%
Loaded model from epoch 14 with validation accuracy 97.16%

Running example inference...

Inference results:
Predicted class: Car Horn Right
Confidence: 71.23%

Inference Time:
Average: 18.3043 milliseconds
Standard Deviation: 13.8734 milliseconds

All class probabilities:
Ambulance Left: 0.37%
Ambulance Middle: 0.90%
Ambulance Right: 7.53%
Car Horn Left: 0.07%
Car Horn Middle: 0.02%
Car Horn Right: 71.23%
Fire Truck Left: 0.11%
Fire Truck Middle: 0.09%
Fire Truck Right: 0.16%
Police Car Left: 0.38%
Police Car Middle: 0.54%
Police Car Right: 18.59%


In [None]:
import matplotlib.pyplot as plt

def plot_training_metrics(train_losses, val_losses, train_accs, val_accs, save_path='training_metrics.png'):
    """
    Plot training and validation metrics.
    
    Args:
        train_losses (list): Training losses per epoch
        val_losses (list): Validation losses per epoch
        train_accs (list): Training accuracies per epoch
        val_accs (list): Validation accuracies per epoch
        save_path (str): Path to save the plot
    """
    epochs = range(1, len(train_losses) + 1)
    
    plt.figure(figsize=(15, 5))
    
    # Plot losses
    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, 'b-', label='Training Loss')
    plt.plot(epochs, val_losses, 'r-', label='Validation Loss')
    plt.title('Training and Validation Loss')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()
    plt.grid(True)
    
    # Plot accuracies
    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accs, 'b-', label='Training Accuracy')
    plt.plot(epochs, val_accs, 'r-', label='Validation Accuracy')
    plt.title('Training and Validation Accuracy')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.grid(True)
    
    plt.tight_layout()
    plt.savefig(save_path)
    plt.close()

def train_model(model, train_loader, val_loader, num_epochs=30):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    model = model.to(device)
    
    # Initialize optimizer and loss
    optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
    criterion = nn.CrossEntropyLoss()
    
    # Create directory for saving checkpoints
    save_dir = Path('model_checkpoints')
    save_dir.mkdir(exist_ok=True)
    
    best_val_acc = 0.0
    
    # Lists to store metrics
    train_losses = []
    val_losses = []
    train_accs = []
    val_accs = []
    
    try:
        for epoch in range(num_epochs):
            # Training phase
            model.train()
            total_loss = 0
            correct = 0
            total = 0
            start_time = time.time()
            
            for batch_idx, (spectrograms, labels) in enumerate(train_loader):
                try:
                    spectrograms, labels = spectrograms.to(device), labels.to(device)
                    
                    outputs = model(spectrograms)
                    loss = criterion(outputs, labels)
                    
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    
                    total_loss += loss.item()
                    
                    _, predicted = outputs.max(1)
                    total += labels.size(0)
                    correct += predicted.eq(labels).sum().item()
                    
                    if (batch_idx + 1) % 10 == 0:
                        print(f'Epoch [{epoch+1}/{num_epochs}], '
                              f'Batch [{batch_idx+1}/{len(train_loader)}], '
                              f'Loss: {loss.item():.4f}, '
                              f'Acc: {100.*correct/total:.2f}%')
                except Exception as e:
                    print(f"Error in batch {batch_idx}: {str(e)}")
                    continue
            
            train_acc = 100.*correct/total
            avg_train_loss = total_loss / len(train_loader)
            epoch_time = time.time() - start_time
            
            # Store training metrics
            train_losses.append(avg_train_loss)
            train_accs.append(train_acc)
            
            # Validation phase
            model.eval()
            val_loss = 0
            correct = 0
            total = 0
            
            with torch.no_grad():
                for spectrograms, labels in val_loader:
                    spectrograms, labels = spectrograms.to(device), labels.to(device)
                    outputs = model(spectrograms)
                    loss = criterion(outputs, labels)
                    
                    val_loss += loss.item()
                    _, predicted = outputs.max(1)
                    total += labels.size(0)
                    correct += predicted.eq(labels).sum().item()
            
            val_acc = 100.*correct/total
            avg_val_loss = val_loss / len(val_loader)
            
            # Store validation metrics
            val_losses.append(avg_val_loss)
            val_accs.append(val_acc)
            
            print(f'Epoch [{epoch+1}/{num_epochs}], '
                  f'Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.2f}%, '
                  f'Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.2f}%, '
                  f'Time: {epoch_time:.2f}s')
            
            # Save best model
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_acc': val_acc,
                }, save_dir / 'short_best_model.pth')
            
            scheduler.step()
            
            # Plot and save training metrics after each epoch
            plot_training_metrics(train_losses, val_losses, train_accs, val_accs)
            
    except Exception as e:
        print(f"Training error: {str(e)}")
        raise e
    
    # Return the training history
    return {
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_accs': train_accs,
        'val_accs': val_accs
    }