<a href="https://colab.research.google.com/github/mercadoerik1031/snn-sound-localization/blob/new_approach/snn_sound_localization_active_reactive.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [23]:
from google.colab import drive
drive.mount("/content/drive")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [24]:
! pip install snntorch optuna --quiet

In [25]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import torchaudio
from torchaudio.transforms import Resample

import snntorch as snn
from snntorch import spikegen
from snntorch import utils

from sklearn.model_selection import train_test_split

import os

import pandas as pd

import optuna

from tqdm import tqdm

In [26]:
config = {
    # Filter Data
    "metadata_path": "/content/drive/My Drive/Colab Notebooks/Masters Project/data/metadata.parquet",
    "speech_path": "/content/drive/My Drive/Colab Notebooks/Masters Project/data/ambisonics_lite",
    "noise_path": "/content/drive/My Drive/Colab Notebooks/Masters Project/data/noise_ambisonics_lite",
    "is_lite_version": True,

    # Load Mix and STFT Audio
    "sr": 16_000,
    "duration": 0.5,
    "noise_ratio": None, # Optional (Large #s make noise louder, Small #s make noise quieter)
    "n_fft": 512,
    "hop_length": 128,
    "win_length": 512,

    # Split Data
    "val_size": 0.1,

    # Dataset
    "num_steps": 10,
    "batch_size": 128,

    # SNN
    "do": 0.6,
    "beta": 0.65,
    "beta_re": 0.65,
    "beta_out": 0.65,
        # Active Branch
    "a_thresh1": 10, "a_thresh2": 10, "a_thresh3": 10, "a_thresh4": 10,
        # Reactive Branch
    "re_thresh1": 10, "re_thresh2": 10, "re_thresh3": 10, "re_thresh4": 10,
        # Output Layer
    "out_thresh1": 10, "out_thresh2": 10,

    # Training
    "num_epochs": 10,
    "device": torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
    "lr": 0.0003,
}

In [27]:
def filter_data(
        metadata_pth=config["metadata_path"],
        speech_pth=config["speech_path"],
        noise_pth=config["noise_path"],
        lite_version=config["is_lite_version"]
):

    metadata = pd.read_parquet(metadata_pth, engine="pyarrow")
    if lite_version:
        metadata = metadata[metadata["lite_version"] == True]

    data = []
    for _, row in metadata.iterrows():
        sample = {
            "sample_id": row["sample_id"],
            "speech_path": os.path.join(speech_pth, f"{row['sample_id']:06}.flac"),
            "noise_path": os.path.join(noise_pth, f"{row['sample_id']:06}.flac") if noise_pth else None,
            "azimuth": row["speech/azimuth"],
            "elevation": row["speech/elevation"],
            "split": row["split"]
        }
        data.append(sample)

    return data

In [28]:
def split_data(data, val_size=config["val_size"]):
    train_data = [sample for sample in data if sample["split"] == "train"]
    test_data = [sample for sample in data if sample["split"] == "test"]

    train_data, val_data = train_test_split(train_data, test_size=val_size, random_state=42)

    return train_data, val_data, test_data


