<a href="https://colab.research.google.com/github/mercadoerik1031/snn-sound-localization/blob/new_approach/snn_sound_localization_active_reactive.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [39]:
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [40]:
! pip install snntorch --quiet

In [41]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import torchaudio
from torchaudio.transforms import Resample

import snntorch as snn
from snntorch import spikegen

from sklearn.model_selection import train_test_split

import os

import pandas as pd

In [42]:
config = {
    # Filter Data
    "metadata_path": "/content/drive/My Drive/Colab Notebooks/Masters Project/data/metadata.parquet",
    "speech_path": "/content/drive/My Drive/Colab Notebooks/Masters Project/data/ambisonics_lite",
    "noise_path": "/content/drive/My Drive/Colab Notebooks/Masters Project/data/noise_ambisonics_lite",
    "is_lite_version": True,

    # Load Mix and STFT Audio
    "sr": 16_000,
    "duration": 0.5,
    "noise_ratio": None, # Optional (Large #s make noise louder, Small #s make noise quieter)
    "n_fft": 512,
    "hop_length": 128,
    "win_length": 512,

    # Split Data
    "val_size": 0.15,

    # Dataset
    "num_steps": 10,
    "batch_size": 32,

    # SNN
    "beta": 0.65,
        # Active Branch
    "a_thresh1": 10,
    "a_thresh2": 10,
    "a_thresh3": 10,
    "a_thresh4": 10,
        # Reactive Branch
    "re_thresh1": 10,
    "re_thresh2": 10,
    "re_thresh3": 10,
    "re_thresh4": 10,

    # Training
    "num_epochs": 20,
    "device": torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    "lr": 0.0003,
}

In [43]:
def filter_data(
        metadata_pth=config["metadata_path"],
        speech_pth=config["speech_path"],
        noise_pth=config["noise_path"],
        lite_version=config["is_lite_version"]
):

    metadata = pd.read_parquet(metadata_pth, engine="pyarrow")
    if lite_version:
        metadata = metadata[metadata["lite_version"] == True]

    data = []
    for _, row in metadata.iterrows():
        sample = {
            "sample_id": row["sample_id"],
            "speech_path": os.path.join(speech_pth, f"{row['sample_id']:06}.flac"),
            "noise_path": os.path.join(noise_pth, f"{row['sample_id']:06}.flac") if noise_pth else None,
            "azimuth": row["speech/azimuth"],
            "elevation": row["speech/elevation"],
            "split": row["split"]
        }
        data.append(sample)

    return data

In [44]:
def split_data(data, val_size=config["val_size"]):
    train_data = [sample for sample in data if sample["split"] == "train"]
    test_data = [sample for sample in data if sample["split"] == "test"]

    train_data, val_data = train_test_split(train_data, test_size=val_size, random_state=42)

    return train_data, val_data, test_data


In [45]:
def load_mix_and_stft_foa_audio(
        speech_pth,
        noise_pth=None,
        sr=config["sr"],
        duration=config["duration"],
        noise_ratio=config["noise_ratio"],
        n_fft=config["n_fft"],
        hop_length=config["hop_length"],
        win_length=config["win_length"]
):
    # Function to load, resample, normalize, and trim/pad audio
    def preprocess_audio(audio_pth, sr, target_len):
        audio, audio_sr = torchaudio.load(audio_pth)
        if audio_sr != sr:
            resample_transform = torchaudio.transforms.Resample(orig_freq=audio_sr, new_freq=sr)
            audio = resample_transform(audio)
        max_val = audio.abs().max()
        if max_val > 0:  # Avoid division by zero
            audio = audio / max_val
        if audio.size(1) > target_len:
            audio = audio[:, :target_len]
        elif audio.size(1) < target_len:
            padding_size = target_len - audio.size(1)
            audio = torch.nn.functional.pad(audio, (0, padding_size), "constant", 0)
        return audio

    target_len = int(duration * sr)
    speech_audio = preprocess_audio(speech_pth, sr, target_len)
    should_renormalize = False

    if noise_pth is not None:
        noise_audio = preprocess_audio(noise_pth, sr, target_len)

        if noise_ratio is not None:
            # Adjust noise level relative to speech
            noise_audio = noise_audio * noise_ratio
            should_renormalize = True

        # Mix speech and noise
        mixed_audio = speech_audio + noise_audio
    else:
        mixed_audio = speech_audio

    if should_renormalize:
        # Re-normalize only if noise has been adjusted and mixed
        max_val = mixed_audio.abs().max()
        if max_val > 0:
            mixed_audio = mixed_audio / max_val

    # Compute the STFT of the mixed audio
    stft = torch.stft(mixed_audio,
                      n_fft=n_fft,
                      hop_length=hop_length,
                      win_length=win_length,
                      window=torch.hann_window(win_length),
                      center=True,
                      normalized=False,
                      onesided=True,
                      return_complex=True)

    return stft


