In [45]:
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [46]:
! pip install snntorch --quiet

In [47]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import torchaudio
from torchaudio.transforms import Resample

import snntorch as snn
from snntorch import spikegen

from sklearn.model_selection import train_test_split

import os

import pandas as pd

In [48]:
config = {
    # Filter Data
    "metadata_path": "/content/drive/My Drive/Colab Notebooks/Masters Project/data/metadata.parquet",
    "speech_path": "/content/drive/My Drive/Colab Notebooks/Masters Project/data/ambisonics_lite",
    "noise_path": "/content/drive/My Drive/Colab Notebooks/Masters Project/data/noise_ambisonics_lite",
    "is_lite_version": True,

    # Load Mix and STFT Audio
    "sr": 16_000,
    "duration": 0.5,
    "noise_ratio": None, # Optional (Large #s make noise louder, Small #s make noise quieter)
    "n_fft": 512,
    "hop_length": 128,
    "win_length": 512,

    # Split Data
    "val_size": 0.15,

    # Dataset
    "num_steps": 10,

    # SNN
    "beta": 0.65,
}

In [49]:
def filter_data(
        metadata_pth=config["metadata_path"],
        speech_pth=config["speech_path"],
        noise_pth=config["noise_path"],
        lite_version=config["is_lite_version"]
):

    metadata = pd.read_parquet(metadata_pth, engine="pyarrow")
    if lite_version:
        metadata = metadata[metadata["lite_version"] == True]

    data = []
    for _, row in metadata.iterrows():
        sample = {
            "sample_id": row["sample_id"],
            "speech_path": os.path.join(speech_pth, f"{row['sample_id']:06}.flac"),
            "noise_path": os.path.join(noise_pth, f"{row['sample_id']:06}.flac") if noise_pth else None,
            "azimuth": row["speech/azimuth"],
            "elevation": row["speech/elevation"],
            "split": row["split"]
        }
        data.append(sample)

    return data

In [50]:
def split_data(data, val_size=config["val_size"]):
    train_data = [sample for sample in data if sample["split"] == "train"]
    test_data = [sample for sample in data if sample["split"] == "test"]

    train_data, val_data = train_test_split(train_data, test_size=val_size, random_state=42)

    return train_data, val_data, test_data


In [51]:
def load_mix_and_stft_foa_audio(
        speech_pth,
        noise_pth=None,
        sr=config["sr"],
        duration=config["duration"],
        noise_ratio=config["noise_ratio"],
        n_fft=config["n_fft"],
        hop_length=config["hop_length"],
        win_length=config["win_length"]
):
    # Function to load, resample, normalize, and trim/pad audio
    def preprocess_audio(audio_pth, sr, target_len):
        audio, audio_sr = torchaudio.load(audio_pth)
        if audio_sr != sr:
            resample_transform = torchaudio.transforms.Resample(orig_freq=audio_sr, new_freq=sr)
            audio = resample_transform(audio)
        max_val = audio.abs().max()
        if max_val > 0:  # Avoid division by zero
            audio = audio / max_val
        if audio.size(1) > target_len:
            audio = audio[:, :target_len]
        elif audio.size(1) < target_len:
            padding_size = target_len - audio.size(1)
            audio = torch.nn.functional.pad(audio, (0, padding_size), "constant", 0)
        return audio

    target_len = int(duration * sr)
    speech_audio = preprocess_audio(speech_pth, sr, target_len)
    should_renormalize = False

    if noise_pth is not None:
        noise_audio = preprocess_audio(noise_pth, sr, target_len)

        if noise_ratio is not None:
            # Adjust noise level relative to speech
            noise_audio = noise_audio * noise_ratio
            should_renormalize = True

        # Mix speech and noise
        mixed_audio = speech_audio + noise_audio
    else:
        mixed_audio = speech_audio

    if should_renormalize:
        # Re-normalize only if noise has been adjusted and mixed
        max_val = mixed_audio.abs().max()
        if max_val > 0:
            mixed_audio = mixed_audio / max_val

    # Compute the STFT of the mixed audio
    stft = torch.stft(mixed_audio,
                      n_fft=n_fft,
                      hop_length=hop_length,
                      win_length=win_length,
                      window=torch.hann_window(win_length),
                      center=True,
                      normalized=False,
                      onesided=True,
                      return_complex=True)

    return stft