In [29]:
def load_mix_and_stft_foa_audio(
        speech_pth,
        noise_pth=None,
        sr=config["sr"],
        duration=config["duration"],
        noise_ratio=config["noise_ratio"],
        n_fft=config["n_fft"],
        hop_length=config["hop_length"],
        win_length=config["win_length"],
        device=config["device"]
):
    # Function to load, resample, normalize, and trim/pad audio
    # def preprocess_audio(audio_pth, sr, target_len, device):
    #     audio, audio_sr = torchaudio.load(audio_pth)
    #     audio = audio.to(device) # Move to GPU
    #     if audio_sr != sr:
    #         resample_transform = torchaudio.transforms.Resample(orig_freq=audio_sr, new_freq=sr)
    #         audio = resample_transform(audio)
    #     max_val = audio.abs().max()
    #     if max_val > 0:  # Avoid division by zero
    #         audio = audio / max_val
    #     if audio.size(1) > target_len:
    #         audio = audio[:, :target_len]
    #     elif audio.size(1) < target_len:
    #         padding_size = target_len - audio.size(1)
    #         audio = torch.nn.functional.pad(audio, (0, padding_size), "constant", 0)
    #     return audio

    def preprocess_audio(audio_pth, sr, target_len, device):
        num_frames = target_len
        audio, audio_sr = torchaudio.load(audio_pth, frame_offset=0, num_frames=num_frames)
        audio = audio.to(device)  # Move to GPU
        if audio_sr != sr:
            resample_transform = torchaudio.transforms.Resample(orig_freq=audio_sr, new_freq=sr)
            audio = resample_transform(audio)
        max_val = audio.abs().max()
        if max_val > 0:  # Avoid division by zero
            audio = audio / max_val
        # Check if padding is necessary (after resampling, the actual number of samples might change)
        actual_len = audio.size(1)
        if actual_len < target_len:
            padding_size = target_len - actual_len
            audio = torch.nn.functional.pad(audio, (0, padding_size), "constant", 0)
        return audio

    device = device
    target_len = int(duration * sr)
    speech_audio = preprocess_audio(speech_pth, sr, target_len, device)
    should_renormalize = False

    if noise_pth is not None:
        noise_audio = preprocess_audio(noise_pth, sr, target_len, device)

        if noise_ratio is not None:
            # Adjust noise level relative to speech
            noise_audio = noise_audio * noise_ratio
            should_renormalize = True

        # Mix speech and noise
        mixed_audio = speech_audio + noise_audio
    else:
        mixed_audio = speech_audio

    if should_renormalize:
        # Re-normalize only if noise has been adjusted and mixed
        max_val = mixed_audio.abs().max()
        if max_val > 0:
            mixed_audio = mixed_audio / max_val

    # Move Window to device
    window = torch.hann_window(win_length).to(device)

    # Compute the STFT of the mixed audio
    stft = torch.stft(mixed_audio,
                      n_fft=n_fft,
                      hop_length=hop_length,
                      win_length=win_length,
                      window=window,
                      center=True,
                      normalized=False,
                      onesided=True,
                      return_complex=True)

    return stft


In [30]:
def compute_active_reactive_intensities(stft, rho=1.21, c=343):
    """
    Compute active and reactive intensity vectors from STFT of 4-channel FOA audio.
    Args:
    - stft: STFT of the FOA audio with shape [4, Frequency Bins, Time Frames].
    - rho: Mean density of air (in kg/m^3).
    - c: Speed of sound in air (in m/s).

    Returns:
    - Ia: Active intensity vector.
    - Ir: Reactive intensity vector.
    """
    # Constants
    three = torch.tensor(3.0, dtype=torch.float, device=stft.device)
    normalization_factor = -1 / (rho * c * torch.sqrt(three))

    # Extract channels
    p = stft[0]  # Pressure (W channel)
    vx = stft[1] * normalization_factor  # Velocity X
    vy = stft[2] * normalization_factor  # Velocity Y
    vz = stft[3] * normalization_factor  # Velocity Z

    # Compute complex conjugate of pressure
    p_star = torch.conj(p)

    # Calculate active and reactive intensity vectors
    Ia_x = torch.real(p_star * vx)
    Ia_y = torch.real(p_star * vy)
    Ia_z = torch.real(p_star * vz)

    Ir_x = torch.imag(p_star * vx)
    Ir_y = torch.imag(p_star * vy)
    Ir_z = torch.imag(p_star * vz)

    # Summing up the components to get total active and reactive intensities
    # Ia = Ia_x + Ia_y + Ia_z  # Total active intensity
    # Ir = Ir_x + Ir_y + Ir_z  # Total reactive intensity

    # Create stack for each channel [3, num_samples, num_frames]
    Ia = torch.stack((Ia_x, Ia_y, Ia_z), dim=0)
    Ir = torch.stack((Ir_x, Ir_y, Ir_z), dim=0)

    return Ia, Ir


