#**SNN Sounnd Localization**



---



In [1]:
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# pip Installs

In [2]:
! pip install icecream optuna snntorch --quiet

# Imports

In [3]:
import librosa
import numpy as np

import matplotlib.pyplot as plt

import pandas as pd

import os

import torch
import torchaudio
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split

import snntorch as snn
from snntorch import spikegen

from icecream import ic

import optuna

import time

# Config

In [4]:
config = {
    # Paths
    "metadata_path": r"/content/drive/My Drive/Colab Notebooks/sound_localization/data/metadata.parquet",
    "ambisonic_path": r"/content/drive/My Drive/Colab Notebooks/sound_localization/data/ambisonics_lite",
    "noise_path": r"/content/drive/My Drive/Colab Notebooks/sound_localization/data/noise_ambisonics_lite",

    # Device
    "device": "cuda" if torch.cuda.is_available() else "cpu",

    # Audio Info
    "sr": 16_000,
    "n_fft": 512,
    "hop_length": 512, # 256

    # SNN
    "num_steps": 12,
    "thresh1": 17.175347400305153,
    "thresh2": 8.12828030444371,
    "thresh3": 15.5855582787769,
    "beta": 0.6210988248097677,

    # DataLoader
    "batch_size": 32,

    "seed": 42,

    # Training
    "epochs": 50, # Apple -> 20
    "lr": 0.000358787335187998,
    "dropout": 0.000358787335187998,
}

# Preprocess

## Load Metadata

In [5]:
def filter_data(
    metadata_path=config["metadata_path"],
    ambisonic_path=config["ambisonic_path"],
    noise_path=config["noise_path"],
    lite_version=True
):

    # Check for missing files
    metadata = pd.read_parquet(metadata_path)

    if lite_version:
      metadata = metadata[metadata["lite_version"] == True]

    metadata_ids = set(metadata["sample_id"])

    # Expected files names
    expected_ambi_files = {f"{id:06}.flac" for id in metadata_ids}
    expected_noise_files = {f"{id:06}.flac" for id in metadata_ids}

    # Actual File Names
    actual_ambi_files = set(os.listdir(ambisonic_path))
    actual_noise_files = set(os.listdir(noise_path))

    # Check for missing files
    missing_ambi_files = expected_ambi_files - actual_ambi_files
    missing_noise_files = expected_noise_files - actual_noise_files

    # Handle missing files
    if missing_ambi_files or missing_noise_files:
      raise ValueError(f"Missing Files: {missing_ambi_files} {missing_noise_files}")

    # List of ambisonic and noise file paths
    ambi_files = [os.path.join(ambisonic_path, f) for f in expected_ambi_files]
    noise_files = [os.path.join(noise_path, f) for f in expected_noise_files]

    return metadata, ambi_files, noise_files

## Load and Pad

In [6]:
def load_and_pad(ambisonic_path, noise_path=None, sr=config["sr"], max_duration=None):
    if max_duration is None:
        raise ValueError("Enter Value or max_duration")

    max_samples = max_duration * sr

    ambi_audio, _ = librosa.load(ambisonic_path, sr=sr, mono=False)
    ambi_audio = librosa.util.fix_length(ambi_audio, size=max_samples)

    if noise_path:
        noise_audio, _ = librosa.load(noise_path, sr=sr, mono=False)
        noise_audio = librosa.util.fix_length(noise_audio, size=max_samples)
        audio = ambi_audio + noise_audio

    else:
        audio = ambi_audio

    return audio

## Feature Extraction (STFT)

In [7]:
def stft(audio, n_fft=config["n_fft"], hop_length=config["hop_length"]):
    features = []

    for i in range(audio.shape[0]):
        channel = np.abs(librosa.stft(audio[i, :], n_fft=n_fft, hop_length=hop_length))
        features.append(channel)

    return np.array(features)


## Normalize

In [8]:
def normalize(features):
    mean = np.mean(features, axis=0)
    std = np.std(features, axis=0)
    epsilon = 1e-10

    normalized_feat = (features - mean) / (std + epsilon)

    return torch.tensor(normalized_feat).float()


## Split Metadata
### Train, Val, Test Split

In [9]:
def split_data(metadata, validation_size=0.2, random_state=42):
    metadata['set'] = metadata['split']

    # Identify indices of rows marked for training
    train_indices = metadata[metadata['split'] == 'train'].index

    # Split the train_indices into training and validation sets
    train_idx, valid_idx = train_test_split(train_indices, test_size=validation_size, random_state=random_state)

    # Update the 'set' column to mark validation
    metadata.loc[valid_idx, 'set'] = 'validation'

    return metadata

