<a href="https://colab.research.google.com/github/mercadoerik1031/snn-sound-localization/blob/new_approach/snn_sound_localization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**SNN Sounnd Localization**



---



In [29]:
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# pip Installs

In [30]:
! pip install icecream snntorch --quiet

# Imports

In [40]:
import librosa
import numpy as np

import matplotlib.pyplot as plt

import pandas as pd

import os

import torch
import torchaudio
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from sklearn.model_selection import train_test_split

import snntorch as snn
from snntorch import spikegen

from icecream import ic

# Config

In [32]:
config = {
    # Paths
    "metadata_path": r"/content/drive/My Drive/Colab Notebooks/Masters Project/metadata.parquet",
    "ambisonic_path": r"/content/drive/My Drive/Colab Notebooks/Masters Project/spatial_librispeech_sample/ambisonics_sample",
    "noise_path": r"/content/drive/My Drive/Colab Notebooks/Masters Project/spatial_librispeech_sample/noise_ambisonics_sample",

    # Device
    "device": "cuda" if torch.cuda.is_available() else "cpu",

    # Audio Info
    "sr": 16_000,
    "n_fft": 512,
    "hop_length": 256,

    # SNN
    "num_steps": 20,
    
    # DataLoader
    "batch_size": 32,
    "seed": 42,
}

# Preprocess

## Load Metadata

In [33]:
def filter_data(
    metadata_path=config["metadata_path"],
    ambisonics_path=config["ambisonic_path"],
    noise_path=config["noise_path"]
    ):

    metadata = pd.read_parquet(metadata_path, engine="pyarrow")

    # Get list of all files in directories
    ambisonic_files = [f for f in os.listdir(ambisonics_path) if os.path.isfile(os.path.join(ambisonics_path, f))]
    noise_files = [f for f in os.listdir(noise_path) if os.path.isfile(os.path.join(noise_path, f))]

    sample_ids = [int(f.split(".")[0].lstrip("0") or 0) for f in ambisonic_files]
    filtered_metadata = metadata[metadata["sample_id"].isin(sample_ids)]

    # Create full file paths
    ambisonic_files = [os.path.join(ambisonics_path, f) for f in ambisonic_files]
    noise_files = [os.path.join(noise_path, f) for f in noise_files]

    return filtered_metadata, ambisonic_files, noise_files

## Load and Pad

In [34]:
def load_and_pad(ambisonic_path, noise_path=None, sr=config["sr"], max_duration=None):
    if max_duration is None:
        raise ValueError("Enter Value or max_duration")

    max_samples = max_duration * sr

    ambi_audio, _ = librosa.load(ambisonic_path, sr=sr, mono=False)
    ambi_audio = librosa.util.fix_length(ambi_audio, size=max_samples)

    if noise_path:
        noise_audio, _ = librosa.load(noise_path, sr=sr, mono=False)
        noise_audio = librosa.util.fix_length(noise_audio, size=max_samples)
        audio = ambi_audio + noise_audio

    else:
        audio = ambi_audio

    return audio

## Feature Extraction (STFT)

In [35]:
def stft(audio, n_fft=config["n_fft"], hop_length=config["hop_length"]):
    features = []

    for i in range(audio.shape[0]):
        channel = np.abs(librosa.stft(audio[i, :], n_fft=n_fft, hop_length=hop_length))
        features.append(channel)

    return np.array(features)


## Normalize

In [36]:
def normalize(features):
    mean = np.mean(features, axis=0)
    std = np.std(features, axis=0)
    epsilon = 1e-10

    normalized_feat = (features - mean) / (std + epsilon)

    return normalized_feat


## Split Metadata
### Train, Val, Test Split

In [41]:
def split_data(metadata):
    # Add a 'set' column to specify train, validation, or test
    train_indices = metadata[metadata['split'] == 'train'].index
    train_idx, valid_idx = train_test_split(train_indices, test_size=0.2, random_state=config["seed"])

    metadata['set'] = 'test'  # Initialize all as test
    metadata.loc[train_idx, 'set'] = 'train'  # Mark train
    metadata.loc[valid_idx, 'set'] = 'validation'  # Mark validation
    return metadata

# Dataset

In [None]:
class AudioDataset(Dataset):
    def __init__(self, metadata, dataset_type, audio_files, noise_files=None, transform=None, sr=config["sr"]):
        # Filter metadata for the specified dataset type
        self.metadata = metadata[metadata['set'] == dataset_type]
        self.audio_files = audio_files
        self.noise_files = noise_files
        self.transform = transform
        self.max_duration = int(np.ceil(self.metadata["audio_info/duration"].max()))
        self.sr = sr

    def __len__(self):
        return len(self.audio_files)

    def __getitem__(self, idx):
        audio_path = self.audio_files[idx]
        noise_path = self.noise_files[idx] if self.noise_files else None

        # Extract the sample id from the audio file name
        sample_id = int(os.path.basename(audio_path).split(".")[0].lstrip("0") or 0)

        # Get the corresponding labels from the metadata
        labels = self.metadata.loc[self.metadata['sample_id'] == sample_id, ['speech/azimuth', 'speech/elevation']].values[0]

        audio = load_and_pad(audio_path, noise_path, self.sr, self.max_duration)
        audio = stft(audio)
        audio = normalize(audio)

        if self.transform:
            audio = self.transform(audio)

        return audio, labels

# Dataloader

In [42]:
metadata, ambisonic_files, noise_files = filter_data()
metadata = split_data(metadata)

train_dataset = AudioDataset(metadata, "train", ambisonic_files, noise_files)
valid_dataset = AudioDataset(metadata, "validation", ambisonic_files, noise_files)
test_dataset = AudioDataset(metadata, "test", ambisonic_files, noise_files)

train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True)
valid_loader = DataLoader(valid_dataset, batch_size=config["batch_size"], shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=True)


## Plotting

In [None]:
# W = audio[0]
# X = audio[1]
# Y = audio[2]
# Z = audio[3]

# W_n = audio_n[0]
# Y_n = audio_n[2]
# X_n = audio_n[1]
# Z_n = audio_n[3]


In [None]:
# # Plot each channel
# fig, axs = plt.subplots(8, 1, figsize=(10, 8), sharex=True)

# axs[0].plot(W)
# axs[0].set_title('W Channel')
# axs[1].plot(W_n)
# axs[1].set_title('W_n Channel')

# axs[2].plot(X)
# axs[2].set_title('X Channel')
# axs[3].plot(X_n)
# axs[3].set_title('X_n Channel')


# axs[4].plot(Y)
# axs[4].set_title('Y Channel')
# axs[5].plot(Y_n)
# axs[5].set_title('Y_n Channel')

# axs[6].plot(Z)
# axs[6].set_title('Z Channel')
# axs[7].plot(Z_n)
# axs[7].set_title('Z_n Channel')

# # Common settings for all subplots
# for ax in axs:
#     ax.set_ylabel('Amplitude')
#     ax.label_outer()

# axs[-1].set_xlabel('Sample')

# plt.tight_layout()
# plt.show()
