In [6]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.models import vit_b_16, ViT_B_16_Weights
from PIL import Image
from pathlib import Path
import re
from sklearn.model_selection import train_test_split
import numpy as np
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
import time

# Early stopping class remains unchanged
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt'):
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf
        self.delta = delta
        self.path = path

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
        torch.save(model.state_dict(), self.path)
        self.val_loss_min = val_loss

class DirectionalSoundDataset(Dataset):
    def __init__(self, base_dir, transform=None, target_size=(224, 224)):
        self.base_dir = Path(base_dir)
        self.transform = transform
        self.target_size = target_size
        
        # Updated class mapping to include bike categories
        self.class_to_idx = {
            'ambulance_L': 0,
            'ambulance_M': 1,
            'ambulance_R': 2,
            'carhorns_L': 3,
            'carhorns_M': 4,
            'carhorns_R': 5,
            'FireTruck_L': 6,
            'FireTruck_M': 7,
            'FireTruck_R': 8,
            'policecar_L': 9,
            'policecar_M': 10,
            'policecar_R': 11,
            'Bike_L': 12,
            'Bike_R': 13,
            'Bike_B': 14
        }
        
        # Rest of the DirectionalSoundDataset implementation remains the same
        self.files = []
        self.labels = []
        
        for class_name in self.class_to_idx.keys():
            class_dir = self.base_dir / class_name
            if class_dir.exists():
                class_files = list(class_dir.glob(f"{class_name}_*.png"))
                self.files.extend(class_files)
                self.labels.extend([self.class_to_idx[class_name]] * len(class_files))
        
        if len(self.files) == 0:
            raise RuntimeError(f"No spectrogram files found in {base_dir}")
            
    def __len__(self):
        return len(self.files)
    
    def __getitem__(self, idx):
        img_path = self.files[idx]
        label = self.labels[idx]
        
        spectrogram = Image.open(img_path).convert('RGB')
        if spectrogram.size != self.target_size:
            spectrogram = spectrogram.resize(self.target_size)
        
        if self.transform:
            spectrogram = self.transform(spectrogram)
        
        return spectrogram, label

def create_data_loaders(base_dir, batch_size=16):
    """
    Create data loaders with a 70-10-20 split for training, validation, and test sets.
    
    Args:
        base_dir (str): Path to the dataset directory
        batch_size (int): Batch size for the data loaders
    
    Returns:
        tuple: (train_loader, val_loader, test_loader)
    """
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225]),
    ])
    
    # Create full dataset
    full_dataset = DirectionalSoundDataset(
        base_dir=base_dir,
        transform=transform,
        target_size=(224, 224)
    )
    
    # Calculate sizes
    total_size = len(full_dataset)
    test_size = 0.2  # 20% for test set
    val_size = 0.1   # 10% for validation set
    train_size = 0.7 # 70% for training set
    
    # Create indices for the splits
    indices = list(range(total_size))
    
    # First split: separate out test set (20%)
    train_val_idx, test_idx = train_test_split(
        indices,
        test_size=test_size,
        random_state=42,
        shuffle=True
    )
    
    # Second split: separate training (70%) and validation (10%) from the remaining 80%
    # To get 10% validation from the remaining 80%, we need val_size/(train_size + val_size) = 0.125
    train_idx, val_idx = train_test_split(
        train_val_idx,
        test_size=val_size/(train_size + val_size),
        random_state=42,
        shuffle=True
    )
    
    # Verify split sizes
    print(f"Total dataset size: {total_size}")
    print(f"Training set size: {len(train_idx)} ({len(train_idx)/total_size*100:.1f}%)")
    print(f"Validation set size: {len(val_idx)} ({len(val_idx)/total_size*100:.1f}%)")
    print(f"Test set size: {len(test_idx)} ({len(test_idx)/total_size*100:.1f}%)")
    
    # Create samplers
    train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
    val_sampler = torch.utils.data.SubsetRandomSampler(val_idx)
    test_sampler = torch.utils.data.SubsetRandomSampler(test_idx)
    
    # Create data loaders
    train_loader = DataLoader(
        full_dataset,
        batch_size=batch_size,
        sampler=train_sampler,
        num_workers=2,
        pin_memory=True
    )
    
    val_loader = DataLoader(
        full_dataset,
        batch_size=batch_size,
        sampler=val_sampler,
        num_workers=2,
        pin_memory=True
    )
    
    test_loader = DataLoader(
        full_dataset,
        batch_size=batch_size,
        sampler=test_sampler,
        num_workers=2,
        pin_memory=True
    )
    
    return train_loader, val_loader, test_loader