In [52]:
def compute_active_reactive_intensities(stft, rho=1.21, c=343):
    """
    Compute active and reactive intensity vectors from STFT of 4-channel FOA audio.
    Args:
    - stft: STFT of the FOA audio with shape [4, Frequency Bins, Time Frames].
    - rho: Mean density of air (in kg/m^3).
    - c: Speed of sound in air (in m/s).

    Returns:
    - Ia: Active intensity vector.
    - Ir: Reactive intensity vector.
    """
    # Constants
    three = torch.tensor(3.0, dtype=torch.float, device=stft.device)
    normalization_factor = -1 / (rho * c * torch.sqrt(three))

    # Extract channels
    p = stft[0]  # Pressure (W channel)
    vx = stft[1] * normalization_factor  # Velocity X
    vy = stft[2] * normalization_factor  # Velocity Y
    vz = stft[3] * normalization_factor  # Velocity Z

    # Compute complex conjugate of pressure
    p_star = torch.conj(p)

    # Calculate active and reactive intensity vectors
    Ia_x = torch.real(p_star * vx)
    Ia_y = torch.real(p_star * vy)
    Ia_z = torch.real(p_star * vz)

    Ir_x = torch.imag(p_star * vx)
    Ir_y = torch.imag(p_star * vy)
    Ir_z = torch.imag(p_star * vz)

    # Summing up the components to get total active and reactive intensities
    Ia = Ia_x + Ia_y + Ia_z  # Total active intensity
    Ir = Ir_x + Ir_y + Ir_z  # Total reactive intensity

    return Ia, Ir


In [53]:
class AmbisonicDataset(Dataset):
    def __init__(self, data, config):
        """
        Args:
            data (list of dicts): Each dictionary contains paths and labels for a sample.
            config (dict): Configuration dictionary including sample rate (sr), duration, etc.
        """
        self.data = data
        self.config = config

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]

        # Load and process ambisonic audio
        speech_path = sample['speech_path']
        noise_path = sample['noise_path'] if 'noise_path' in sample and sample['noise_path'] is not None else None

        stft_audio = load_mix_and_stft_foa_audio(
            speech_path,
            noise_pth=noise_path,
            sr=self.config['sr'],
            duration=self.config['duration'],
            noise_ratio=self.config['noise_ratio'],
            n_fft=self.config['n_fft'],
            hop_length=self.config['hop_length'],
            win_length=self.config['win_length']
        )

        # Compute active and reactive intensities
        Ia, Ir = compute_active_reactive_intensities(stft_audio, rho=1.21, c=343)

        # Generate Spike Trains
        spikes_Ia = spikegen.rate(Ia, num_steps=self.config["num_steps"])
        spikes_Ir = spikegen.rate(Ir, num_steps=self.config["num_steps"])

        azimuth = sample['azimuth']
        elevation = sample['elevation']
        label = torch.tensor([azimuth, elevation], dtype=torch.float)

        return spikes_Ia, spikes_Ir, label


In [54]:
data = filter_data()
train_data, val_data, test_data = split_data(data)

train_dataset = AmbisonicDataset(data=train_data, config=config)
val_dataset = AmbisonicDataset(data=val_data, config=config)
test_dataset = AmbisonicDataset(data=test_data, config=config)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)


In [55]:
# for spikes_active, spikes_reactive, lables in train_loader:
#     print(spikes_active.shape)
#     print(spikes_reactive.shape)
#     print(lables.shape)
#     break

torch.Size([32, 10, 257, 63])
torch.Size([32, 10, 257, 63])
torch.Size([32, 2])


In [59]:
# for spikes_Ia, spikes_Ir, _ in train_loader:
#     spikes_per_step_Ia = spikes_Ia.sum(dim=[2, 3])
#     spikes_per_step_Ir = spikes_Ir.sum(dim=[2, 3])
#     print(f"Spikes in Ia: {spikes_per_step_Ia}")
#     print(f"Spikes in Ir: {spikes_per_step_Ir}")
#     break


