In [3]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import vit_b_16, ViT_B_16_Weights
from PIL import Image
from pathlib import Path
import re
from sklearn.model_selection import train_test_split
import numpy as np
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import time

class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
            verbose (bool): If True, prints a message for each validation loss improvement.
            delta (float): Minimum change in monitored quantity to qualify as an improvement.
            path (str): Path for the checkpoint to be saved to.
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when validation loss decrease.'''
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

class DirectionalSoundDataset(Dataset):
    def __init__(self, base_dir, transform=None, target_size=(224, 224)):
        self.base_dir = Path(base_dir)
        self.transform = transform
        self.target_size = target_size
        
        # Define class mapping for all vehicle types and directions
        self.class_to_idx = {
            'ambulance_L': 0,
            'ambulance_M': 1,
            'ambulance_R': 2,
            'carhorns_L': 3,
            'carhorns_M': 4,
            'carhorns_R': 5,
            'FireTruck_L': 6,
            'FireTruck_M': 7,
            'FireTruck_R': 8,
            'policecar_L': 9,
            'policecar_M': 10,
            'policecar_R': 11
        }
        
        # Collect all files and their labels
        self.files = []
        self.labels = []
        
        for class_name in self.class_to_idx.keys():
            class_dir = self.base_dir / class_name
            if class_dir.exists():
                class_files = list(class_dir.glob(f"{class_name}_*.png"))
                self.files.extend(class_files)
                self.labels.extend([self.class_to_idx[class_name]] * len(class_files))
        
        if len(self.files) == 0:
            raise RuntimeError(f"No spectrogram files found in {base_dir}")
            
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        img_path = self.files[idx]
        label = self.labels[idx]
        
        # Load and process image
        spectrogram = Image.open(img_path).convert('RGB')
        if spectrogram.size != self.target_size:
            spectrogram = spectrogram.resize(self.target_size)
        
        if self.transform:
            spectrogram = self.transform(spectrogram)
        
        return spectrogram, label

def create_data_loaders(base_dir, batch_size=32, test_size=0.2, val_size=0.1):
    # Define transforms
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225]),
    ])
    
    # Create full dataset
    full_dataset = DirectionalSoundDataset(
        base_dir=base_dir,
        transform=transform,
        target_size=(224, 224)
    )
    
    # Calculate sizes
    total_size = len(full_dataset)
    indices = list(range(total_size))
    
    # First split into train and test
    train_idx, test_idx = train_test_split(indices, test_size=test_size, random_state=42)
    
    # Then split train into train and val
    train_idx, val_idx = train_test_split(train_idx, test_size=val_size/(1-test_size), random_state=42)
    
    # Create samplers
    train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
    val_sampler = torch.utils.data.SubsetRandomSampler(val_idx)
    test_sampler = torch.utils.data.SubsetRandomSampler(test_idx)
    
    # Create data loaders with reduced num_workers to prevent memory issues
    train_loader = DataLoader(
        full_dataset, batch_size=batch_size, sampler=train_sampler,
        num_workers=2, pin_memory=True
    )
    
    val_loader = DataLoader(
        full_dataset, batch_size=batch_size, sampler=val_sampler,
        num_workers=2, pin_memory=True
    )
    
    test_loader = DataLoader(
        full_dataset, batch_size=batch_size, sampler=test_sampler,
        num_workers=2, pin_memory=True
    )
    
    return train_loader, val_loader, test_loader

class DirectionalSoundViT(nn.Module):
    def __init__(self, num_classes=12):  # Updated to 12 classes
        super().__init__()
        self.vit = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        num_features = self.vit.heads.head.in_features
        self.vit.heads.head = nn.Linear(num_features, num_classes)

    def forward(self, x):
        return self.vit(x)

def train_model(model, train_loader, val_loader, num_epochs=30):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    print(f"Using device: {device}")
    
    model = model.to(device)
    
    # Initialize optimizer and loss
    optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)
    scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
    criterion = nn.CrossEntropyLoss()
    
    # Create directory for saving checkpoints
    save_dir = Path('model_checkpoints')
    save_dir.mkdir(exist_ok=True)
    
    best_val_acc = 0.0
    
    try:
        for epoch in range(num_epochs):
            # Training phase
            model.train()
            total_loss = 0
            correct = 0
            total = 0
            start_time = time.time()
            
            for batch_idx, (spectrograms, labels) in enumerate(train_loader):
                try:
                    spectrograms, labels = spectrograms.to(device), labels.to(device)
                    
                    outputs = model(spectrograms)
                    loss = criterion(outputs, labels)
                    
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()
                    
                    total_loss += loss.item()
                    
                    _, predicted = outputs.max(1)
                    total += labels.size(0)
                    correct += predicted.eq(labels).sum().item()
                    
                    if (batch_idx + 1) % 10 == 0:
                        print(f'Epoch [{epoch+1}/{num_epochs}], '
                              f'Batch [{batch_idx+1}/{len(train_loader)}], '
                              f'Loss: {loss.item():.4f}, '
                              f'Acc: {100.*correct/total:.2f}%')
                except Exception as e:
                    print(f"Error in batch {batch_idx}: {str(e)}")
                    continue
            
            train_acc = 100.*correct/total
            avg_loss = total_loss / len(train_loader)
            epoch_time = time.time() - start_time
            
            # Validation phase
            model.eval()
            val_loss = 0
            correct = 0
            total = 0
            
            with torch.no_grad():
                for spectrograms, labels in val_loader:
                    spectrograms, labels = spectrograms.to(device), labels.to(device)
                    outputs = model(spectrograms)
                    loss = criterion(outputs, labels)
                    
                    val_loss += loss.item()
                    _, predicted = outputs.max(1)
                    total += labels.size(0)
                    correct += predicted.eq(labels).sum().item()
            
            val_acc = 100.*correct/total
            val_loss = val_loss / len(val_loader)
            
            print(f'Epoch [{epoch+1}/{num_epochs}], '
                  f'Train Loss: {avg_loss:.4f}, Train Acc: {train_acc:.2f}%, '
                  f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%, '
                  f'Time: {epoch_time:.2f}s')
            
            # Save best model
            if val_acc > best_val_acc:
                best_val_acc = val_acc
                torch.save({
                    'epoch': epoch,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'val_acc': val_acc,
                }, save_dir / 'best_model.pth')
            
            scheduler.step()
            
    except Exception as e:
        print(f"Training error: {str(e)}")
        raise e

def evaluate_model(model, test_loader):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()
    
    correct = 0
    total = 0
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for spectrograms, labels in test_loader:
            spectrograms, labels = spectrograms.to(device), labels.to(device)
            outputs = model(spectrograms)
            _, predicted = outputs.max(1)
            
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')
    
    # Print per-class accuracy
    class_names = [
        'Ambulance Left', 'Ambulance Middle', 'Ambulance Right',
        'Car Horn Left', 'Car Horn Middle', 'Car Horn Right',
        'Fire Truck Left', 'Fire Truck Middle', 'Fire Truck Right',
        'Police Car Left', 'Police Car Middle', 'Police Car Right'
    ]
    
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    
    for i, class_name in enumerate(class_names):
        class_mask = (all_labels == i)
        if np.sum(class_mask) > 0:
            class_acc = 100 * np.sum((all_predictions == i) & class_mask) / np.sum(class_mask)
            print(f'{class_name} Accuracy: {class_acc:.2f}%')
    
    return accuracy, all_predictions, all_labels

def save_model(model, optimizer, epoch, val_acc, filename):
    """Save model checkpoint with all necessary state information"""
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_acc': val_acc,
    }, filename)
    print(f"Model saved to {filename}")

def load_best_model(model, filepath):
    """Load the best model weights"""
    if not Path(filepath).exists():
        raise FileNotFoundError(f"No model checkpoint found at {filepath}")
    
    checkpoint = torch.load(filepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded model from epoch {checkpoint['epoch']} with validation accuracy {checkpoint['val_acc']:.2f}%")
    return model, checkpoint['epoch'], checkpoint['val_acc']

def inference(model, image_path, device=None):
    """Run inference on a single image"""
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    # Define the same transforms used during training
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225]),
    ])
    
    # Load and preprocess the image
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    
    # Set model to evaluation mode
    model.eval()
    model = model.to(device)
    
    # Define class names
    class_names = [
        'Ambulance Left', 'Ambulance Middle', 'Ambulance Right',
        'Car Horn Left', 'Car Horn Middle', 'Car Horn Right',
        'Fire Truck Left', 'Fire Truck Middle', 'Fire Truck Right',
        'Police Car Left', 'Police Car Middle', 'Police Car Right'
    ]
    
    with torch.no_grad():
        outputs = model(image)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        predicted_class = torch.argmax(outputs, dim=1).item()
        confidence = probabilities[0][predicted_class].item()
        
    return {
        'predicted_class': class_names[predicted_class],
        'confidence': confidence * 100,
        'all_probabilities': {
            class_name: prob.item() * 100 
            for class_name, prob in zip(class_names, probabilities[0])
        }
    }

if __name__ == "__main__":
    # Set the path to your dataset directory
    base_dir = "Dataset of warning sound types and source directions"
    checkpoint_dir = Path('model_checkpoints')
    best_model_path = checkpoint_dir / 'best_model.pth'
    
    # Create data loaders
    train_loader, val_loader, test_loader = create_data_loaders(base_dir, batch_size=32)
    
    # Create model with 12 classes
    model = DirectionalSoundViT(num_classes=12)
    
    # Training phase
    print("Starting training...")
    train_model(model, train_loader, val_loader, num_epochs=30)
    
    # Load the best model for evaluation
    print("\nLoading best model for evaluation...")
    model, best_epoch, best_val_acc = load_best_model(model, best_model_path)
    
    # Evaluate the best model
    print("\nEvaluating best model...")
    accuracy, predictions, labels = evaluate_model(model, test_loader)
    
   

Starting training...
Using device: cuda
Epoch [1/30], Batch [10/116], Loss: 2.0621, Acc: 14.69%
Epoch [1/30], Batch [20/116], Loss: 1.7429, Acc: 22.19%
Epoch [1/30], Batch [30/116], Loss: 1.3671, Acc: 25.62%
Epoch [1/30], Batch [40/116], Loss: 1.1493, Acc: 28.12%
Epoch [1/30], Batch [50/116], Loss: 1.1407, Acc: 30.06%
Epoch [1/30], Batch [60/116], Loss: 1.1808, Acc: 31.09%
Epoch [1/30], Batch [70/116], Loss: 1.1746, Acc: 32.14%
Epoch [1/30], Batch [80/116], Loss: 1.2092, Acc: 32.77%
Epoch [1/30], Batch [90/116], Loss: 1.0133, Acc: 34.10%
Epoch [1/30], Batch [100/116], Loss: 1.2345, Acc: 34.50%
Epoch [1/30], Batch [110/116], Loss: 1.0898, Acc: 34.86%
Epoch [1/30], Train Loss: 1.3790, Train Acc: 35.28%, Val Loss: 1.1333, Val Acc: 41.67%, Time: 71.76s
Epoch [2/30], Batch [10/116], Loss: 1.0982, Acc: 43.12%
Epoch [2/30], Batch [20/116], Loss: 0.9869, Acc: 45.94%
Epoch [2/30], Batch [30/116], Loss: 1.3485, Acc: 46.56%
Epoch [2/30], Batch [40/116], Loss: 0.9393, Acc: 47.50%
Epoch [2/30], Bat

  checkpoint = torch.load(filepath)


Loaded model from epoch 14 with validation accuracy 97.16%

Evaluating best model...
Test Accuracy: 97.63%
Ambulance Left Accuracy: 99.06%
Ambulance Middle Accuracy: 100.00%
Ambulance Right Accuracy: 94.85%
Car Horn Left Accuracy: 100.00%
Car Horn Middle Accuracy: 97.75%
Car Horn Right Accuracy: 100.00%
Fire Truck Left Accuracy: 98.15%
Fire Truck Middle Accuracy: 95.08%
Fire Truck Right Accuracy: 96.55%
Police Car Left Accuracy: 100.00%
Police Car Middle Accuracy: 96.77%
Police Car Right Accuracy: 92.08%


In [14]:
def test_model_inference():
    # Load the model
    model = DirectionalSoundViT(num_classes=12)
    print("\nLoading best model for evaluation...")
    model, best_epoch, best_val_acc = load_best_model(model, 'model_checkpoints/best_model.pth')
    print(f"Loaded model from epoch {best_epoch} with validation accuracy {best_val_acc:.2f}%")

    # Run example inference
    print("\nRunning example inference...")
    # test_image_path = "./Dataset of warning sound types and source directions/noise/noise_22.png"
    test_image_path = "./test/test_output/final_stitched.png"

    
    if Path(test_image_path).exists():
        result = inference(model, test_image_path)
        
        print(f"\nInference results:")
        print(f"Predicted class: {result['predicted_class']}")
        print(f"Confidence: {result['confidence']:.2f}%")
        print("\nAll class probabilities:")
        for class_name, prob in result['all_probabilities'].items():
            print(f"{class_name}: {prob:.2f}%")
    else:
        print(f"Error: Test image not found at {test_image_path}")

if __name__ == "__main__":
    test_model_inference()


Loading best model for evaluation...


  checkpoint = torch.load(filepath)


Loaded model from epoch 14 with validation accuracy 97.16%
Loaded model from epoch 14 with validation accuracy 97.16%

Running example inference...

Inference results:
Predicted class: Car Horn Left
Confidence: 39.18%

All class probabilities:
Ambulance Left: 1.63%
Ambulance Middle: 5.02%
Ambulance Right: 3.26%
Car Horn Left: 39.18%
Car Horn Middle: 5.32%
Car Horn Right: 4.40%
Fire Truck Left: 0.33%
Fire Truck Middle: 38.82%
Fire Truck Right: 0.84%
Police Car Left: 0.05%
Police Car Middle: 0.92%
Police Car Right: 0.23%