## Calc Median Absolute Error

In [10]:
def calculate_3d_angle_error(true_azimuth, true_elevation, pred_azimuth, pred_elevation):
    # Calculate the 3D angle error in radians
    angle_error_rad = torch.acos(
        torch.sin(true_azimuth) * torch.sin(pred_azimuth) +
        torch.cos(true_azimuth) * torch.cos(pred_azimuth) * torch.cos(true_elevation - pred_elevation)
    )

    # Convert the angle error to degrees for interpretability
    angle_error_deg = torch.rad2deg(angle_error_rad)

    # Calculate the median absolute error of the angle in degrees
    median_abs_error_deg = torch.median(torch.abs(angle_error_deg))

    return median_abs_error_deg


# Custom Dataset

In [11]:
class AudioDataset(Dataset):
    def __init__(self, metadata, dataset_type, audio_files, noise_files=None, transform=None, sr=config["sr"]):
        self.metadata = metadata
        self.dataset_type = dataset_type
        self.audio_files = audio_files
        self.noise_files = noise_files
        self.transform = transform
        self.max_duration = int(np.ceil(self.metadata["audio_info/duration"].mean() + metadata["audio_info/duration"].std()))
        self.sr = sr
        self.num_steps = config["num_steps"]

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        audio_path = self.audio_files[idx]
        noise_path = self.noise_files[idx] if self.noise_files else None

        # Extract the sample id from the audio file name
        sample_id = int(os.path.basename(audio_path).split(".")[0].lstrip("0") or 0)

        # Attempt to get the corresponding labels from the metadata
        label_rows = self.metadata[self.metadata['sample_id'] == sample_id]

        if label_rows.empty:
            raise ValueError(f"Sample ID {sample_id} not found in metadata for file {audio_path}")

        labels = label_rows[['speech/azimuth', 'speech/elevation']].values[0]

        # labels to a tensor
        labels = labels.astype("float32")  # Ensure this conversion is correctly applied
        labels_tensor = torch.from_numpy(labels)

        # audio loading and preprocessing functions
        audio = load_and_pad(audio_path, noise_path, max_duration=self.max_duration)
        audio = stft(audio)
        audio = normalize(audio)
        spike_train = spikegen.rate(audio, num_steps=self.num_steps)

        return spike_train, labels_tensor

# Model

## SNN

In [12]:
class SNN(nn.Module):
    def __init__(self, config):
        super(SNN, self).__init__()
        self.thresh1 = config["thresh1"]
        self.thresh2 = config["thresh2"]
        self.thresh3 = config["thresh3"]
        self.beta = config["beta"]
        self.num_steps = config["num_steps"]
        self.dropout_prob = config["dropout"]

        self.conv1 = nn.Conv2d(4, 32, kernel_size=(3, 3))
        self.batch_norm1 = nn.BatchNorm2d(32)
        self.max_pool1 = nn.MaxPool2d(kernel_size=(2, 2))
        self.lif1 = snn.Leaky(beta=self.beta, threshold=self.thresh1)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=(3, 3))
        self.batch_norm2 = nn.BatchNorm2d(64)
        self.max_pool2 = nn.MaxPool2d(kernel_size=(2, 2))
        self.lif2 = snn.Leaky(beta=self.beta, threshold=self.thresh2)

        self.conv3 = nn.Conv2d(64, 128, kernel_size=(3, 3))
        self.batch_norm3 = nn.BatchNorm2d(128)
        self.max_pool3 = nn.MaxPool2d(kernel_size=(2, 2))
        self.lif3 = snn.Leaky(beta=self.beta, threshold=self.thresh3)

        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(215040, 256)
        self.dropout_layer = nn.Dropout(self.dropout_prob)
        self.fc2 = nn.Linear(256, 2)

    def forward(self, inpt):
        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        out_rec = []
        mem3_rec = []

        inpt = inpt.permute(1, 0, 2, 3, 4) # [num_step, batch_size, channels, width, height]

        # ic(inpt.shape)

        for step in range(inpt.size(0)):
            # ic(inpt[step].shape)

            current1 = self.conv1(inpt[step])
            current1 = self.batch_norm1(current1)
            current1 = self.max_pool1(current1)
            spike1, mem1 = self.lif1(current1, mem1)

            current2 = self.conv2(spike1)
            current2 = self.batch_norm2(current2)
            current2 = self.max_pool2(current2)
            spike2, mem2 = self.lif2(current2, mem2)

            current3 = self.conv3(spike2)
            current3 = self.batch_norm3(current3)
            current3 = self.max_pool3(current3)
            spike3, mem3 = self.lif3(current3, mem3)

            flatten = self.flatten(spike3)

            fc = self.fc1(flatten)
            fc = self.dropout_layer(fc)

            out = self.fc2(fc)

            out_rec.append(out)
            mem3_rec.append(mem3)

        # Average the outputs over the steps
        out_avg = torch.mean(torch.stack(out_rec), dim=0)

        return out_avg