In [46]:
def compute_active_reactive_intensities(stft, rho=1.21, c=343):
    """
    Compute active and reactive intensity vectors from STFT of 4-channel FOA audio.
    Args:
    - stft: STFT of the FOA audio with shape [4, Frequency Bins, Time Frames].
    - rho: Mean density of air (in kg/m^3).
    - c: Speed of sound in air (in m/s).

    Returns:
    - Ia: Active intensity vector.
    - Ir: Reactive intensity vector.
    """
    # Constants
    three = torch.tensor(3.0, dtype=torch.float, device=stft.device)
    normalization_factor = -1 / (rho * c * torch.sqrt(three))

    # Extract channels
    p = stft[0]  # Pressure (W channel)
    vx = stft[1] * normalization_factor  # Velocity X
    vy = stft[2] * normalization_factor  # Velocity Y
    vz = stft[3] * normalization_factor  # Velocity Z

    # Compute complex conjugate of pressure
    p_star = torch.conj(p)

    # Calculate active and reactive intensity vectors
    Ia_x = torch.real(p_star * vx)
    Ia_y = torch.real(p_star * vy)
    Ia_z = torch.real(p_star * vz)

    Ir_x = torch.imag(p_star * vx)
    Ir_y = torch.imag(p_star * vy)
    Ir_z = torch.imag(p_star * vz)

    # Summing up the components to get total active and reactive intensities
    # Ia = Ia_x + Ia_y + Ia_z  # Total active intensity
    # Ir = Ir_x + Ir_y + Ir_z  # Total reactive intensity

    # Create stack for each channel [3, num_samples, num_frames]
    Ia = torch.stack((Ia_x, Ia_y, Ia_z), dim=0)
    Ir = torch.stack((Ir_x, Ir_y, Ir_z), dim=0)

    return Ia, Ir


In [47]:
class AmbisonicDataset(Dataset):
    def __init__(self, data, config):
        """
        Args:
            data (list of dicts): Each dictionary contains paths and labels for a sample.
            config (dict): Configuration dictionary including sample rate (sr), duration, etc.
        """
        self.data = data
        self.config = config

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]

        # Load and process ambisonic audio
        speech_path = sample['speech_path']
        noise_path = sample['noise_path'] if 'noise_path' in sample and sample['noise_path'] is not None else None

        stft_audio = load_mix_and_stft_foa_audio(
            speech_path,
            noise_pth=noise_path,
            sr=self.config['sr'],
            duration=self.config['duration'],
            noise_ratio=self.config['noise_ratio'],
            n_fft=self.config['n_fft'],
            hop_length=self.config['hop_length'],
            win_length=self.config['win_length']
        )

        # Compute active and reactive intensities
        Ia, Ir = compute_active_reactive_intensities(stft_audio, rho=1.21, c=343)

        # Generate Spike Trains
        spikes_Ia = spikegen.rate(Ia, num_steps=self.config["num_steps"])
        spikes_Ir = spikegen.rate(Ir, num_steps=self.config["num_steps"])

        azimuth = sample['azimuth']
        elevation = sample['elevation']
        label = torch.tensor([azimuth, elevation], dtype=torch.float)

        #print(f"stft_audio.shape: {stft_audio.shape}\nIa.shape:{Ia.shape}, Ir.shape: {Ir.shape}\nspikes_Ia.shape: {spikes_Ia.shape}, spikes_Ir.shape: {spikes_Ir.shape}\nlabel.shape: {label.shape}")

        return spikes_Ia, spikes_Ir, label


In [48]:
# for spikes_active, spikes_reactive, lables in train_loader:
#     print(spikes_active.shape)
#     print(spikes_reactive.shape)
#     print(lables.shape)
#     break

In [49]:
# for spikes_Ia, spikes_Ir, _ in train_loader:
#     spikes_per_step_Ia = spikes_Ia.sum(dim=[2, 3])
#     spikes_per_step_Ir = spikes_Ir.sum(dim=[2, 3])
#     print(f"Spikes in Ia: {spikes_per_step_Ia}")
#     print(f"Spikes in Ir: {spikes_per_step_Ir}")
#     break