In [31]:
class AmbisonicDataset(Dataset):
    def __init__(self, data, config):
        """
        Args:
            data (list of dicts): Each dictionary contains paths and labels for a sample.
            config (dict): Configuration dictionary including sample rate (sr), duration, etc.
        """
        self.data = data
        self.config = config

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]

        # Load and process ambisonic audio
        speech_path = sample["speech_path"]
        noise_path = sample["noise_path"] if "noise_path" in sample and sample["noise_path"] is not None else None

        stft_audio = load_mix_and_stft_foa_audio(
            speech_path,
            noise_pth=noise_path,
            sr=self.config["sr"],
            duration=self.config["duration"],
            noise_ratio=self.config["noise_ratio"],
            n_fft=self.config["n_fft"],
            hop_length=self.config["hop_length"],
            win_length=self.config["win_length"],
            device=self.config["device"]
        )

        # Compute active and reactive intensities
        Ia, Ir = compute_active_reactive_intensities(stft_audio, rho=1.21, c=343)

        # Generate Spike Trains
        spikes_Ia = spikegen.rate(Ia, num_steps=self.config["num_steps"])
        spikes_Ir = spikegen.rate(Ir, num_steps=self.config["num_steps"])

        azimuth = sample['azimuth']
        elevation = sample['elevation']
        label = torch.tensor([azimuth, elevation], dtype=torch.float)

        #print(f"stft_audio.shape: {stft_audio.shape}\nIa.shape:{Ia.shape}, Ir.shape: {Ir.shape}\nspikes_Ia.shape: {spikes_Ia.shape}, spikes_Ir.shape: {spikes_Ir.shape}\nlabel.shape: {label.shape}")

        return spikes_Ia, spikes_Ir, label


In [32]:
# for spikes_active, spikes_reactive, lables in train_loader:
#     print(spikes_active.shape)
#     print(spikes_reactive.shape)
#     print(lables.shape)
#     break

In [33]:
# for spikes_Ia, spikes_Ir, _ in train_loader:
#     spikes_per_step_Ia = spikes_Ia.sum(dim=[2, 3])
#     spikes_per_step_Ir = spikes_Ir.sum(dim=[2, 3])
#     print(f"Spikes in Ia: {spikes_per_step_Ia}")
#     print(f"Spikes in Ir: {spikes_per_step_Ir}")
#     break