## Train Function

In [13]:
def train(model, train_loader, criterion, optimizer, device):
    #print("Training")

    train_loss = 0.0
    true_azimuths, true_elevations, pred_azimuths, pred_elevations = [], [], [], []

    model.train()

    for batch_idx, (data, labels) in enumerate(train_loader):
        print(f"Batch {batch_idx+1}/{len(test_loader)}")

        data, labels = data.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(data)

        # ic(outputs)
        # ic(labels)

        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()
        train_loss += loss.item() * data.size(0)

        true_azimuths.append(labels[:, 0].detach())
        true_elevations.append(labels[:, 1].detach())
        pred_azimuths.append(outputs[:, 0].detach())
        pred_elevations.append(outputs[:, 1].detach())

    train_loss = train_loss / len(train_loader.dataset)

    return train_loss, torch.cat(true_azimuths), torch.cat(true_elevations), torch.cat(pred_azimuths), torch.cat(pred_elevations)


## Validation Function

In [14]:
def validate(model, val_loader, criterion, device):
    #print("Validation")

    valid_loss = 0.0
    true_azimuths, true_elevations, pred_azimuths, pred_elevations = [], [], [], []

    model.eval()

    with torch.no_grad():
        for batch_idx, (data, labels) in enumerate(val_loader):
            print(f"Batch {batch_idx+1}/{len(test_loader)}")
            data, labels = data.to(device), labels.to(device)

            outputs = model(data)

            # ic(outputs)
            # ic(labels)

            loss = criterion(outputs, labels)
            valid_loss += loss.item() * data.size(0)

            true_azimuths.append(labels[:, 0].detach())
            true_elevations.append(labels[:, 1].detach())
            pred_azimuths.append(outputs[:, 0].detach())
            pred_elevations.append(outputs[:, 1].detach())

    valid_loss = valid_loss / len(val_loader.dataset)

    return valid_loss, torch.cat(true_azimuths), torch.cat(true_elevations), torch.cat(pred_azimuths), torch.cat(pred_elevations)


## Test Function

In [15]:
def test(model, test_loader, criterion, device):
    #print("Testing")

    test_loss = 0.0
    true_azimuths, true_elevations, pred_azimuths, pred_elevations = [], [], [], []

    model.eval()

    with torch.no_grad():
        for batch_idx, (data, labels) in enumerate(test_loader):
            data, labels = data.to(device), labels.to(device)

            outputs = model(data)

            # ic(outputs)
            # ic(labels)

            loss = criterion(outputs, labels)
            test_loss += loss.item() * data.size(0)

            true_azimuths.append(labels[:, 0].detach())
            true_elevations.append(labels[:, 1].detach())
            pred_azimuths.append(outputs[:, 0].detach())
            pred_elevations.append(outputs[:, 1].detach())

    test_loss = test_loss / len(test_loader.dataset)

    return test_loss, torch.cat(true_azimuths), torch.cat(true_elevations), torch.cat(pred_azimuths), torch.cat(pred_elevations)

# Run Code

In [16]:
# from concurrent.futures import ProcessPoolExecutor, as_completed
# import functools

In [17]:
# def filter_and_split_data():
#     # Assuming filter_data() can be run independently and its results
#     # are used by split_data().
#     metadata, ambisonic_files, noise_files = filter_data()
#     metadata = split_data(metadata)
#     return metadata, ambisonic_files, noise_files

# num_workers = 10

# with ProcessPoolExecutor(max_workers=num_workers) as executor:
#     future_to_task = {executor.submit(filter_and_split_data): 'Task 1'}
#     for future in as_completed(future_to_task):
#         try:
#             result = future.result()
#             # Process your result here
#             print(f"Result: {result}")
#         except Exception as exc:
#             print(f"Generated an exception: {exc}")

In [18]:
metadata, ambisonic_files, noise_files = filter_data()
metadata = split_data(metadata)


In [19]:
device = config["device"]
model = SNN(config).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"])

