In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import vit_b_16, ViT_B_16_Weights
from PIL import Image
from pathlib import Path
import re
from sklearn.model_selection import train_test_split
import numpy as np
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import time

class DirectionalSoundDataset(Dataset):
    def __init__(self, base_dir, transform=None, target_size=(224, 224)):
        self.base_dir = Path(base_dir)
        self.transform = transform
        self.target_size = target_size
        
        # Define class mapping for all vehicle types and directions
        self.class_to_idx = {
            'ambulance_L': 0,
            'ambulance_M': 1,
            'ambulance_R': 2,
            'carhorns_L': 3,
            'carhorns_M': 4,
            'carhorns_R': 5,
            'FireTruck_L': 6,
            'FireTruck_M': 7,
            'FireTruck_R': 8,
            'policecar_L': 9,
            'policecar_M': 10,
            'policecar_R': 11
        }
        
        # Collect all files and their labels
        self.files = []
        self.labels = []
        
        for class_name in self.class_to_idx.keys():
            class_dir = self.base_dir / class_name
            if class_dir.exists():
                class_files = list(class_dir.glob(f"{class_name}_*.png"))
                self.files.extend(class_files)
                self.labels.extend([self.class_to_idx[class_name]] * len(class_files))
        
        if len(self.files) == 0:
            raise RuntimeError(f"No spectrogram files found in {base_dir}")
            
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        img_path = self.files[idx]
        label = self.labels[idx]
        
        # Load and process image
        spectrogram = Image.open(img_path).convert('RGB')
        if spectrogram.size != self.target_size:
            spectrogram = spectrogram.resize(self.target_size)
        
        if self.transform:
            spectrogram = self.transform(spectrogram)
        
        return spectrogram, label

def create_data_loaders(base_dir, batch_size=32, test_size=0.2, val_size=0.1):
    # Define transforms
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225]),
    ])
    
    # Create full dataset
    full_dataset = DirectionalSoundDataset(
        base_dir=base_dir,
        transform=transform,
        target_size=(224, 224)
    )
    
    # Calculate sizes
    total_size = len(full_dataset)
    indices = list(range(total_size))
    
    # First split into train and test
    train_idx, test_idx = train_test_split(indices, test_size=test_size, random_state=42)
    
    # Then split train into train and val
    train_idx, val_idx = train_test_split(train_idx, test_size=val_size/(1-test_size), random_state=42)
    
    # Create samplers
    train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
    val_sampler = torch.utils.data.SubsetRandomSampler(val_idx)
    test_sampler = torch.utils.data.SubsetRandomSampler(test_idx)
    
    # Create data loaders with reduced num_workers to prevent memory issues
    train_loader = DataLoader(
        full_dataset, batch_size=batch_size, sampler=train_sampler,
        num_workers=2, pin_memory=True
    )
    
    val_loader = DataLoader(
        full_dataset, batch_size=batch_size, sampler=val_sampler,
        num_workers=2, pin_memory=True
    )
    
    test_loader = DataLoader(
        full_dataset, batch_size=batch_size, sampler=test_sampler,
        num_workers=2, pin_memory=True
    )
    
    return train_loader, val_loader, test_loader

class DirectionalSoundViT(nn.Module):
    def __init__(self, num_classes=12):  # Updated to 12 classes
        super().__init__()
        self.vit = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        num_features = self.vit.heads.head.in_features
        self.vit.heads.head = nn.Linear(num_features, num_classes)

    def forward(self, x):
        return self.vit(x)