# Optional: Function to verify the split distribution across classes
def verify_split_distribution(train_loader, val_loader, test_loader, num_classes=15):
    """
    Verify the distribution of classes across the different splits.
    """
    def count_samples_per_class(loader):
        class_counts = torch.zeros(num_classes)
        for _, labels in loader:
            for label in labels:
                class_counts[label] += 1
        return class_counts
    
    train_dist = count_samples_per_class(train_loader)
    val_dist = count_samples_per_class(val_loader)
    test_dist = count_samples_per_class(test_loader)
    
    print("\nClass distribution in splits:")
    print(f"{'Class':<15} {'Train':>10} {'Val':>10} {'Test':>10}")
    print("-" * 45)
    
    for i in range(num_classes):
        total = train_dist[i] + val_dist[i] + test_dist[i]
        if total > 0:  # Only print if class has samples
            train_pct = train_dist[i] / total * 100
            val_pct = val_dist[i] / total * 100
            test_pct = test_dist[i] / total * 100
            print(f"{i:<15} {train_pct:>9.1f}% {val_pct:>9.1f}% {test_pct:>9.1f}%")

class DirectionalSoundViT(nn.Module):
    def __init__(self, num_classes=15):  # Updated to 15 classes to include bike categories
        super().__init__()
        self.vit = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_V1)
        num_features = self.vit.heads.head.in_features
        self.vit.heads.head = nn.Linear(num_features, num_classes)
    def forward(self, x):
        return self.vit(x)

def train_model(model, train_loader, val_loader, num_epochs=30):
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {device}")
        
        model = model.to(device)
        
        # Initialize optimizer and loss
        optimizer = AdamW(model.parameters(), lr=1e-4, weight_decay=0.05)
        scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs)
        criterion = nn.CrossEntropyLoss()
        
        # Create directory for saving checkpoints
        save_dir = Path('model_checkpoints')
        save_dir.mkdir(exist_ok=True)
        
        best_val_acc = 0.0
        
        try:
            for epoch in range(num_epochs):
                # Training phase
                model.train()
                total_loss = 0
                correct = 0
                total = 0
                start_time = time.time()
                
                for batch_idx, (spectrograms, labels) in enumerate(train_loader):
                    try:
                        spectrograms, labels = spectrograms.to(device), labels.to(device)
                        
                        outputs = model(spectrograms)
                        loss = criterion(outputs, labels)
                        
                        optimizer.zero_grad()
                        loss.backward()
                        optimizer.step()
                        
                        total_loss += loss.item()
                        
                        _, predicted = outputs.max(1)
                        total += labels.size(0)
                        correct += predicted.eq(labels).sum().item()
                        
                        if (batch_idx + 1) % 10 == 0:
                            print(f'Epoch [{epoch+1}/{num_epochs}], '
                                f'Batch [{batch_idx+1}/{len(train_loader)}], '
                                f'Loss: {loss.item():.4f}, '
                                f'Acc: {100.*correct/total:.2f}%')
                    except Exception as e:
                        print(f"Error in batch {batch_idx}: {str(e)}")
                        continue
                
                train_acc = 100.*correct/total
                avg_loss = total_loss / len(train_loader)
                epoch_time = time.time() - start_time
                
                # Validation phase
                model.eval()
                val_loss = 0
                correct = 0
                total = 0
                
                with torch.no_grad():
                    for spectrograms, labels in val_loader:
                        spectrograms, labels = spectrograms.to(device), labels.to(device)
                        outputs = model(spectrograms)
                        loss = criterion(outputs, labels)
                        
                        val_loss += loss.item()
                        _, predicted = outputs.max(1)
                        total += labels.size(0)
                        correct += predicted.eq(labels).sum().item()
                
                val_acc = 100.*correct/total
                val_loss = val_loss / len(val_loader)
                
                print(f'Epoch [{epoch+1}/{num_epochs}], '
                    f'Train Loss: {avg_loss:.4f}, Train Acc: {train_acc:.2f}%, '
                    f'Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%, '
                    f'Time: {epoch_time:.2f}s')
                
                # Save best model
                if val_acc > best_val_acc:
                    best_val_acc = val_acc
                    torch.save({
                        'epoch': epoch,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'val_acc': val_acc,
                    }, save_dir / 'updated_best_model.pth')
                
                scheduler.step()
                
        except Exception as e:
            print(f"Training error: {str(e)}")
            raise e


# Training and evaluation functions remain largely the same, but with updated class names
def evaluate_model(model, test_loader):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()
    
    correct = 0
    total = 0
    all_predictions = []
    all_labels = []
    
    with torch.no_grad():
        for spectrograms, labels in test_loader:
            spectrograms, labels = spectrograms.to(device), labels.to(device)
            outputs = model(spectrograms)
            _, predicted = outputs.max(1)
            
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            all_predictions.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')
    
    # Updated class names to include bike categories
    class_names = [
        'Ambulance Left', 'Ambulance Middle', 'Ambulance Right',
        'Car Horn Left', 'Car Horn Middle', 'Car Horn Right',
        'Fire Truck Left', 'Fire Truck Middle', 'Fire Truck Right',
        'Police Car Left', 'Police Car Middle', 'Police Car Right',
        'Bike Left', 'Bike Right', 'Bike Middle'
    ]
    
    all_predictions = np.array(all_predictions)
    all_labels = np.array(all_labels)
    
    for i, class_name in enumerate(class_names):
        class_mask = (all_labels == i)
        if np.sum(class_mask) > 0:
            class_acc = 100 * np.sum((all_predictions == i) & class_mask) / np.sum(class_mask)
            print(f'{class_name} Accuracy: {class_acc:.2f}%')
    
    return accuracy, all_predictions, all_labels
