<a href="https://colab.research.google.com/github/mercadoerik1031/snn-sound-localization/blob/new_approach/snn_sound_localization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**SNN Sounnd Localization**



---



In [1]:
from google.colab import drive
drive.mount("/content/drive")

Mounted at /content/drive


# pip Installs

In [2]:
! pip install icecream snntorch --quiet

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m109.0/109.0 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m76.2/76.2 kB[0m [31m10.6 MB/s[0m eta [36m0:00:00[0m
[?25h

# Imports

In [3]:
import librosa
import numpy as np

import matplotlib.pyplot as plt

import pandas as pd

import os

import torch
import torchaudio
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split

import snntorch as snn
from snntorch import spikegen

from icecream import ic

# Config

In [4]:
config = {
    # Paths
    "metadata_path": r"/content/drive/My Drive/Colab Notebooks/Masters Project/metadata.parquet",
    "ambisonic_path": r"/content/drive/My Drive/Colab Notebooks/Masters Project/spatial_librispeech_sample/ambisonics_sample",
    "noise_path": r"/content/drive/My Drive/Colab Notebooks/Masters Project/spatial_librispeech_sample/noise_ambisonics_sample",

    # Device
    "device": "cuda" if torch.cuda.is_available() else "cpu",

    # Audio Info
    "sr": 16_000,
    "n_fft": 256, #512
    "hop_length": 256,

    # SNN
    "num_steps": 10,
    "thresh1": 6,
    "thresh2": 3,
    "thresh3": 5,
    "beta": 0.2,

    # DataLoader
    "batch_size": 16,
    "seed": 42,

    # Training
    "epochs": 5,
}

# Preprocess

## Load Metadata

In [5]:
def filter_data(
    metadata_path=config["metadata_path"],
    ambisonics_path=config["ambisonic_path"],
    noise_path=config["noise_path"]
    ):

    metadata = pd.read_parquet(metadata_path, engine="pyarrow")

    ambisonic_file_names = [f for f in os.listdir(ambisonics_path) if os.path.isfile(os.path.join(ambisonics_path, f))]
    ambisonic_file_names.sort()

    noise_file_names = [f for f in os.listdir(noise_path) if os.path.isfile(os.path.join(noise_path, f))]
    noise_file_names.sort()

    if ambisonic_file_names == noise_file_names:
        sample_ids = [int(f.split(".")[0].lstrip("0") or 0) for f in ambisonic_file_names]

        # Filter metadata to include only rows with sample_ids from available files
        metadata = metadata[metadata["sample_id"].isin(sample_ids)]

        ambisonic_files = [os.path.join(ambisonics_path, f) for f in ambisonic_file_names]
        noise_files = [os.path.join(noise_path, f) for f in noise_file_names]
    else:
        raise ValueError("ambisonic_files and noise_files do not match")


    return metadata, ambisonic_files, noise_files


## Load and Pad

In [6]:
def load_and_pad(ambisonic_path, noise_path=None, sr=config["sr"], max_duration=None):
    if max_duration is None:
        raise ValueError("Enter Value or max_duration")

    max_samples = max_duration * sr

    ambi_audio, _ = librosa.load(ambisonic_path, sr=sr, mono=False)
    ambi_audio = librosa.util.fix_length(ambi_audio, size=max_samples)

    if noise_path:
        noise_audio, _ = librosa.load(noise_path, sr=sr, mono=False)
        noise_audio = librosa.util.fix_length(noise_audio, size=max_samples)
        audio = ambi_audio + noise_audio

    else:
        audio = ambi_audio

    return audio

## Feature Extraction (STFT)

In [7]:
def stft(audio, n_fft=config["n_fft"], hop_length=config["hop_length"]):
    features = []

    for i in range(audio.shape[0]):
        channel = np.abs(librosa.stft(audio[i, :], n_fft=n_fft, hop_length=hop_length))
        features.append(channel)

    return np.array(features)


## Normalize

In [8]:
def normalize(features):
    mean = np.mean(features, axis=0)
    std = np.std(features, axis=0)
    epsilon = 1e-10

    normalized_feat = (features - mean) / (std + epsilon)

    return torch.tensor(normalized_feat).float()


## Split Metadata
### Train, Val, Test Split

In [9]:
def split_data(metadata, validation_size=0.2, random_state=42):

    # Initialize all rows with 'set' as 'test' based on the existing 'split' column
    metadata['set'] = metadata['split']

    # Identify indices of rows marked for training
    train_indices = metadata[metadata['split'] == 'train'].index

    # Split the train_indices into training and validation sets
    train_idx, valid_idx = train_test_split(train_indices, test_size=validation_size, random_state=random_state)

    # Update the 'set' column to mark validation
    metadata.loc[valid_idx, 'set'] = 'validation'

    return metadata

## Calc Median Absolute Error

In [10]:
def calculate_3d_angle_error(true_azimuth, true_elevation, pred_azimuth, pred_elevation):
    # Calculate the 3D angle error in radians
    angle_error_rad = torch.acos(
        torch.sin(true_azimuth) * torch.sin(pred_azimuth) +
        torch.cos(true_azimuth) * torch.cos(pred_azimuth) * torch.cos(true_elevation - pred_elevation)
    )

    # Convert the angle error to degrees for interpretability
    angle_error_deg = torch.rad2deg(angle_error_rad)

    # Calculate the median absolute error of the angle in degrees
    median_abs_error_deg = torch.median(torch.abs(angle_error_deg))

    return median_abs_error_deg


# Custom Dataset

In [11]:
class AudioDataset(Dataset):
    def __init__(self, metadata, dataset_type, audio_files, noise_files=None, transform=None, sr=config["sr"]):
        # Filter metadata for the specified dataset type
        self.metadata = metadata
        self.dataset_type = dataset_type
        self.audio_files = audio_files
        self.noise_files = noise_files
        self.transform = transform
        self.max_duration = int(np.ceil(self.metadata["audio_info/duration"].max()))
        self.sr = sr
        self.num_steps = config["num_steps"]

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        audio_path = self.audio_files[idx]
        noise_path = self.noise_files[idx] if self.noise_files else None

        # Extract the sample id from the audio file name
        sample_id = int(os.path.basename(audio_path).split(".")[0].lstrip("0") or 0)

        # Attempt to get the corresponding labels from the metadata
        label_rows = self.metadata[self.metadata['sample_id'] == sample_id]

        if label_rows.empty:
            raise ValueError(f"Sample ID {sample_id} not found in metadata for file {audio_path}")

        labels = label_rows[['speech/azimuth', 'speech/elevation']].values[0]

        # Correctly convert labels to a tensor
        labels = labels.astype("float32")  # Ensure this conversion is correctly applied
        labels_tensor = torch.from_numpy(labels)

        # Placeholder for actual audio loading and preprocessing functions
        audio = load_and_pad(audio_path, noise_path, max_duration=self.max_duration)
        audio = stft(audio)
        audio = normalize(audio)
        spike_train = spikegen.rate(audio, num_steps=self.num_steps)

        if self.transform:
            audio = self.transform(audio)

        return audio, labels_tensor

# Model

## SNN

In [12]:
class SNN(nn.Module):
    def __init__(self, config):
        super(SNN, self).__init__()
        self.thresh1 = config["thresh1"]
        self.thresh2 = config["thresh2"]
        self.thresh3 = config["thresh3"]
        self.beta = config["beta"]
        self.num_steps = config["num_steps"]

        self.conv1 = nn.Conv2d(4, 32, kernel_size=(3, 3))
        self.batch_norm1 = nn.BatchNorm2d(32)
        self.max_pool1 = nn.MaxPool2d(kernel_size=(2, 2))
        self.lif1 = snn.Leaky(beta=self.beta, threshold=self.thresh1)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=(3, 3))
        self.batch_norm2 = nn.BatchNorm2d(64)
        self.max_pool2 = nn.MaxPool2d(kernel_size=(2, 2))
        self.lif2 = snn.Leaky(beta=self.beta, threshold=self.thresh2)

        self.conv3 = nn.Conv2d(64, 128, kernel_size=(3, 3))
        self.batch_norm3 = nn.BatchNorm2d(128)
        self.max_pool3 = nn.MaxPool2d(kernel_size=(2, 2))
        self.lif3 = snn.Leaky(beta=self.beta, threshold=self.thresh3)

        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(458752, 256)
        self.fc2 = nn.Linear(256, 2)

    def forward(self, inpt):
        #print(f"Input shape: {inpt.shape}")  # Print the input shape

        mem1 = self.lif1.init_leaky()
        mem2 = self.lif2.init_leaky()
        mem3 = self.lif3.init_leaky()

        out_rec = []
        mem3_rec = []

        for step in range(self.num_steps):
            current1 = self.conv1(inpt)
            #print(f"After conv1 shape: {current1.shape}")
            current1 = self.batch_norm1(current1)
            current1 = self.max_pool1(current1)
            spike1, mem1 = self.lif1(current1, mem1)

            current2 = self.conv2(spike1)
            #print(f"After conv2 shape: {current2.shape}")
            current2 = self.batch_norm2(current2)
            current2 = self.max_pool2(current2)
            spike2, mem2 = self.lif2(current2, mem2)

            current3 = self.conv3(spike2)
            #print(f"After conv3 shape: {current3.shape}")
            current3 = self.batch_norm3(current3)
            current3 = self.max_pool3(current3)
            spike3, mem3 = self.lif3(current3, mem3)

            flatten = self.flatten(spike3)
            #print(f"After flatten shape: {flatten.shape}")

            fc = self.fc1(flatten)
            #print(f"After fc1 shape: {fc.shape}")

            out = self.fc2(fc)
            #print(f"Output shape: {out.shape}")

            out_rec.append(out)
            mem3_rec.append(mem3)

        # Assuming you want to average the outputs over the steps
        out_avg = torch.mean(torch.stack(out_rec), dim=0)
        #print(f"Average output shape: {out_avg.shape}")

        return out_avg

## Train Function

In [13]:
def train(model, train_loader, criterion, optimizer, device):
    #print("Training")

    train_loss = 0.0
    true_azimuths, true_elevations, pred_azimuths, pred_elevations = [], [], [], []

    model.train()

    for batch_idx, (data, labels) in enumerate(train_loader):
        data, labels = data.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(data)

        # ic(outputs)
        # ic(labels)

        loss = criterion(outputs, labels)

        loss.backward()
        optimizer.step()
        train_loss += loss.item() * data.size(0)

        true_azimuths.append(labels[:, 0].detach())
        true_elevations.append(labels[:, 1].detach())
        pred_azimuths.append(outputs[:, 0].detach())
        pred_elevations.append(outputs[:, 1].detach())

    train_loss = train_loss / len(train_loader.dataset)

    return train_loss, torch.cat(true_azimuths), torch.cat(true_elevations), torch.cat(pred_azimuths), torch.cat(pred_elevations)


## Validation Function

In [14]:
def validate(model, val_loader, criterion, device):
    #print("Validation")

    valid_loss = 0.0
    true_azimuths, true_elevations, pred_azimuths, pred_elevations = [], [], [], []

    model.eval()

    with torch.no_grad():
        for batch_idx, (data, labels) in enumerate(val_loader):
            data, labels = data.to(device), labels.to(device)

            outputs = model(data)

            # ic(outputs)
            # ic(labels)

            loss = criterion(outputs, labels)
            valid_loss += loss.item() * data.size(0)

            true_azimuths.append(labels[:, 0].detach())
            true_elevations.append(labels[:, 1].detach())
            pred_azimuths.append(outputs[:, 0].detach())
            pred_elevations.append(outputs[:, 1].detach())

    valid_loss = valid_loss / len(val_loader.dataset)

    return valid_loss, torch.cat(true_azimuths), torch.cat(true_elevations), torch.cat(pred_azimuths), torch.cat(pred_elevations)


## Test Function

In [15]:
def test(model, test_loader, criterion, device):
    #print("Testing")

    test_loss = 0.0
    true_azimuths, true_elevations, pred_azimuths, pred_elevations = [], [], [], []

    model.eval()

    with torch.no_grad():
        for batch_idx, (data, labels) in enumerate(test_loader):
            data, labels = data.to(device), labels.to(device)

            outputs = model(data)

            # ic(outputs)
            # ic(labels)

            loss = criterion(outputs, labels)
            test_loss += loss.item() * data.size(0)

            true_azimuths.append(labels[:, 0].detach())
            true_elevations.append(labels[:, 1].detach())
            pred_azimuths.append(outputs[:, 0].detach())
            pred_elevations.append(outputs[:, 1].detach())

    test_loss = test_loss / len(test_loader.dataset)

    return test_loss, torch.cat(true_azimuths), torch.cat(true_elevations), torch.cat(pred_azimuths), torch.cat(pred_elevations)

# Run Code

In [16]:
metadata, ambisonic_files, noise_files = filter_data()
metadata = split_data(metadata)


In [17]:
device = config["device"]
model = SNN(config).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [18]:
train_set = AudioDataset(metadata, 'train', ambisonic_files, noise_files)
val_set = AudioDataset(metadata, "validation", ambisonic_files, noise_files)
test_set = AudioDataset(metadata, "test", ambisonic_files, noise_files)

train_loader = DataLoader(train_set, batch_size=config["batch_size"], shuffle=True)
val_loader = DataLoader(val_set, batch_size=config["batch_size"], shuffle=True)
test_loader = DataLoader(test_set, batch_size=config["batch_size"], shuffle=False)

In [19]:
train_loss = []
valid_loss = []

for epoch in range(config["epochs"]):
    print(f"Epoch {epoch+1}/{config['epochs']}")

    train_epoch_loss, train_true_azimuth, train_true_elevation, train_pred_azimuth, train_pred_elevation = train(model, train_loader, criterion, optimizer, device)
    valid_epoch_loss, valid_true_azimuth, valid_true_elevation, valid_pred_azimuth, valid_pred_elevation = validate(model, val_loader, criterion, device)

    train_loss.append(train_epoch_loss)
    valid_loss.append(valid_epoch_loss)

    train_angle_error = calculate_3d_angle_error(train_true_azimuth, train_true_elevation, train_pred_azimuth, train_pred_elevation)
    valid_angle_error = calculate_3d_angle_error(valid_true_azimuth, valid_true_elevation, valid_pred_azimuth, valid_pred_elevation)

    print(f"Train Loss: {train_epoch_loss:.4f} | Validation Loss: {valid_epoch_loss:.4f}")
    print(f"Train Angle Error: {train_angle_error:.4f} | Validation Angle Error: {valid_angle_error:.4f}\n")

test_loss, test_true_azimuth, test_true_elevation, test_pred_azimuth, test_pred_elevation = test(model, test_loader, criterion, device)
test_angle_error = calculate_3d_angle_error(test_true_azimuth, test_true_elevation, test_pred_azimuth, train_pred_elevation)
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Angle Error: {test_angle_error:.4f}\n")

Epoch 1/5
Training


ic| outputs: tensor([[ 0.0130, -0.0351],
                     [ 0.0171,  0.0064],
                     [ 0.0335, -0.0201],
                     [-0.0303,  0.0038],
                     [ 0.0254,  0.0103],
                     [ 0.0142, -0.0165],
                     [ 0.0509,  0.0313],
                     [ 0.0288,  0.0133],
                     [-0.0281,  0.0025],
                     [-0.0254, -0.0250],
                     [ 0.0630, -0.0145],
                     [ 0.0278,  0.0210],
                     [ 0.0105,  0.0141],
                     [ 0.0360,  0.0256],
                     [ 0.0534, -0.0011],
                     [ 0.0176, -0.0043]], device='cuda:0', grad_fn=<MeanBackward1>)
ic| labels: tensor([[-5.2584e-01, -2.8116e-01],
                    [ 1.8784e+00,  1.6417e-01],
                    [-1.7161e+00,  3.9511e-02],
                    [-2.0596e+00,  1.7430e-03],
                    [ 5.3497e-01, -1.3366e-01],
                    [-8.9757e-01, -6.4136e-02],
             

KeyboardInterrupt: 

## Plotting

In [None]:
# W = audio[0]
# X = audio[1]
# Y = audio[2]
# Z = audio[3]

# W_n = audio_n[0]
# Y_n = audio_n[2]
# X_n = audio_n[1]
# Z_n = audio_n[3]


In [None]:
# # Plot each channel
# fig, axs = plt.subplots(8, 1, figsize=(10, 8), sharex=True)

# axs[0].plot(W)
# axs[0].set_title('W Channel')
# axs[1].plot(W_n)
# axs[1].set_title('W_n Channel')

# axs[2].plot(X)
# axs[2].set_title('X Channel')
# axs[3].plot(X_n)
# axs[3].set_title('X_n Channel')


# axs[4].plot(Y)
# axs[4].set_title('Y Channel')
# axs[5].plot(Y_n)
# axs[5].set_title('Y_n Channel')

# axs[6].plot(Z)
# axs[6].set_title('Z Channel')
# axs[7].plot(Z_n)
# axs[7].set_title('Z_n Channel')

# # Common settings for all subplots
# for ax in axs:
#     ax.set_ylabel('Amplitude')
#     ax.label_outer()

# axs[-1].set_xlabel('Sample')

# plt.tight_layout()
# plt.show()