In [20]:
train_set = AudioDataset(metadata, 'train', ambisonic_files, noise_files)
val_set = AudioDataset(metadata, "validation", ambisonic_files, noise_files)
test_set = AudioDataset(metadata, "test", ambisonic_files, noise_files)

train_loader = DataLoader(train_set, batch_size=config["batch_size"], shuffle=True)
val_loader = DataLoader(val_set, batch_size=config["batch_size"], shuffle=True)
test_loader = DataLoader(test_set, batch_size=config["batch_size"], shuffle=False)

In [None]:
train_loss = []
valid_loss = []

for epoch in range(config["epochs"]):
    print(f"Epoch {epoch+1}/{config['epochs']}")

    train_epoch_loss, train_true_azimuth, train_true_elevation, train_pred_azimuth, train_pred_elevation = train(model, train_loader, criterion, optimizer, device)
    valid_epoch_loss, valid_true_azimuth, valid_true_elevation, valid_pred_azimuth, valid_pred_elevation = validate(model, val_loader, criterion, device)

    train_loss.append(train_epoch_loss)
    valid_loss.append(valid_epoch_loss)

    train_angle_error = calculate_3d_angle_error(train_true_azimuth, train_true_elevation, train_pred_azimuth, train_pred_elevation)
    valid_angle_error = calculate_3d_angle_error(valid_true_azimuth, valid_true_elevation, valid_pred_azimuth, valid_pred_elevation)

    print(f"Train Loss: {train_epoch_loss:.4f} | Validation Loss: {valid_epoch_loss:.4f}")
    print(f"Train Angle Error: {train_angle_error:.4f}° | Validation Angle Error: {valid_angle_error:.4f}°\n")

test_loss, test_true_azimuth, test_true_elevation, test_pred_azimuth, test_pred_elevation = test(model, test_loader, criterion, device)
test_angle_error = calculate_3d_angle_error(test_true_azimuth, test_true_elevation, test_pred_azimuth, train_pred_elevation)

print(f"Test Loss: {test_loss:.4f}")
print(f"Test Angle Error: {test_angle_error:.4f}\n")

# Hyper Param Sweep (Optuna)

In [None]:
# def objective(trial):
#     # Hyperparameters to optimize
#     config["thresh1"] = trial.suggest_float("thresh1", 1, 30)
#     config["thresh2"] = trial.suggest_float("thresh2", 1, 30)
#     config["thresh3"] = trial.suggest_float("thresh3", 1, 30)
#     config["beta"] = trial.suggest_float("beta", 0.1, 0.9)
#     config["lr"] = trial.suggest_float("lr", 1e-10, 1e-3)
#     config["dropout"] = trial.suggest_float("droput", 0, 0.8)
#     config["num_steps"] = trial.suggest_int("num_steps", 5, 30)

#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     model = SNN(config).to(device)
#     criterion = nn.MSELoss()
#     optimizer = torch.optim.Adam(model.parameters(), lr=config["lr"])

#     # Data loading code here
#     # Assuming train_loader and val_loader are defined

#     for epoch in range(config["epochs"]):
#         train(model, train_loader, criterion, optimizer, device)
#         valid_epoch_loss, valid_true_azimuth, valid_true_elevation, valid_pred_azimuth, valid_pred_elevation = validate(model, val_loader, criterion, device)
#         valid_angle_error = calculate_3d_angle_error(valid_true_azimuth, valid_true_elevation, valid_pred_azimuth, valid_pred_elevation)

#     return valid_angle_error.item()  # Optuna minimizes this value


In [None]:
# def run_study():
#     study_name = "snn_study"
#     storage_name = "sqlite:///{}.db".format(study_name)

#     study = optuna.create_study(study_name=study_name, storage=storage_name, load_if_exists=True, direction="minimize")
#     study.optimize(objective, n_trials=100)

#     # Save the study to disk
#     study.trials_dataframe().to_csv("study_results.csv")

#     print("Study statistics: ")
#     print("  Number of finished trials: ", len(study.trials))
#     print("  Best trial:")
#     trial = study.best_trial
#     print("    Value: ", trial.value)
#     print("    Params: ")
#     for key, value in trial.params.items():
#         print(f"      {key}: {value}")



In [None]:
# run_study()

[I 2024-02-12 23:15:10,615] Trial 1 finished with value: 33.963905334472656 and parameters: {'thresh1': 17.175347400305153, 'thresh2': 8.12828030444371, 'thresh3': 15.5855582787769, 'beta': 0.6210988248097677, 'lr': 0.000358787335187998, 'droput': 0.7449582513653029}. Best is trial 1 with value: 33.963905334472656.