In [34]:
class SNN(nn.Module):
    def __init__(self, config):
        super(SNN, self).__init__()
        self.num_steps = config["num_steps"]
        self.beta = config["beta"]
        self.beta_re = config["beta_re"]
        self.beta_out = config["beta_out"]

        # Active Thresholds
        self.a_thr1 = config["a_thresh1"]
        self.a_thr2 = config["a_thresh2"]
        self.a_thr3 = config["a_thresh3"]
        self.a_thr4 = config["a_thresh4"]

        # Reactive Thresholds
        self.re_thr1 = config["re_thresh1"]
        self.re_thr2 = config["re_thresh2"]
        self.re_thr3 = config["re_thresh3"]
        self.re_thr4 = config["re_thresh4"]

        ## FC Layer
        self.beta_out = config["beta_out"]
        self.out_thr1 = config["out_thresh1"]
        self.out_thr2 = config["out_thresh2"]
        self.do = config["do"]

        # Define the active branch with LIF neurons
        self.active_branch = nn.Sequential(
            nn.Conv2d(3, 6, kernel_size=3, padding=1),
            nn.BatchNorm2d(6),
            nn.MaxPool2d(2),
            snn.Leaky(beta=self.beta, threshold=self.a_thr1, init_hidden=True),

            nn.Conv2d(6, 12, kernel_size=3, padding=1),
            nn.BatchNorm2d(12),
            nn.MaxPool2d(2),
            snn.Leaky(beta=self.beta, threshold=self.a_thr2, init_hidden=True),

            nn.Conv2d(12, 24, kernel_size=3, padding=1),
            nn.BatchNorm2d(24),
            nn.MaxPool2d(2),
            snn.Leaky(beta=self.beta, threshold=self.a_thr3, init_hidden=True),

            nn.Conv2d(24, 48, kernel_size=3, padding=1),
            nn.BatchNorm2d(48),
            nn.MaxPool2d(2),
            snn.Leaky(beta=self.beta, threshold=self.a_thr4, init_hidden=True),
        )

        self.reactive_branch = nn.Sequential(
            nn.Conv2d(3, 6, kernel_size=3, padding=1),
            nn.BatchNorm2d(6),
            nn.MaxPool2d(2),
            snn.Leaky(beta=self.beta_re, threshold=self.re_thr1, init_hidden=True),

            nn.Conv2d(6, 12, kernel_size=3, padding=1),
            nn.BatchNorm2d(12),
            nn.MaxPool2d(2),
            snn.Leaky(beta=self.beta_re, threshold=self.re_thr2, init_hidden=True),

            nn.Conv2d(12, 24, kernel_size=3, padding=1),
            nn.BatchNorm2d(24),
            nn.MaxPool2d(2),
            snn.Leaky(beta=self.beta_re, threshold=self.re_thr3, init_hidden=True),

            nn.Conv2d(24, 48, kernel_size=3, padding=1),
            nn.BatchNorm2d(48),
            nn.MaxPool2d(2),
            snn.Leaky(beta=self.beta_re, threshold=self.re_thr4, init_hidden=True),
        )

        self.fc = nn.Sequential(
            nn.Linear(4608, 512),
            snn.Leaky(beta=self.beta_out, threshold=self.out_thr1, init_hidden=True),
            nn.Linear(512, 256),
            snn.Leaky(beta=self.beta_out, threshold=self.out_thr2, init_hidden=True),
            nn.Linear(256, 128),
            nn.Dropout(p=self.do),
            nn.Linear(128, 2)
        )

    # def forward(self, active_input, reactive_input):
    #     active_in_permuted = active_input.permute(1, 0, 2, 3, 4) # [num_step, batch_size, channels, num_samples, num_frames]
    #     reactive_in_permuted = reactive_input.permute(1, 0, 2, 3, 4)

    #     active_outputs = []
    #     reactive_outputs = []

    #     for step in range(active_in_permuted.size(0)):
    #         active_step = self.active_branch(active_in_permuted[step])
    #         reactive_step = self.reactive_branch(reactive_in_permuted[step])

    #         active_outputs.append(active_step)
    #         reactive_outputs.append(reactive_step)

    #     # Stack the outputs across time steps to form tensors of shape [num_steps, batch_size, channels, height, width]
    #     active_stacked = torch.stack(active_outputs, dim=0)
    #     reactive_stacked = torch.stack(reactive_outputs, dim=0)

    #     # Aggregate across time steps, e.g., by taking the mean or sum
    #     active_agg = torch.mean(active_stacked, dim=0)
    #     reactive_agg = torch.mean(reactive_stacked, dim=0)

    #     # Flatten and concatenate the aggregated outputs for the final MLP
    #     active_flat = active_agg.view(active_agg.size(0), -1)
    #     reactive_flat = reactive_agg.view(reactive_agg.size(0), -1)
    #     combined = torch.cat((active_flat, reactive_flat), dim=1)

    #     output = self.fc(combined)

    #     return output

    def forward(self, active_input, reactive_input):
        step_outputs = []
        permute_active = active_input.permute(1, 0, 2, 3, 4)
        permute_reactive = reactive_input.permute(1, 0, 2, 3, 4)

        if permute_active.size(0) != permute_reactive.size(0):
            raise ValueError("The Tme Steps from active and reactive Do NOT Match")

        for step in range(permute_active.size(0)):

            current_active = permute_active[step]
            current_reactive = permute_reactive[step]

            # Process inputs through active and reactive branches
            active_out = self.active_branch(current_active)
            reactive_out = self.reactive_branch(current_reactive)

            # Flatten and combine the outputs
            combined = torch.cat((active_out, reactive_out), dim=1)
            combined = combined.view(combined.size(0), -1)

            fc_out = self.fc(combined)

            step_outputs.append(fc_out)

        tensor_out = torch.stack(step_outputs, dim=0)
        output = torch.mean(tensor_out, dim=0)

        return output