def train_model(model, train_loader, val_loader, num_epochs=30):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    model = model.to(device)
    
    # Initialize optimizer and loss
    optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
    criterion = nn.CrossEntropyLoss()
    
    # Create directory for saving checkpoints
    save_dir = Path('model_checkpoints')
    save_dir.mkdir(exist_ok=True)
    
    best_val_acc = 0.0
    
    try:
        for epoch in range(num_epochs):
            # Training phase
            model.train()
            total_loss = 0
            correct = 0
            total = 0
            start_time = time.time()
            
            for batch_idx, (spectrograms, labels) in enumerate(train_loader):
                try:
                    spectrograms, labels = spectrograms.to(device), labels.to(device)
                    
                    outputs = model(spectrograms)
                    loss = criterion(outputs, labels)
                    
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    
                    total_loss += loss.item()
                    
                    _, predicted = outputs.max(1)
                    total += labels.size(0)
                    correct += predicted.eq(labels).sum().item()
                    
                    if (batch_idx + 1) % 10 == 0:
                        print(f'Epoch [{epoch+1}/{num_epochs}], '
                              f'Batch [{batch_idx+1}/{len(train_loader)}], '
                              f'Loss: {loss.item():.4f}, '
                              f'Acc: {100.*correct/total:.2f}%')
                except Exception as e:
                    print(f"Error in batch {batch_idx}: {str(e)}")
                    continue
            
            train_acc = 100.*correct/total
            avg_loss = total_loss / len(train_loader)
            epoch_time = time.time() - start_time
            
            # Validation phase
            model.eval()
            val_loss = 0
            correct = 0
            total = 0
            
            with torch.no_grad():
                for spectrograms, labels in val_loader:
                    spectrograms, labels = spectrograms.to(device), labels.to(device)
                    outputs = model(spectrograms)
                    loss = criterion(outputs, labels)
                    
                    val_loss += loss.item()
                    _, predicted = outputs.max(1)
                    total += labels.size(0)
                    correct += predicted.eq(labels).sum().item()
            
            val_acc = 100.*correct/total
            val_loss = val_loss / len(val_loader)
            
            print(f'Epoch [{epoch+1}/{num_epochs}], '
                  f'Train Loss: {avg_loss:.4f}, Train Acc: {train_acc:.2f}%, '
                  f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%, '
                  f'Time: {epoch_time:.2f}s')
            
            # Save best model
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_acc': val_acc,
                }, save_dir / 'best_model.pth')
            
            scheduler.step()
            
    except Exception as e:
        print(f"Training error: {str(e)}")
        raise e

def evaluate_model(model, test_loader):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()
    
    correct = 0
    total = 0
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for spectrograms, labels in test_loader:
            spectrograms, labels = spectrograms.to(device), labels.to(device)
            outputs = model(spectrograms)
            _, predicted = outputs.max(1)
            
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')
    
    # Print per-class accuracy
    class_names = [
        'Ambulance Left', 'Ambulance Middle', 'Ambulance Right',
        'Car Horn Left', 'Car Horn Middle', 'Car Horn Right',
        'Fire Truck Left', 'Fire Truck Middle', 'Fire Truck Right',
        'Police Car Left', 'Police Car Middle', 'Police Car Right'
    ]
    
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    
    for i, class_name in enumerate(class_names):
        class_mask = (all_labels == i)
        if np.sum(class_mask) > 0:
            class_acc = 100 * np.sum((all_predictions == i) & class_mask) / np.sum(class_mask)
            print(f'{class_name} Accuracy: {class_acc:.2f}%')
    
    return accuracy, all_predictions, all_labels

if __name__ == "__main__":
    # Set the path to your dataset directory
    base_dir = "Dataset of warning sound types and source directions"
    
    # Create data loaders
    train_loader, val_loader, test_loader = create_data_loaders(base_dir, batch_size=32)
    
    # Create model with 12 classes
    model = DirectionalSoundViT(num_classes=12)
    
    # Train model
    train_model(model, train_loader, val_loader, num_epochs=30)
    
    # Evaluate model
    accuracy, predictions, labels = evaluate_model(model, test_loader)

Using device: cuda
Epoch [1/30], Batch [10/116], Loss: 2.3690, Acc: 11.25%
Epoch [1/30], Batch [20/116], Loss: 1.5902, Acc: 19.53%
Epoch [1/30], Batch [30/116], Loss: 1.5932, Acc: 21.56%
Epoch [1/30], Batch [40/116], Loss: 1.4731, Acc: 24.22%
Epoch [1/30], Batch [50/116], Loss: 1.1913, Acc: 25.31%
Epoch [1/30], Batch [60/116], Loss: 1.1824, Acc: 26.72%
Epoch [1/30], Batch [70/116], Loss: 1.1839, Acc: 29.06%
Epoch [1/30], Batch [80/116], Loss: 1.3216, Acc: 29.65%
Epoch [1/30], Batch [90/116], Loss: 1.0653, Acc: 30.17%
Epoch [1/30], Batch [100/116], Loss: 1.0722, Acc: 30.81%
Epoch [1/30], Batch [110/116], Loss: 1.0634, Acc: 31.70%
Epoch [1/30], Train Loss: 1.4147, Train Acc: 32.49%, Val Loss: 1.1910, Val Acc: 37.31%, Time: 71.76s
Epoch [2/30], Batch [10/116], Loss: 1.1309, Acc: 44.38%
Epoch [2/30], Batch [20/116], Loss: 0.9869, Acc: 47.03%
Epoch [2/30], Batch [30/116], Loss: 1.0407, Acc: 47.60%
Epoch [2/30], Batch [40/116], Loss: 0.8112, Acc: 48.98%
Epoch [2/30], Batch [50/116], Loss: 0.