def load_best_model(model, filepath):
    """Load the best model weights"""
    if not Path(filepath).exists():
        raise FileNotFoundError(f"No model checkpoint found at {filepath}")
    
    checkpoint = torch.load(filepath)
    model.load_state_dict(checkpoint['model_state_dict'])
    print(f"Loaded model from epoch {checkpoint['epoch']} with validation accuracy {checkpoint['val_acc']:.2f}%")
    return model, checkpoint['epoch'], checkpoint['val_acc']

def save_model(model, optimizer, epoch, val_acc, filename):
    """Save model checkpoint with all necessary state information"""
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'val_acc': val_acc,
    }, filename)
    print(f"Model saved to {filename}")

def inference(model, image_path, device=None):
    """Run inference on a single image"""
    if device is None:
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                           std=[0.229, 0.224, 0.225]),
    ])
    
    image = Image.open(image_path).convert('RGB')
    image = transform(image).unsqueeze(0).to(device)
    
    model.eval()
    model = model.to(device)
    
    # Updated class names to include bike categories
    class_names = [
        'Ambulance Left', 'Ambulance Middle', 'Ambulance Right',
        'Car Horn Left', 'Car Horn Middle', 'Car Horn Right',
        'Fire Truck Left', 'Fire Truck Middle', 'Fire Truck Right',
        'Police Car Left', 'Police Car Middle', 'Police Car Right',
        'Bike Left', 'Bike Right', 'Bike Middle'
    ]
    
    with torch.no_grad():
        outputs = model(image)
        probabilities = torch.nn.functional.softmax(outputs, dim=1)
        predicted_class = torch.argmax(outputs, dim=1).item()
        confidence = probabilities[0][predicted_class].item()
        
    return {
        'predicted_class': class_names[predicted_class],
        'confidence': confidence * 100,
        'all_probabilities': {
            class_name: prob.item() * 100 
            for class_name, prob in zip(class_names, probabilities[0])
        }
    }

if __name__ == "__main__":
    base_dir = "Dataset of warning sound types and source directions"
    checkpoint_dir = Path('model_checkpoints')
    best_model_path = checkpoint_dir / 'updated_best_model.pth'
    
    train_loader, val_loader, test_loader = create_data_loaders(base_dir, batch_size=16)
    
    # Verify the split distribution
    verify_split_distribution(train_loader, val_loader, test_loader)    
    # Create model with 15 classes (including bike categories)
    model = DirectionalSoundViT(num_classes=15)
    
    print("Starting training...")
    train_model(model, train_loader, val_loader, num_epochs=30)
    
    print("\nLoading best model for evaluation...")
    model, best_epoch, best_val_acc = load_best_model(model, best_model_path)
    
    print("\nEvaluating best model...")
    accuracy, predictions, labels = evaluate_model(model, test_loader)

Total dataset size: 5904
Training set size: 4132 (70.0%)
Validation set size: 591 (10.0%)
Test set size: 1181 (20.0%)

Class distribution in splits:
Class                Train        Val       Test
---------------------------------------------
0                    67.9%       8.1%      24.0%
1                    69.4%      11.2%      19.4%
2                    68.5%       9.6%      21.9%
3                    72.3%       8.5%      19.2%
4                    72.7%       8.5%      18.8%
5                    68.8%      11.5%      19.8%
6                    73.3%       9.8%      16.9%
7                    71.0%       9.8%      19.2%
8                    70.4%      11.0%      18.5%
9                    68.5%       9.2%      22.3%
10                   66.7%      10.6%      22.7%
11                   71.2%      11.0%      17.7%
12                   66.7%      14.6%      18.8%
13                   68.0%      10.0%      22.0%
14                   65.2%      17.4%      17.4%
Starting training...


  checkpoint = torch.load(filepath)


Loaded model from epoch 21 with validation accuracy 95.94%

Evaluating best model...
Test Accuracy: 94.83%
Ambulance Left Accuracy: 97.39%
Ambulance Middle Accuracy: 100.00%
Ambulance Right Accuracy: 96.19%
Car Horn Left Accuracy: 98.91%
Car Horn Middle Accuracy: 100.00%
Car Horn Right Accuracy: 98.95%
Fire Truck Left Accuracy: 93.83%
Fire Truck Middle Accuracy: 100.00%
Fire Truck Right Accuracy: 60.67%
Police Car Left Accuracy: 89.72%
Police Car Middle Accuracy: 99.08%
Police Car Right Accuracy: 100.00%
Bike Left Accuracy: 100.00%
Bike Right Accuracy: 100.00%
Bike Middle Accuracy: 100.00%
