<a href="https://colab.research.google.com/github/mercadoerik1031/snn-sound-localization/blob/write_to_disk/snn_sound_localization.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#**SNN Sounnd Localization**



---



# Pip Installs

In [None]:
! pip install snntorch brian2 brian2hears --quiet

# Imports

In [None]:
import pandas as pd
import os
import librosa
import torch
import numpy as np
from tqdm import tqdm
from snntorch import spikegen
from brian2 import *
from brian2hears import *
import gc

In [None]:
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


# Config

In [None]:
config = {
    # Google Colab Path
    "metadata_path": "/content/drive/My Drive/Colab Notebooks/Masters Project/metadata.parquet",
    "ambisonics_path": "/content/drive/My Drive/Colab Notebooks/Masters Project/spatial_librispeech_sample/ambisonics_sample",
    "noise_path": "/content/drive/My Drive/Colab Notebooks/Masters Project/spatial_librispeech_sample/noise_ambisonics_sample",
    "output_path": "/content/drive/My Drive/Colab Notebooks/Masters Project/spatial_librispeech_sample/preprocessed_samples",
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "batch_size_pre": 5,
    "sr": 16000,

    "time_based_encoding": True,
    "num_steps": 10,
    "max_rate": 10,
    "noise": True,
}

# Filter Data

In [None]:
def filter_data(metadata_path=config["metadata_path"], ambisonics_path=config["ambisonics_path"], noise_path=config["noise_path"]):
  # Load metadata
  metadata = pd.read_parquet(metadata_path, engine="pyarrow")

  # Get lists of all files in directories
  ambisonic_files = [f for f in os.listdir(ambisonics_path) if os.path.isfile(os.path.join(ambisonics_path, f))]
  noise_files = [f for f in os.listdir(noise_path) if os.path.isfile(os.path.join(noise_path, f))]

  # Extract sample ids from filenames and filter metadata
  sample_ids = [int(f.split(".")[0].lstrip("0") or 0) for f in ambisonic_files]
  filtered_metadata = metadata[metadata["sample_id"].isin(sample_ids)]

  # Create full file paths
  ambisonic_files = [os.path.join(ambisonics_path, f) for f in ambisonic_files]
  noise_files = [os.path.join(noise_path, f) for f in noise_files]

  return filtered_metadata, ambisonic_files, noise_files


# Preprocess Functions

## Cochlear Filter

In [None]:
# def cochlear_filter(audio_data, sr):
#   num_channels = audio_data.shape[0]
#   processed_channels = []

#   for channel in range(num_channels):
#     sound = Sound(audio_data[channel], samplerate=sr*Hz)
#     cf = erbspace(20*Hz, 20*kHz, 32)
#     gammatone = Gammatone(sound, cf)
#     filtered_sound = gammatone.process()
#     filtered_data = filtered_sound.T
#     processed_channels.append(filtered_data)

#   combined_processed_data = np.stack(processed_channels, axis=0)
#   return combined_processed_data


## Normalize

In [None]:
def normalize(audio_data, device=config["device"]):
  audio_data = audio_data.to(device)
  return (audio_data - audio_data.min()) / (audio_data.max() - audio_data.min())


## Rate Based Encoding

In [None]:
def rate_based_encoding(audio_data, max_rate=config["max_rate"], num_steps=config["num_steps"], device=config["device"]):
  if audio_data is None:
    raise ValueError("Input data is None.")

  audio_tensor = torch.from_numpy(audio_data).float().to(device)

  normalized_data = normalize(audio_tensor)

  spike_rates = normalized_data * max_rate

  spike_train = spikegen.rate(spike_rates, num_steps=num_steps)

  return spike_train


## Time Based Encoding

In [None]:
def time_based_encoding(audio_data, num_steps=config["num_steps"], device=config["device"]):
  if audio_data is None:
    raise ValueError("Input data is None.")

  audio_tensor = torch.from_numpy(audio_data).float().to(device)

  normalized_data = normalize(audio_tensor)

  spike_times = torch.where(normalized_data > 0.5, 1, 0)

  spike_train = spikegen.latency(spike_times, num_steps=num_steps, bypass=True)

  print(f"spike_train.shape: {spike_train.shape}")

  return spike_train


## Preprocess Function

In [None]:
def preprocess(ambisonic_file, noise_file, duration, sr=config["sr"]):
    # Load and Pad Ambisonic File
    audio, _ = librosa.load(ambisonic_file, sr=sr, mono=False, duration=duration)
    length = int(np.round(duration * sr))
    padded_ambisonic = librosa.util.fix_length(audio, size=length, axis=1)

    # Combine Noise (Optional)
    if config["noise"] and noise_file:
        # Load and Pad Noise File
        noise_audio, _ = librosa.load(noise_file, sr=sr, mono=False, duration=duration)
        padded_noise = librosa.util.fix_length(noise_audio, size=length, axis=1)

        # Combine Ambisonic & Noise
        combined_audio = padded_ambisonic + padded_noise
    else:
        combined_audio = padded_ambisonic

    # processed_audio = cochlear_filter(combined_audio, sr)
    spike_trains = time_based_encoding(combined_audio) if config["time_based_encoding"] else rate_based_encoding(processed_audio)

    # return torch.from_numpy(spike_trains)
    return spike_trains


## Append to File

In [None]:
def append_to_file(data, filepath):
  # Append data to an existing file or create a new one
  if os.path.exists(filepath):
      existing_data = torch.load(filepath)
      combined_data = torch.cat((existing_data, data), dim=0)
      torch.save(combined_data, filepath)
  else:
      torch.save(data, filepath)

## Process & Save Batches

In [None]:
import gc  # Import Python's garbage collection module

def process_batches(metadata, ambisonic_files, noise_files, duration, batch_size=config["batch_size_pre"], output_path=config["output_path"], sr=config["sr"]):
    # Iterate over batches
    for i in range(0, len(ambisonic_files), batch_size):
        print(f"Processing batch {i} to {i+batch_size}")
        batch_ambisonic_files = ambisonic_files[i:i+batch_size]
        batch_noise_files = noise_files[i:i+batch_size]
        batch_metadata = metadata.iloc[i:i+batch_size]

        processed_data = []
        labels = []

        # Process each file in the batch
        for ambisonic_file, noise_file, meta_row in zip(batch_ambisonic_files, batch_noise_files, batch_metadata.itertuples()):
            spike_trains = preprocess(ambisonic_file, noise_file, duration)
            processed_data.append(spike_trains)

            labels.append({
                'sample_id': meta_row.sample_id,
                'split': meta_row.split,
                'azimuth': batch_metadata.at[meta_row.Index, 'speech/azimuth'],
                'elevation': batch_metadata.at[meta_row.Index, 'speech/elevation']
            })

        # Save processed data and labels
        torch.save(torch.stack(processed_data), os.path.join(output_path, f'processed_batch_{i}.pt'))
        pd.DataFrame(labels).to_csv(os.path.join(output_path, f'labels_batch_{i}.csv'), index=False)

        # Clear memory
        del processed_data, labels, batch_ambisonic_files, batch_noise_files, batch_metadata
        gc.collect()  # Trigger garbage collection

        print(f"Batch {i} to {i+batch_size} processed and saved.")

    print("All batches processed and saved.")


In [None]:
filtered_metadata, ambisonic_files, noise_files = filter_data()
duration = filtered_metadata["audio_info/duration"].mean() + filtered_metadata["audio_info/duration"].std()
process_batches(filtered_data, ambisonic_files, noise_files, duration)