<a href="https://colab.research.google.com/github/mercadoerik1031/snn-sound-localization/blob/new_approach/snn_sound_localization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**SNN Sounnd Localization**



---



In [None]:
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# pip Installs

In [None]:
! pip install icecream snntorch --quiet

# Imports

In [None]:
import librosa
import numpy as np

import matplotlib.pyplot as plt

import pandas as pd

import os

import torch
import torchaudio

import snntorch as snn
from snntorch import spikegen

from icecream import ic

# Config

In [None]:
config = {
    # Paths
    "metadata_path": r"/content/drive/My Drive/Colab Notebooks/Masters Project/metadata.parquet",
    "ambisonic_path": r"/content/drive/My Drive/Colab Notebooks/Masters Project/spatial_librispeech_sample/ambisonics_sample",
    "noise_path": r"/content/drive/My Drive/Colab Notebooks/Masters Project/spatial_librispeech_sample/noise_ambisonics_sample",

    # Device
    "device": "cuda" if torch.cuda.is_available() else "cpu",

    # Audio Info
    "sr": 16_000,
    "n_fft": 2048,
    "hop_length": 512,

    # SNN
    "num_steps": 20,
}

# Preprocess

## Load Metadata

In [None]:
def filter_data(
    metadata_path=config["metadata_path"],
    ambisonics_path=config["ambisonic_path"],
    noise_path=config["noise_path"]
    ):

    metadata = pd.read_parquet(metadata_path, engine="pyarrow")

    # Get list of all files in directories
    ambisonic_files = [f for f in os.listdir(ambisonics_path) if os.path.isfile(os.path.join(ambisonics_path, f))]
    noise_files = [f for f in os.listdir(noise_path) if os.path.isfile(os.path.join(noise_path, f))]

    sample_ids = [int(f.split(".")[0].lstrip("0") or 0) for f in ambisonic_files]
    filtered_metadata = metadata[metadata["sample_id"].isin(sample_ids)]

    # Create full file paths
    ambisonic_files = [os.path.join(ambisonics_path, f) for f in ambisonic_files]
    noise_files = [os.path.join(noise_path, f) for f in noise_files]

    return filtered_metadata, ambisonic_files, noise_files

## Load and Pad

In [None]:
def load_and_pad(ambisonic_path, noise_path=None, sr=config["sr"], max_duration=None):
    if max_duration is None:
        raise ValueError("Enter Value or max_duration")

    max_samples = max_duration * sr

    ambi_audio, _ = librosa.load(ambisonic_path, sr=sr, mono=False)
    ambi_audio = librosa.util.fix_length(ambi_audio, size=max_samples)

    if noise_path:
        noise_audio, _ = librosa.load(noise_path, sr=sr, mono=False)
        noise_audio = librosa.util.fix_length(noise_audio, size=max_samples)
        audio = ambi_audio + noise_audio

    else:
        audio = ambi_audio

    return audio

## Feature Extraction (STFT)

In [None]:
def stft(audio, n_fft=config["n_fft"], hop_length=config["hop_length"]):
    features = []

    for i in range(audio.shape[0]):
        channel = librosa.stft(audio[i, :], n_fft=n_fft, hop_length=hop_length)
        features.append(channel)

    return np.array(features)


## Normalize

In [None]:
def normalize(features):
    mean = np.mean(features, axis=0)
    std = np.std(features, axis=0)
    epsilon = 1e-10

    normalized_feat = (features - mean) / (std + epsilon)

    return normalized_feat


In [None]:
filtered_metadata, ambisonic_files, noise_files = filter_data()

In [None]:
max_duration = int(np.ceil(filtered_metadata["audio_info/duration"].max()))

audio = load_and_pad(ambisonic_path=ambisonic_files[0], noise_path=noise_files[0], max_duration=max_duration)
features = stft(audio)
norm_features = normalize(features)

norm_features_tensor = torch.tensor(norm_features, dtype=torch.float)

# Generate spike train
spike_train = spikegen.rate(norm_features_tensor, num_steps=config["num_steps"])


## Plotting

In [None]:
# W = audio[0]
# X = audio[1]
# Y = audio[2]
# Z = audio[3]

# W_n = audio_n[0]
# Y_n = audio_n[2]
# X_n = audio_n[1]
# Z_n = audio_n[3]


In [None]:
# # Plot each channel
# fig, axs = plt.subplots(8, 1, figsize=(10, 8), sharex=True)

# axs[0].plot(W)
# axs[0].set_title('W Channel')
# axs[1].plot(W_n)
# axs[1].set_title('W_n Channel')

# axs[2].plot(X)
# axs[2].set_title('X Channel')
# axs[3].plot(X_n)
# axs[3].set_title('X_n Channel')


# axs[4].plot(Y)
# axs[4].set_title('Y Channel')
# axs[5].plot(Y_n)
# axs[5].set_title('Y_n Channel')

# axs[6].plot(Z)
# axs[6].set_title('Z Channel')
# axs[7].plot(Z_n)
# axs[7].set_title('Z_n Channel')

# # Common settings for all subplots
# for ax in axs:
#     ax.set_ylabel('Amplitude')
#     ax.label_outer()

# axs[-1].set_xlabel('Sample')

# plt.tight_layout()
# plt.show()