Spikes in Ia: tensor([[155., 159., 155., 147., 148., 144., 156., 135., 165., 151.],
        [ 89.,  95.,  87.,  92.,  96., 100.,  85.,  90.,  78., 115.],
        [ 44.,  48.,  51.,  54.,  50.,  44.,  46.,  53.,  50.,  47.],
        [114., 109., 135., 126., 112., 131., 130., 125., 135., 118.],
        [115., 122., 134., 131., 129., 140., 123., 140., 130., 136.],
        [123., 116., 140., 136., 122., 117., 131., 120., 129., 139.],
        [ 75.,  58.,  80.,  71.,  77.,  57.,  71.,  65.,  68.,  87.],
        [109., 115., 100., 109., 110., 114.,  88., 112.,  96., 117.],
        [ 10.,  17.,  13.,  14.,  14.,  12.,  10.,  12.,  15.,  11.],
        [150., 179., 162., 148., 172., 172., 146., 161., 161., 154.],
        [109., 133., 130., 112., 111., 109., 101., 114., 114., 123.],
        [ 52.,  51.,  49.,  47.,  59.,  53.,  52.,  52.,  51.,  58.],
        [ 10.,   9.,   7.,  11.,   8.,  12.,  11.,  10.,   7.,   8.],
        [100.,  95.,  99., 101.,  98.,  92.,  90.,  99., 101.,  98.],
      

In [None]:
class SNN(nn.Module):
    def __init__(self, config):
        super(SNN, self).__init__()
        self.num_steps
        self.beta = config["beta"]

        # snntorch LIF neuron
        self.lif_neuron = snn.Leaky(beta=self.beta,)

        # Define the active branch with LIF neurons
        self.active_branch = nn.Sequential(
            nn.Conv2d(1, 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(2),
            nn.MaxPool2d(2),
            # LIF neuron

            nn.Conv2d(2, 4, kernel_size=3, padding=1),
            nn.BatchNorm2d(4),
            nn.MaxPool2d(2),
            # LIF neuron

            nn.Conv2d(4, 8, kernel_size=3, padding=1),
            nn.BatchNorm2d(8),
            nn.MaxPool2d(2),
            # LIF neuron

            nn.Conv2d(8, 16, kernel_size=3, padding=1),
            nn.BatchNorm2d(16),
            nn.MaxPool2d(2)
            # LIF neuron
        )

        self.reactive_branch = nn.Sequential(
            nn.Conv2d(1, 2, kernel_size=3, padding=1),
            nn.BatchNorm2d(2),
            nn.MaxPool2d(2),
            # LIF neuron
        )

        # MLP for concatenated features
        self.fc = nn.Sequential(
            nn.Linear(16*16*16, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 2)
        )

    def forward(self, active_input, reactive_input):
        # Apply convolutions and LIF neurons per timestep

        active_output = self.active_branch(active_spikes)
        reactive_output = self.reactive_branch(reactive_spikes)

        # Concatenate, flatten, and pass through MLP
        combined = torch.cat((active_output, reactive_output), dim=1)
        combined = combined.view(combined.size(0), -1)  # Flatten
        output = self.fc(combined)

        return output


In [56]:
# def calc_median_absolute_error(
#         t_azimuth,
#         t_elevation,
#         p_azimuth,
#         p_elevation
# ):
#     """
#     From Section 4: https://www.isca-archive.org/interspeech_2023/sarabia23_interspeech.html

#     Parameters:
#     true_azimuth (float): Azimuth angle of the true point in radians.
#     true_elevation (float): Elevation angle of the true point in radians.
#     pred_azimuth (float): Azimuth angle of the predicted point in radians.
#     pred_elevation (float): Elevation angle of the predicted point in radians.

#     Returns:
#     float: The angular distance in degrees.
#     """

#     # Calc angle error in radians
#     error_rad = torch.acos(
#         torch.sin(true_azimuth) * torch.sin(pred_azimuth) +
#         torch.cos(true_azimuth) * torch.cos(pred_azimuth) * torch.cos(true_elevation - pred_elevation)
#     )

#     # Convert radians to degrees
#     error_deg = torch.rad2deg(error_rad)

#     # Calc median
#     median_error = torch.median(torch.abs(error_deg))

#     return median_error




In [57]:
# Transform and normalize segments of spatial audio encoded as 4-channel first-order ambisonics into active and reactive compents
    # The duration of the segments varies by task 3D source localization 0.5s

# Active and reactive compenets are fed through two independent branches of four 3D-convolutional layers each of 2,4,8,16 channels respectively
    # with max-pool batch-norm and exponential linear units between the convolution layers.
    # The output of both branches is flattened and concatenated and fed to a 3-layer multilayer perception

# Trained for 20 Epochs