In [50]:
class SNN(nn.Module):
    def __init__(self, config):
        super(SNN, self).__init__()
        self.num_steps = config["num_steps"]
        self.beta = config["beta"]

        # Active Thresholds
        self.a_thr1 = config["a_thresh1"]
        self.a_thr2 = config["a_thresh2"]
        self.a_thr3 = config["a_thresh3"]
        self.a_thr4 = config["a_thresh4"]

        # Reactive Thresholds
        self.re_thr1 = config["re_thresh1"]
        self.re_thr2 = config["re_thresh2"]
        self.re_thr3 = config["re_thresh3"]
        self.re_thr4 = config["re_thresh4"]

        # Define the active branch with LIF neurons
        self.active_branch = nn.Sequential(
            nn.Conv2d(3, 6, kernel_size=3, padding=1),
            nn.BatchNorm2d(6),
            nn.MaxPool2d(2),
            snn.Leaky(beta=self.beta, threshold=self.a_thr1, init_hidden=True),
            #snn.Leaky(beta=self.beta, threshold=self.a_thr1),

            nn.Conv2d(6, 12, kernel_size=3, padding=1),
            nn.BatchNorm2d(12),
            nn.MaxPool2d(2),
            snn.Leaky(beta=self.beta, threshold=self.a_thr2, init_hidden=True),
            #snn.Leaky(beta=self.beta, threshold=self.a_thr2),

            nn.Conv2d(12, 24, kernel_size=3, padding=1),
            nn.BatchNorm2d(24),
            nn.MaxPool2d(2),
            snn.Leaky(beta=self.beta, threshold=self.a_thr3, init_hidden=True),
            #snn.Leaky(beta=self.beta, threshold=self.a_thr3),

            nn.Conv2d(24, 48, kernel_size=3, padding=1),
            nn.BatchNorm2d(48),
            nn.MaxPool2d(2),
            snn.Leaky(beta=self.beta, threshold=self.a_thr4, init_hidden=True),
            #snn.Leaky(beta=self.beta, threshold=self.a_thr4),
        )

        self.reactive_branch = nn.Sequential(
            nn.Conv2d(3, 6, kernel_size=3, padding=1),
            nn.BatchNorm2d(6),
            nn.MaxPool2d(2),
            snn.Leaky(beta=self.beta, threshold=self.re_thr1, init_hidden=True),
            #snn.Leaky(beta=self.beta, threshold=self.re_thr1),

            nn.Conv2d(6, 12, kernel_size=3, padding=1),
            nn.BatchNorm2d(12),
            nn.MaxPool2d(2),
            snn.Leaky(beta=self.beta, threshold=self.re_thr2, init_hidden=True),
            #snn.Leaky(beta=self.beta, threshold=self.re_thr2),

            nn.Conv2d(12, 24, kernel_size=3, padding=1),
            nn.BatchNorm2d(24),
            nn.MaxPool2d(2),
            snn.Leaky(beta=self.beta, threshold=self.re_thr3, init_hidden=True),
            #snn.Leaky(beta=self.beta, threshold=self.re_thr3),

            nn.Conv2d(24, 48, kernel_size=3, padding=1),
            nn.BatchNorm2d(48),
            nn.MaxPool2d(2),
            snn.Leaky(beta=self.beta, threshold=self.re_thr4, init_hidden=True),
            #snn.Leaky(beta=self.beta, threshold=self.re_thr4),
        )

        # MLP for concatenated features
        self.fc = nn.Sequential(
            nn.Linear(4608, 512),
            nn.ReLU(),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, 2)
        )

    # def forward(self, active_input, reactive_input):
    #     active_in_permuted = active_input.permute(1, 0, 2, 3, 4) # [num_step, batch_size, channels, num_samples, num_frames]
    #     reactive_in_permuted = reactive_input.permute(1, 0, 2, 3, 4)

    #     active_outputs = []
    #     reactive_outputs = []

    #     for step in range(active_in_permuted.size(0)):
    #         active_step = self.active_branch(active_in_permuted[step])
    #         reactive_step = self.reactive_branch(reactive_in_permuted[step])

    #         active_outputs.append(active_step)
    #         reactive_outputs.append(reactive_step)

    #     # Stack the outputs across time steps to form tensors of shape [num_steps, batch_size, channels, height, width]
    #     active_stacked = torch.stack(active_outputs, dim=0)
    #     reactive_stacked = torch.stack(reactive_outputs, dim=0)

    #     # Aggregate across time steps, e.g., by taking the mean or sum
    #     active_agg = torch.mean(active_stacked, dim=0)
    #     reactive_agg = torch.mean(reactive_stacked, dim=0)

    #     # Flatten and concatenate the aggregated outputs for the final MLP
    #     active_flat = active_agg.view(active_agg.size(0), -1)
    #     reactive_flat = reactive_agg.view(reactive_agg.size(0), -1)
    #     combined = torch.cat((active_flat, reactive_flat), dim=1)

    #     output = self.fc(combined)

    #     return output

    def forward(self, active_input, reactive_input):
        active_in_permuted = active_input.permute(1, 0, 2, 3, 4)  # [num_step, batch_size, channels, height, width]
        reactive_in_permuted = reactive_input.permute(1, 0, 2, 3, 4)

        step_outputs = []

        for step in range(active_in_permuted.size(0)):
            active_step = self.active_branch(active_in_permuted[step])
            reactive_step = self.reactive_branch(reactive_in_permuted[step])

            active_step = active_step.view(active_step.size(0), -1) # Flatten
            reactive_step = reactive_step.view(reactive_step.size(0), -1) # Flatten

            combined = torch.cat((active_step, reactive_step), dim=1)

            out = self.fc(combined)

            step_outputs.append(out)

        step_outputs_tensor = torch.stack(step_outputs, dim=0)
        output = torch.mean(step_outputs_tensor, dim=0)

        return output