In [35]:
def calc_median_absolute_error(
        t_azimuth,
        t_elevation,
        p_azimuth,
        p_elevation
):
    """
    From Section 4: https://www.isca-archive.org/interspeech_2023/sarabia23_interspeech.html

    Parameters:
    t_azimuth (float): Azimuth angle of the true point in radians.
    t_elevation (float): Elevation angle of the true point in radians.
    p_azimuth (float): Azimuth angle of the predicted point in radians.
    p_elevation (float): Elevation angle of the predicted point in radians.

    Returns:
    float: The angular distance in degrees.
    """

    # Calc angle error in radians
    error_rad = torch.acos(
        torch.sin(t_azimuth) * torch.sin(p_azimuth) +
        torch.cos(t_azimuth) * torch.cos(p_azimuth) * torch.cos(t_elevation - p_elevation)
    )

    # Convert radians to degrees
    error_deg = torch.rad2deg(error_rad)

    # Calc median
    median_error = torch.median(torch.abs(error_deg))

    return median_error




In [36]:
def train(model, train_loader, criterion, optimizer, device):
    model.train()  # Set model to training mode

    train_loss = 0.0
    true_azimuths, true_elevations, pred_azimuths, pred_elevations = [], [], [], []

    for batch_idx, (active_input, reactive_input, labels) in enumerate(tqdm(train_loader, desc="Training")):
        #print(f"Batch {batch_idx+1}/{len(train_loader)}")

        # Move data and labels to the device
        active_input, reactive_input, labels = active_input.to(device), reactive_input.to(device), labels.to(device)

        # Reset branches - required for init_hidden=True
        utils.reset(model.active_branch)
        utils.reset(model.reactive_branch)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(active_input, reactive_input)  # Pass both inputs to the model

        # Compute loss
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Update training loss
        train_loss = train_loss + loss.item() * active_input.size(0)

        true_azimuths.append(labels[:, 0].detach())
        true_elevations.append(labels[:, 1].detach())
        pred_azimuths.append(outputs[:, 0].detach())
        pred_elevations.append(outputs[:, 1].detach())

    # Calculate average loss over the dataset
    train_loss = train_loss / len(train_loader.dataset)

    return train_loss, torch.cat(true_azimuths), torch.cat(true_elevations), torch.cat(pred_azimuths), torch.cat(pred_elevations)


In [37]:
def validate(model, val_loader, criterion, device):
    model.eval()  # Set the model to evaluation mode

    valid_loss = 0.0
    true_azimuths, true_elevations, pred_azimuths, pred_elevations = [], [], [], []

    with torch.no_grad():  # No gradients needed
        for batch_idx, (active_input, reactive_input, labels) in enumerate(tqdm(val_loader, desc="Validation")):
            #print(f"Batch {batch_idx+1}/{len(val_loader)}")

            # Move the inputs and labels to the specified device
            active_input, reactive_input, labels = active_input.to(device), reactive_input.to(device), labels.to(device)

            # Forward pass: compute the model output
            outputs = model(active_input, reactive_input)

            # Compute the loss
            loss = criterion(outputs, labels)
            valid_loss = valid_loss + loss.item() * active_input.size(0)

            true_azimuths.append(labels[:, 0].detach())
            true_elevations.append(labels[:, 1].detach())
            pred_azimuths.append(outputs[:, 0].detach())
            pred_elevations.append(outputs[:, 1].detach())

    # Calculate the average loss over the dataset
    valid_loss = valid_loss / len(val_loader.dataset)

    return valid_loss, torch.cat(true_azimuths), torch.cat(true_elevations), torch.cat(pred_azimuths), torch.cat(pred_elevations)