In [51]:
def calc_median_absolute_error(
        t_azimuth,
        t_elevation,
        p_azimuth,
        p_elevation
):
    """
    From Section 4: https://www.isca-archive.org/interspeech_2023/sarabia23_interspeech.html

    Parameters:
    t_azimuth (float): Azimuth angle of the true point in radians.
    t_elevation (float): Elevation angle of the true point in radians.
    p_azimuth (float): Azimuth angle of the predicted point in radians.
    p_elevation (float): Elevation angle of the predicted point in radians.

    Returns:
    float: The angular distance in degrees.
    """

    # Calc angle error in radians
    error_rad = torch.acos(
        torch.sin(t_azimuth) * torch.sin(p_azimuth) +
        torch.cos(t_azimuth) * torch.cos(p_azimuth) * torch.cos(t_elevation - p_elevation)
    )

    # Convert radians to degrees
    error_deg = torch.rad2deg(error_rad)

    # Calc median
    median_error = torch.median(torch.abs(error_deg))

    return median_error




In [52]:
def train(model, train_loader, criterion, optimizer, device):
    model.train()  # Set model to training mode

    train_loss = 0.0
    true_azimuths, true_elevations, pred_azimuths, pred_elevations = [], [], [], []

    for batch_idx, (active_input, reactive_input, labels) in enumerate(train_loader):
        print(f"Batch {batch_idx+1}/{len(train_loader)}")

        # Move data and labels to the device
        active_input, reactive_input, labels = active_input.to(device), reactive_input.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(active_input, reactive_input)  # Pass both inputs to the model

        # Compute loss
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward(retain_graph=True)
        # loss.backward()
        optimizer.step()

        # Update training loss
        train_loss = train_loss + loss.item() * active_input.size(0)

        true_azimuths.append(labels[:, 0].detach())
        true_elevations.append(labels[:, 1].detach())
        pred_azimuths.append(outputs[:, 0].detach())
        pred_elevations.append(outputs[:, 1].detach())

    # Calculate average loss over the dataset
    train_loss = train_loss / len(train_loader.dataset)

    return train_loss, torch.cat(true_azimuths), torch.cat(true_elevations), torch.cat(pred_azimuths), torch.cat(pred_elevations)


In [53]:
def validate(model, val_loader, criterion, device):
    model.eval()  # Set the model to evaluation mode

    valid_loss = 0.0
    true_azimuths, true_elevations, pred_azimuths, pred_elevations = [], [], [], []

    with torch.no_grad():  # No gradients needed
        for batch_idx, (active_input, reactive_input, labels) in enumerate(val_loader):
            print(f"Batch {batch_idx+1}/{len(val_loader)}")

            # Move the inputs and labels to the specified device
            active_input, reactive_input, labels = active_input.to(device), reactive_input.to(device), labels.to(device)

            # Forward pass: compute the model output
            outputs = model(active_input, reactive_input)

            # Compute the loss
            loss = criterion(outputs, labels)
            valid_loss = valid_loss + loss.item() * active_input.size(0)

            true_azimuths.append(labels[:, 0].detach())
            true_elevations.append(labels[:, 1].detach())
            pred_azimuths.append(outputs[:, 0].detach())
            pred_elevations.append(outputs[:, 1].detach())

    # Calculate the average loss over the dataset
    valid_loss = valid_loss / len(val_loader.dataset)

    return valid_loss, torch.cat(true_azimuths), torch.cat(true_elevations), torch.cat(pred_azimuths), torch.cat(pred_elevations)