In [38]:
def test(model, test_loader, criterion, device):
    model.eval()  # Set the model to evaluation mode

    test_loss = 0.0
    true_azimuths, true_elevations, pred_azimuths, pred_elevations = [], [], [], []

    with torch.no_grad():  # No gradients needed during testing
        for batch_idx, (active_input, reactive_input, labels) in enumerate(tqdm(test_loader, desc="Testing")):
            active_input, reactive_input, labels = active_input.to(device), reactive_input.to(device), labels.to(device)

            # Forward pass: compute the model output
            outputs = model(active_input, reactive_input)

            # Compute the loss
            loss = criterion(outputs, labels)
            test_loss = test_loss + loss.item() * active_input.size(0)

            # Optionally, accumulate metrics here
            true_azimuths.append(labels[:, 0].detach())
            true_elevations.append(labels[:, 1].detach())
            pred_azimuths.append(outputs[:, 0].detach())
            pred_elevations.append(outputs[:, 1].detach())

    # Calculate the average loss over the dataset
    test_loss = test_loss / len(test_loader.dataset)

    return test_loss, torch.cat(true_azimuths), torch.cat(true_elevations), torch.cat(pred_azimuths), torch.cat(pred_elevations)


In [39]:
# data = filter_data()
# train_data, val_data, test_data = split_data(data)

# train_dataset = AmbisonicDataset(data=train_data, config=config)
# val_dataset = AmbisonicDataset(data=val_data, config=config)
# test_dataset = AmbisonicDataset(data=test_data, config=config)

# train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True, drop_last=True)
# val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False, drop_last=True)
# test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=False, drop_last=True)

In [40]:
# print(f"len(train_dataset: {len(train_dataset)})")
# print(f"len(val_dataset: {len(val_dataset)})")
# print(f"len(test_dataset: {len(test_dataset)})")

In [41]:
# device = config["device"]

# model = SNN(config)
# if torch.cuda.device_count() > 1:
#     print(f"Let's use {torch.cuda.device_count()} GPUs!")
#     model = torch.nn.DataParallel(model)
# model.to(device)
# criterion = nn.MSELoss()
# optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"])

In [42]:
def objective(trial):
    # Hyperparameters to optimize
    config["a_thresh1"] = trial.suggest_float("a_thresh1", 1, 30)
    config["a_thresh2"] = trial.suggest_float("a_thresh2", 1, 30)
    config["a_thresh3"] = trial.suggest_float("a_thresh3", 1, 30)
    config["a_thresh4"] = trial.suggest_float("a_thresh4", 1, 30)
    config["beta"] = trial.suggest_float("beta", 0.1, 0.9)


    config["re_thresh1"] = trial.suggest_float("re_thresh1", 1, 30)
    config["re_thresh2"] = trial.suggest_float("re_thresh2", 1, 30)
    config["re_thresh3"] = trial.suggest_float("re_thresh3", 1, 30)
    config["re_thresh4"] = trial.suggest_float("re_thresh4", 1, 30)
    config["beta_re"] = trial.suggest_float("beta_re", 0.1, 0.9)

    config["out_thresh1"] = trial.suggest_float("out_thresh1", 1, 30)
    config["out_thresh2"] = trial.suggest_float("out_thresh2", 1, 30)
    config["beta_out"] = trial.suggest_float("beta_re", 0.1, 0.9)

    config["do"] = trial.suggest_float("do", 0.1, 0.9)
    config["num_steps"] = trial.suggest_int("num_steps", 5, 20)
    config["lr"] = trial.suggest_float("lr", 1e-10, 1e-3)



    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SNN(config)
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.device_count()} GPUs.")
    model.to(device)
    criterion = nn.MSELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"])

    data = filter_data()
    train_data, val_data, test_data = split_data(data)

    train_dataset = AmbisonicDataset(data=train_data, config=config)
    val_dataset = AmbisonicDataset(data=val_data, config=config)
    test_dataset = AmbisonicDataset(data=test_data, config=config)

    train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False, drop_last=True)
    test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=False, drop_last=True)

    for epoch in range(config["num_epochs"]):
        train(model, train_loader, criterion, optimizer, device)
        valid_epoch_loss, valid_true_azimuth, valid_true_elevation, valid_pred_azimuth, valid_pred_elevation = validate(model, val_loader, criterion, device)
        valid_angle_error = calc_median_absolute_error(valid_true_azimuth, valid_true_elevation, valid_pred_azimuth, valid_pred_elevation)

    return valid_angle_error.item()  # Optuna minimizes this value