In [54]:
def test(model, test_loader, criterion, device):
    model.eval()  # Set the model to evaluation mode

    test_loss = 0.0
    true_azimuths, true_elevations, pred_azimuths, pred_elevations = [], [], [], []

    with torch.no_grad():  # No gradients needed during testing
        for batch_idx, (active_input, reactive_input, labels) in enumerate(test_loader):
            active_input, reactive_input, labels = active_input.to(device), reactive_input.to(device), labels.to(device)

            # Forward pass: compute the model output
            outputs = model(active_input, reactive_input)

            # Compute the loss
            loss = criterion(outputs, labels)
            test_loss = test_loss + loss.item() * active_input.size(0)

            # Optionally, accumulate metrics here
            true_azimuths.append(labels[:, 0].detach())
            true_elevations.append(labels[:, 1].detach())
            pred_azimuths.append(outputs[:, 0].detach())
            pred_elevations.append(outputs[:, 1].detach())

    # Calculate the average loss over the dataset
    test_loss = test_loss / len(test_loader.dataset)

    return test_loss, torch.cat(true_azimuths), torch.cat(true_elevations), torch.cat(pred_azimuths), torch.cat(pred_elevations)


In [55]:
data = filter_data()
train_data, val_data, test_data = split_data(data)

train_dataset = AmbisonicDataset(data=train_data, config=config)
val_dataset = AmbisonicDataset(data=val_data, config=config)
test_dataset = AmbisonicDataset(data=test_data, config=config)

train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=False, num_workers=2)

In [56]:
device = config["device"]
model = SNN(config).to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"])

In [57]:
torch.autograd.set_detect_anomaly(True)

train_loss = []
valid_loss = []

for epoch in range(config["num_epochs"]):
    print(f"Epoch {epoch+1}/{config['num_epochs']}")

    train_epoch_loss, train_true_azimuth, train_true_elevation, train_pred_azimuth, train_pred_elevation = train(model, train_loader, criterion, optimizer, device)
    valid_epoch_loss, valid_true_azimuth, valid_true_elevation, valid_pred_azimuth, valid_pred_elevation = validate(model, val_loader, criterion, device)

    train_loss.append(train_epoch_loss)
    valid_loss.append(valid_epoch_loss)

    train_angle_error = calc_median_absolute_error(train_true_azimuth, train_true_elevation, train_pred_azimuth, train_pred_elevation)
    valid_angle_error = calc_median_absolute_error(valid_true_azimuth, valid_true_elevation, valid_pred_azimuth, valid_pred_elevation)

    print(f"Train Loss: {train_epoch_loss:.4f} | Validation Loss: {valid_epoch_loss:.4f}")
    print(f"Train Angle Error: {train_angle_error:.4f}° | Validation Angle Error: {valid_angle_error:.4f}°\n")

test_loss, test_true_azimuth, test_true_elevation, test_pred_azimuth, test_pred_elevation = test(model, test_loader, criterion, device)
test_angle_error = calc_median_absolute_error(test_true_azimuth, test_true_elevation, test_pred_azimuth, train_pred_elevation)

print(f"Test Loss: {test_loss:.4f}")
print(f"Test Angle Error: {test_angle_error:.4f}\n")

Epoch 1/20
Batch 1/457
Batch 2/457


  File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/usr/local/lib/python3.10/dist-packages/colab_kernel_launcher.py", line 37, in <module>
    ColabKernelApp.launch_instance()
  File "/usr/local/lib/python3.10/dist-packages/traitlets/config/application.py", line 992, in launch_instance
    app.start()
  File "/usr/local/lib/python3.10/dist-packages/ipykernel/kernelapp.py", line 619, in start
    self.io_loop.start()
  File "/usr/local/lib/python3.10/dist-packages/tornado/platform/asyncio.py", line 195, in start
    self.asyncio_loop.run_forever()
  File "/usr/lib/python3.10/asyncio/base_events.py", line 603, in run_forever
    self._run_once()
  File "/usr/lib/python3.10/asyncio/base_events.py", line 1909, in _run_once
    handle._run()
  File "/usr/lib/python3.10/asyncio/events.py", line 80, in _run
    self._context.run(s

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [48]] is at version 3; expected version 1 instead. Hint: the backtrace further above shows the operation that failed to compute its gradient. The variable in question was changed in there or anywhere later. Good luck!