In [43]:
def run_study():
    study_name = "snn_study"
    storage_name = "sqlite:///{}.db".format(study_name)

    study = optuna.create_study(study_name=study_name, storage=storage_name, load_if_exists=True, direction="minimize")
    study.optimize(objective, n_trials=10)

    # Save the study to disk
    study.trials_dataframe().to_csv("study_results.csv")

    print("Study statistics: ")
    print("  Number of finished trials: ", len(study.trials))
    print("  Best trial:")
    trial = study.best_trial
    print("    Value: ", trial.value)
    print("    Params: ")
    for key, value in trial.params.items():
        print(f"      {key}: {value}")



In [None]:
run_study()

[I 2024-02-26 22:05:18,035] A new study created in RDB with name: snn_study
Training: 100%|██████████| 120/120 [3:14:58<00:00, 97.49s/it]
Validation: 100%|██████████| 13/13 [21:13<00:00, 97.96s/it]
Training: 100%|██████████| 120/120 [2:17:30<00:00, 68.75s/it]
Validation:  31%|███       | 4/13 [06:15<14:01, 93.53s/it]

In [None]:
# torch.autograd.set_detect_anomaly(True)

# train_loss = []
# valid_loss = []

# for epoch in range(config["num_epochs"]):
#     print(f"Epoch {epoch+1}/{config['num_epochs']}")

#     train_epoch_loss, train_true_azimuth, train_true_elevation, train_pred_azimuth, train_pred_elevation = train(model, train_loader, criterion, optimizer, device)
#     valid_epoch_loss, valid_true_azimuth, valid_true_elevation, valid_pred_azimuth, valid_pred_elevation = validate(model, val_loader, criterion, device)

#     train_loss.append(train_epoch_loss)
#     valid_loss.append(valid_epoch_loss)

#     train_angle_error = calc_median_absolute_error(train_true_azimuth, train_true_elevation, train_pred_azimuth, train_pred_elevation)
#     valid_angle_error = calc_median_absolute_error(valid_true_azimuth, valid_true_elevation, valid_pred_azimuth, valid_pred_elevation)

#     print(f"Train Loss: {train_epoch_loss:.4f} | Validation Loss: {valid_epoch_loss:.4f}")
#     print(f"Train Angle Error: {train_angle_error:.4f}° | Validation Angle Error: {valid_angle_error:.4f}°\n")

# test_loss, test_true_azimuth, test_true_elevation, test_pred_azimuth, test_pred_elevation = test(model, test_loader, criterion, device)
# test_angle_error = calc_median_absolute_error(test_true_azimuth, test_true_elevation, test_pred_azimuth, train_pred_elevation)

# print(f"Test Loss: {test_loss:.4f}")
# print(f"Test Angle Error: {test_angle_error:.4f}\n")