## Imports

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

import torchaudio
from torchaudio.transforms import Resample

import snntorch as snn
from snntorch import spikegen
from snntorch import utils

from sklearn.model_selection import train_test_split

import os

import pandas as pd

import optuna

from tqdm import tqdm

## Config

In [2]:
config = {
    # Filter Data
    "metadata_path": "/home/emercad3/sound_localization/data/metadata.parquet",
    "speech_path": "/home/emercad3/sound_localization/data/speech_ambi",
    #"noise_path": "/home/emercad3/sound_localization/data/noise_ambi",
    "noise_path": None,
    "is_lite_version": True,

    # Load Mix and STFT Audio
    "sr": 16_000,
    "duration": 0.5,
    "noise_ratio": None, # Optional (Large #s make noise louder, Small #s make noise quieter)
    "n_fft": 512,
    "hop_length": 128,
    "win_length": 512,

    # Split Data
    "val_size": 0.1,

    # Dataset
    "num_steps": 40,
    "batch_size": 128,

    # SNN
    "do": 0.6,
    "beta": 0.65,
    "beta_re": 0.65,
    "beta_out": 0.65,
        # Active Branch
    "a_thresh1": 10, "a_thresh2": 10, "a_thresh3": 10, "a_thresh4": 10,
        # Reactive Branch
    "re_thresh1": 10, "re_thresh2": 10, "re_thresh3": 10, "re_thresh4": 10,
        # Output Layer
    "out_thresh1": 10, "out_thresh2": 10,

    # Training
    "num_epochs": 50,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
    "lr": 0.0003,
    
    # Optuna
    "n_trials": 50,
    "early_stop_epochs": 5,
}

## Preprocessing

### Filter Data

In [3]:
def filter_data(
        metadata_pth=config["metadata_path"],
        speech_pth=config["speech_path"],
        noise_pth=config["noise_path"],
        lite_version=config["is_lite_version"]
):

    metadata = pd.read_parquet(metadata_pth, engine="pyarrow")
    if lite_version:
        metadata = metadata[metadata["lite_version"] == True]

    data = []
    for _, row in metadata.iterrows():
        sample = {
            "sample_id": row["sample_id"],
            "speech_path": os.path.join(speech_pth, f"{row['sample_id']:06}.flac"),
            "noise_path": os.path.join(noise_pth, f"{row['sample_id']:06}.flac") if noise_pth else None,
            "azimuth": row["speech/azimuth"],
            "elevation": row["speech/elevation"],
            "split": row["split"]
        }
        data.append(sample)

    return data

### Split Data

In [4]:
def split_data(data, val_size=config["val_size"]):
    train_data = [sample for sample in data if sample["split"] == "train"]
    test_data = [sample for sample in data if sample["split"] == "test"]

    train_data, val_data = train_test_split(train_data, test_size=val_size, random_state=42)

    return train_data, val_data, test_data


### Load, Trim/Pad, STFT

In [5]:
def load_mix_and_stft_foa_audio(
        speech_pth,
        noise_pth=None,
        sr=config["sr"],
        duration=config["duration"],
        noise_ratio=config["noise_ratio"],
        n_fft=config["n_fft"],
        hop_length=config["hop_length"],
        win_length=config["win_length"],
        device=config["device"]
):

    def preprocess_audio(audio_pth, sr, target_len, device):
        num_frames = target_len
        audio, audio_sr = torchaudio.load(audio_pth, frame_offset=0, num_frames=num_frames)
        audio = audio.to(device)  # Move to GPU
        if audio_sr != sr:
            resample_transform = torchaudio.transforms.Resample(orig_freq=audio_sr, new_freq=sr)
            audio = resample_transform(audio)
        max_val = audio.abs().max()
        if max_val > 0:  # Avoid division by zero
            audio = audio / max_val
        # Check if padding is necessary (after resampling, the actual number of samples might change)
        actual_len = audio.size(1)
        if actual_len < target_len:
            padding_size = target_len - actual_len
            audio = torch.nn.functional.pad(audio, (0, padding_size), "constant", 0)
        return audio

    device = device
    target_len = int(duration * sr)
    speech_audio = preprocess_audio(speech_pth, sr, target_len, device)
    should_renormalize = False

    if noise_pth is not None:
        noise_audio = preprocess_audio(noise_pth, sr, target_len, device)

        if noise_ratio is not None:
            # Adjust noise level relative to speech
            noise_audio = noise_audio * noise_ratio
            should_renormalize = True

        # Mix speech and noise
        mixed_audio = speech_audio + noise_audio
    else:
        mixed_audio = speech_audio

    if should_renormalize:
        # Re-normalize only if noise has been adjusted and mixed
        max_val = mixed_audio.abs().max()
        if max_val > 0:
            mixed_audio = mixed_audio / max_val

    # Move Window to device
    window = torch.hann_window(win_length).to(device)

    # Compute the STFT of the mixed audio
    stft = torch.stft(mixed_audio,
                      n_fft=n_fft,
                      hop_length=hop_length,
                      win_length=win_length,
                      window=window,
                      center=True,
                      normalized=True,
                      onesided=False,
                      return_complex=True)

    return stft


### Calculate Active and Reactive Intensity

In [6]:
def compute_active_reactive_intensities(stft, rho=1.21, c=343):
    """
    Compute active and reactive intensity vectors from STFT of 4-channel FOA audio.
    Args:
    - stft: STFT of the FOA audio with shape [4, Frequency Bins, Time Frames].
    - rho: Mean density of air (in kg/m^3).
    - c: Speed of sound in air (in m/s).

    Returns:
    - Ia: Active intensity vector.
    - Ir: Reactive intensity vector.
    """
    # Constants
    three = torch.tensor(3.0, dtype=torch.float, device=stft.device)
    normalization_factor = -1 / (rho * c * torch.sqrt(three))

    # Extract channels
    p = stft[0]  # Pressure (W channel)
    vx = stft[1] * normalization_factor  # Velocity X
    vy = stft[2] * normalization_factor  # Velocity Y
    vz = stft[3] * normalization_factor  # Velocity Z

    # Compute complex conjugate of pressure
    p_star = torch.conj(p)

    # Calculate active and reactive intensity vectors
    Ia_x = torch.real(p_star * vx)
    Ia_y = torch.real(p_star * vy)
    Ia_z = torch.real(p_star * vz)

    Ir_x = torch.imag(p_star * vx)
    Ir_y = torch.imag(p_star * vy)
    Ir_z = torch.imag(p_star * vz)

    # Create stack for each channel [3, num_samples, num_frames]
    Ia = torch.stack((Ia_x, Ia_y, Ia_z), dim=0)
    Ir = torch.stack((Ir_x, Ir_y, Ir_z), dim=0)

    return Ia, Ir


## Custom Dataset Class

In [7]:
class AmbisonicDataset(Dataset):
    def __init__(self, data, config):
        """
        Args:
            data (list of dicts): Each dictionary contains paths and labels for a sample.
            config (dict): Configuration dictionary including sample rate (sr), duration, etc.
        """
        self.data = data
        self.config = config

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]

        # Load and process ambisonic audio
        speech_path = sample["speech_path"]
        noise_path = sample["noise_path"] if "noise_path" in sample and sample["noise_path"] is not None else None

        stft_audio = load_mix_and_stft_foa_audio(
            speech_path,
            noise_pth=noise_path,
            sr=self.config["sr"],
            duration=self.config["duration"],
            noise_ratio=self.config["noise_ratio"],
            n_fft=self.config["n_fft"],
            hop_length=self.config["hop_length"],
            win_length=self.config["win_length"],
            device=self.config["device"]
        )

        # Compute active and reactive intensities
        Ia, Ir = compute_active_reactive_intensities(stft_audio, rho=1.21, c=343)

        # Generate Spike Trains
        spikes_Ia = spikegen.rate(Ia, num_steps=self.config["num_steps"])
        spikes_Ir = spikegen.rate(Ir, num_steps=self.config["num_steps"])

        azimuth = sample['azimuth']
        elevation = sample['elevation']
        label = torch.tensor([azimuth, elevation], dtype=torch.float)

        #print(f"stft_audio.shape: {stft_audio.shape}\nIa.shape:{Ia.shape}, Ir.shape: {Ir.shape}\nspikes_Ia.shape: {spikes_Ia.shape}, spikes_Ir.shape: {spikes_Ir.shape}\nlabel.shape: {label.shape}")

        return spikes_Ia, spikes_Ir, label


## Model

In [8]:
class SNN(nn.Module):
    def __init__(self, config):
        super(SNN, self).__init__()
        self.num_steps = config["num_steps"]
        self.beta = config["beta"]
        self.beta_re = config["beta_re"]
        self.beta_out = config["beta_out"]

        # Active Thresholds
        self.a_thr1 = config["a_thresh1"]
        self.a_thr2 = config["a_thresh2"]
        self.a_thr3 = config["a_thresh3"]
        self.a_thr4 = config["a_thresh4"]

        # Reactive Thresholds
        self.re_thr1 = config["re_thresh1"]
        self.re_thr2 = config["re_thresh2"]
        self.re_thr3 = config["re_thresh3"]
        self.re_thr4 = config["re_thresh4"]

        ## FC Layer
        self.beta_out = config["beta_out"]
        self.out_thr1 = config["out_thresh1"]
        self.out_thr2 = config["out_thresh2"]
        self.do = config["do"]

        # Define the active branch with LIF neurons
        self.active_branch = nn.Sequential(
            nn.Conv2d(3, 6, kernel_size=3),
            nn.BatchNorm2d(6),
            nn.MaxPool2d(2),
            snn.Leaky(beta=self.beta, threshold=self.a_thr1, init_hidden=True),

            nn.Conv2d(6, 12, kernel_size=3),
            nn.BatchNorm2d(12),
            nn.MaxPool2d(2),
            snn.Leaky(beta=self.beta, threshold=self.a_thr2, init_hidden=True),

            nn.Conv2d(12, 24, kernel_size=3),
            nn.BatchNorm2d(24),
            nn.MaxPool2d(2),
            snn.Leaky(beta=self.beta, threshold=self.a_thr3, init_hidden=True),

            nn.Conv2d(24, 48, kernel_size=3),
            nn.BatchNorm2d(48),
            nn.MaxPool2d(2),
            snn.Leaky(beta=self.beta, threshold=self.a_thr4, init_hidden=True),
        )

        self.reactive_branch = nn.Sequential(
            nn.Conv2d(3, 6, kernel_size=3),
            nn.BatchNorm2d(6),
            nn.MaxPool2d(2),
            snn.Leaky(beta=self.beta_re, threshold=self.re_thr1, init_hidden=True),

            nn.Conv2d(6, 12, kernel_size=3),
            nn.BatchNorm2d(12),
            nn.MaxPool2d(2),
            snn.Leaky(beta=self.beta_re, threshold=self.re_thr2, init_hidden=True),

            nn.Conv2d(12, 24, kernel_size=3),
            nn.BatchNorm2d(24),
            nn.MaxPool2d(2),
            snn.Leaky(beta=self.beta_re, threshold=self.re_thr3, init_hidden=True),

            nn.Conv2d(24, 48, kernel_size=3),
            nn.BatchNorm2d(48),
            nn.MaxPool2d(2),
            snn.Leaky(beta=self.beta_re, threshold=self.re_thr4, init_hidden=True),
        )

        self.fc = nn.Sequential(
            nn.Linear(5760, 512), # use 4608 when stft onesided=True
            snn.Leaky(beta=self.beta_out, threshold=self.out_thr1, init_hidden=True),
            nn.Linear(512, 256),
            snn.Leaky(beta=self.beta_out, threshold=self.out_thr2, init_hidden=True),
            nn.Linear(256, 128),
            nn.Dropout(p=self.do),
            nn.Linear(128, 2)
        )

    # def forward(self, active_input, reactive_input):
    #     active_in_permuted = active_input.permute(1, 0, 2, 3, 4) # [num_step, batch_size, channels, num_samples, num_frames]
    #     reactive_in_permuted = reactive_input.permute(1, 0, 2, 3, 4)

    #     active_outputs = []
    #     reactive_outputs = []

    #     for step in range(active_in_permuted.size(0)):
    #         active_step = self.active_branch(active_in_permuted[step])
    #         reactive_step = self.reactive_branch(reactive_in_permuted[step])

    #         active_outputs.append(active_step)
    #         reactive_outputs.append(reactive_step)

    #     # Stack the outputs across time steps to form tensors of shape [num_steps, batch_size, channels, height, width]
    #     active_stacked = torch.stack(active_outputs, dim=0)
    #     reactive_stacked = torch.stack(reactive_outputs, dim=0)

    #     # Aggregate across time steps, e.g., by taking the mean or sum
    #     active_agg = torch.mean(active_stacked, dim=0)
    #     reactive_agg = torch.mean(reactive_stacked, dim=0)

    #     # Flatten and concatenate the aggregated outputs for the final MLP
    #     active_flat = active_agg.view(active_agg.size(0), -1)
    #     reactive_flat = reactive_agg.view(reactive_agg.size(0), -1)
    #     combined = torch.cat((active_flat, reactive_flat), dim=1)

    #     output = self.fc(combined)

    #     return output

    def forward(self, active_input, reactive_input):
        step_outputs = []
        permute_active = active_input.permute(1, 0, 2, 3, 4)
        permute_reactive = reactive_input.permute(1, 0, 2, 3, 4)

        if permute_active.size(0) != permute_reactive.size(0):
            raise ValueError("The Tme Steps from active and reactive Do NOT Match")

        for step in range(permute_active.size(0)):

            current_active = permute_active[step]
            current_reactive = permute_reactive[step]

            # Process inputs through active and reactive branches
            active_out = self.active_branch(current_active)
            reactive_out = self.reactive_branch(current_reactive)

            # Flatten and combine the outputs
            combined = torch.cat((active_out, reactive_out), dim=1)
            combined = combined.view(combined.size(0), -1)

            fc_out = self.fc(combined)

            step_outputs.append(fc_out)

        tensor_out = torch.stack(step_outputs, dim=0)
        output = torch.mean(tensor_out, dim=0)

        return output




## Util Functions

In [9]:
def calc_median_absolute_error(t_azimuth, t_elevation, p_azimuth, p_elevation):
    """
    Calculate the median absolute error of the angular distance between the true
    and predicted azimuth and elevation angles.

    Parameters:
    t_azimuth (tensor): Azimuth angles of the true points in radians.
    t_elevation (tensor): Elevation angles of the true points in radians.
    p_azimuth (tensor): Azimuth angles of the predicted points in radians.
    p_elevation (tensor): Elevation angles of the predicted points in radians.

    Returns:
    tensor: The median angular distance in degrees.
    """

    # Calculate the cosine of the angular distance
    cosine_of_angle = (
        torch.sin(t_azimuth) * torch.sin(p_azimuth) +
        torch.cos(t_azimuth) * torch.cos(p_azimuth) * torch.cos(t_elevation - p_elevation)
    )

    # Clamp the cosine of the angle to the range [-1, 1] to avoid errors due to numerical instability
    cosine_of_angle = torch.clamp(cosine_of_angle, -1, 1)

    # Calculate the angular distance in radians
    error_rad = torch.acos(cosine_of_angle)

    # Convert the angular distance from radians to degrees
    error_deg = torch.rad2deg(error_rad)

    # Calculate the median of the absolute errors in degrees
    median_error = torch.median(torch.abs(error_deg))

    return median_error


## Training, Validation and Testing Functions

### Training Function

In [10]:
# def train(model, train_loader, criterion, optimizer, device):
#     model.train()  # Set model to training mode

#     train_loss = 0.0
#     true_azimuths, true_elevations, pred_azimuths, pred_elevations = [], [], [], []

#     for batch_idx, (active_input, reactive_input, labels) in enumerate(tqdm(train_loader, desc="Training")):
#         #print(f"Batch {batch_idx+1}/{len(train_loader)}")

#         # Move data and labels to the device
#         active_input, reactive_input, labels = active_input.to(device), reactive_input.to(device), labels.to(device)

#         # Reset branches - required for init_hidden=True
#         utils.reset(model.active_branch)
#         utils.reset(model.reactive_branch)

#         # Zero the parameter gradients
#         optimizer.zero_grad()

#         # Forward pass
#         outputs = model(active_input, reactive_input)  # Pass both inputs to the model

#         # Compute loss
#         loss = criterion(outputs, labels)

#         # Backward pass and optimize
#         loss.backward()
#         optimizer.step()

#         # Update training loss
#         train_loss = train_loss + loss.item() * active_input.size(0)

#         true_azimuths.append(labels[:, 0].detach())
#         true_elevations.append(labels[:, 1].detach())
#         pred_azimuths.append(outputs[:, 0].detach())
#         pred_elevations.append(outputs[:, 1].detach())

#     # Calculate average loss over the dataset
#     train_loss = train_loss / len(train_loader.dataset)

#     return train_loss, torch.cat(true_azimuths), torch.cat(true_elevations), torch.cat(pred_azimuths), torch.cat(pred_elevations)


In [11]:
def train(model, train_loader, criterion, optimizer, device):
    model.train()  # Set model to training mode

    train_loss = 0.0
    true_azimuths, true_elevations, pred_azimuths, pred_elevations = [], [], [], []

    for batch_idx, (active_input, reactive_input, labels) in enumerate(train_loader):
        # Move data and labels to the device
        active_input, reactive_input, labels = active_input.to(device), reactive_input.to(device), labels.to(device)

        # Check if model is wrapped in DataParallel and access the original model for reset
        if isinstance(model, nn.DataParallel):
            utils.reset(model.module.active_branch)
            utils.reset(model.module.reactive_branch)
        else:
            utils.reset(model.active_branch)
            utils.reset(model.reactive_branch)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(active_input, reactive_input)  # Pass both inputs to the model

        # Compute loss
        loss = criterion(outputs, labels)

        # Backward pass and optimize
        loss.backward()
        optimizer.step()

        # Update training loss
        train_loss += loss.item() * active_input.size(0)

        true_azimuths.append(labels[:, 0].detach())
        true_elevations.append(labels[:, 1].detach())
        pred_azimuths.append(outputs[:, 0].detach())
        pred_elevations.append(outputs[:, 1].detach())

    # Calculate average loss over the dataset
    train_loss = train_loss / len(train_loader.dataset)

    return train_loss, torch.cat(true_azimuths), torch.cat(true_elevations), torch.cat(pred_azimuths), torch.cat(pred_elevations)


### Validation Function

In [12]:
def validate(model, val_loader, criterion, device):
    model.eval()  # Set the model to evaluation mode

    valid_loss = 0.0
    true_azimuths, true_elevations, pred_azimuths, pred_elevations = [], [], [], []

    with torch.no_grad():  # No gradients needed
        for batch_idx, (active_input, reactive_input, labels) in enumerate(val_loader):
            #print(f"Batch {batch_idx+1}/{len(val_loader)}")

            # Move the inputs and labels to the specified device
            active_input, reactive_input, labels = active_input.to(device), reactive_input.to(device), labels.to(device)

            # Forward pass: compute the model output
            outputs = model(active_input, reactive_input)

            # Compute the loss
            loss = criterion(outputs, labels)
            valid_loss = valid_loss + loss.item() * active_input.size(0)

            true_azimuths.append(labels[:, 0].detach())
            true_elevations.append(labels[:, 1].detach())
            pred_azimuths.append(outputs[:, 0].detach())
            pred_elevations.append(outputs[:, 1].detach())

    # Calculate the average loss over the dataset
    valid_loss = valid_loss / len(val_loader.dataset)

    return valid_loss, torch.cat(true_azimuths), torch.cat(true_elevations), torch.cat(pred_azimuths), torch.cat(pred_elevations)


### Testing Function

In [13]:
def test(model, test_loader, criterion, device):
    model.eval()  # Set the model to evaluation mode

    test_loss = 0.0
    true_azimuths, true_elevations, pred_azimuths, pred_elevations = [], [], [], []

    with torch.no_grad():  # No gradients needed during testing
        for batch_idx, (active_input, reactive_input, labels) in enumerate(tqdm(test_loader, desc="Testing")):
            active_input, reactive_input, labels = active_input.to(device), reactive_input.to(device), labels.to(device)

            # Forward pass: compute the model output
            outputs = model(active_input, reactive_input)

            # Compute the loss
            loss = criterion(outputs, labels)
            test_loss = test_loss + loss.item() * active_input.size(0)

            # Optionally, accumulate metrics here
            true_azimuths.append(labels[:, 0].detach())
            true_elevations.append(labels[:, 1].detach())
            pred_azimuths.append(outputs[:, 0].detach())
            pred_elevations.append(outputs[:, 1].detach())

    # Calculate the average loss over the dataset
    test_loss = test_loss / len(test_loader.dataset)

    return test_loss, torch.cat(true_azimuths), torch.cat(true_elevations), torch.cat(pred_azimuths), torch.cat(pred_elevations)


## DataLoaders, Training setup, Training Loop

### DataLoaders

In [14]:
# data = filter_data()
# train_data, val_data, test_data = split_data(data)

# train_dataset = AmbisonicDataset(data=train_data, config=config)
# val_dataset = AmbisonicDataset(data=val_data, config=config)
# test_dataset = AmbisonicDataset(data=test_data, config=config)

# train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True, drop_last=True)
# val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False, drop_last=True)
# test_loader = DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=False, drop_last=True)

### Training Setup

In [15]:
# device = config["device"]

# model = SNN(config)
# if torch.cuda.device_count() > 1:
#     print(f"Using {torch.cuda.device_count()} GPUs!")
#     model = torch.nn.DataParallel(model)
# model.to(device)
# criterion = nn.MSELoss()
# optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"])

### Training Loop

In [16]:
# torch.autograd.set_detect_anomaly(True)

# train_loss = []
# valid_loss = []

# for epoch in tqdm(range(config["num_epochs"]), desc="Epochs"):

#     train_epoch_loss, train_true_azimuth, train_true_elevation, train_pred_azimuth, train_pred_elevation = train(model, train_loader, criterion, optimizer, device)
#     valid_epoch_loss, valid_true_azimuth, valid_true_elevation, valid_pred_azimuth, valid_pred_elevation = validate(model, val_loader, criterion, device)

#     train_loss.append(train_epoch_loss)
#     valid_loss.append(valid_epoch_loss)

#     train_angle_error = calc_median_absolute_error(train_true_azimuth, train_true_elevation, train_pred_azimuth, train_pred_elevation)
#     valid_angle_error = calc_median_absolute_error(valid_true_azimuth, valid_true_elevation, valid_pred_azimuth, valid_pred_elevation)

#     print(f"Train Loss: {train_epoch_loss:.4f} | Validation Loss: {valid_epoch_loss:.4f}")
#     print(f"Train Angle Error: {train_angle_error:.4f}° | Validation Angle Error: {valid_angle_error:.4f}°\n")

# test_loss, test_true_azimuth, test_true_elevation, test_pred_azimuth, test_pred_elevation = test(model, test_loader, criterion, device)
# test_angle_error = calc_median_absolute_error(test_true_azimuth, test_true_elevation, test_pred_azimuth, test_pred_elevation)

# print(f"Test Loss: {test_loss:.4f}")
# print(f"Test Angle Error: {test_angle_error:.4f}\n")

## Hypyerparameter Tuning

### Objective Function

In [17]:
# def objective(trial):
#     # Hyperparameters to optimize
#     config["a_thresh1"] = trial.suggest_float("a_thresh1", 1, 50)
#     config["a_thresh2"] = trial.suggest_float("a_thresh2", 1, 50)
#     config["a_thresh3"] = trial.suggest_float("a_thresh3", 1, 50)
#     config["a_thresh4"] = trial.suggest_float("a_thresh4", 1, 50)
#     config["beta"] = trial.suggest_float("beta", 0.1, 0.9)


#     config["re_thresh1"] = trial.suggest_float("re_thresh1", 1, 50)
#     config["re_thresh2"] = trial.suggest_float("re_thresh2", 1, 50)
#     config["re_thresh3"] = trial.suggest_float("re_thresh3", 1, 50)
#     config["re_thresh4"] = trial.suggest_float("re_thresh4", 1, 50)
#     config["beta_re"] = trial.suggest_float("beta_re", 0.1, 0.9)

#     config["out_thresh1"] = trial.suggest_float("out_thresh1", 1, 50)
#     config["out_thresh2"] = trial.suggest_float("out_thresh2", 1, 50)
#     config["beta_out"] = trial.suggest_float("beta_out", 0.1, 0.9)

#     config["do"] = trial.suggest_float("do", 0.01, 0.9)
#     config["num_steps"] = trial.suggest_int("num_steps", 5, 25)
#     config["lr"] = trial.suggest_float("lr", 1e-10, 1e-3)



#     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#     model = SNN(config)
#     if torch.cuda.device_count() > 1:
#         print(f"Using {torch.cuda.device_count()} GPUs.")
#         model = nn.DataParallel(model)

#     model.to(device)
#     criterion = nn.MSELoss()
#     optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"])

#     data = filter_data()
#     train_data, val_data, _ = split_data(data)

#     train_dataset = AmbisonicDataset(data=train_data, config=config)
#     val_dataset = AmbisonicDataset(data=val_data, config=config)
    

#     train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True, drop_last=True)
#     val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False, drop_last=True)

#     for epoch in tqdm(range(config["num_epochs"])):
#         train(model, train_loader, criterion, optimizer, device)
#         val_loss, val_true_azimuth, val_true_elevation, val_pred_azimuth, val_pred_elevation = validate(model, val_loader, criterion, device)
        
#         val_error_deg = calc_median_absolute_error(val_true_azimuth, val_true_elevation, val_pred_azimuth, val_pred_elevation)

#     return val_error_deg.item()  # Optuna minimizes this value

### Objective Function With Early Stopping

In [18]:
def objective(trial):
    # Hyperparameters to optimize
    config["a_thresh1"] = trial.suggest_float("a_thresh1", 1, 50)
    config["a_thresh2"] = trial.suggest_float("a_thresh2", 1, 50)
    config["a_thresh3"] = trial.suggest_float("a_thresh3", 1, 50)
    config["a_thresh4"] = trial.suggest_float("a_thresh4", 1, 50)
    config["beta"] = trial.suggest_float("beta", 0.1, 0.9)

    config["re_thresh1"] = trial.suggest_float("re_thresh1", 1, 50)
    config["re_thresh2"] = trial.suggest_float("re_thresh2", 1, 50)
    config["re_thresh3"] = trial.suggest_float("re_thresh3", 1, 50)
    config["re_thresh4"] = trial.suggest_float("re_thresh4", 1, 50)
    config["beta_re"] = trial.suggest_float("beta_re", 0.1, 0.9)

    config["out_thresh1"] = trial.suggest_float("out_thresh1", 1, 50)
    config["out_thresh2"] = trial.suggest_float("out_thresh2", 1, 50)
    config["beta_out"] = trial.suggest_float("beta_out", 0.1, 0.9)

    config["do"] = trial.suggest_float("do", 0.01, 0.9)
    # config["num_steps"] = trial.suggest_int("num_steps", 5, 40) # 23 is the max num_steps for 2 GPUs
    config["lr"] = trial.suggest_float("lr", 1e-10, 1e-3)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = SNN(config)
    if torch.cuda.device_count() > 1:
        print(f"Using {torch.cuda.device_count()} GPUs.")
        model = nn.DataParallel(model)

    model.to(device)
    criterion = nn.MSELoss()
    optimizer = torch.optim.AdamW(model.parameters(), lr=config["lr"])

    data = filter_data()
    train_data, val_data, _ = split_data(data)

    train_dataset = AmbisonicDataset(data=train_data, config=config)
    val_dataset = AmbisonicDataset(data=val_data, config=config)

    train_loader = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=config["batch_size"], shuffle=False, drop_last=True)

    best_val_error_deg = float('inf')
    epochs_no_improve = 0
    early_stop_epochs = config["early_stop_epochs"]

    for epoch in tqdm(range(config["num_epochs"]), desc="Epochs"):
        train(model, train_loader, criterion, optimizer, device)

        _, val_true_azimuth, val_true_elevation, val_pred_azimuth, val_pred_elevation = validate(model, val_loader, criterion, device)
        val_error_deg = calc_median_absolute_error(val_true_azimuth, val_true_elevation, val_pred_azimuth, val_pred_elevation)

        if val_error_deg < best_val_error_deg:
            best_val_error_deg = val_error_deg
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        if epochs_no_improve == early_stop_epochs:
            print("Early stopping triggered.")
            break

    return best_val_error_deg.item()  # Return the best validation angle error



### Run Study Function

In [19]:
def run_study():
    study_name = "snn_study"
    storage_name = "sqlite:///{}.db".format(study_name)

    study = optuna.create_study(study_name=study_name, storage=storage_name, load_if_exists=True, direction="minimize")
    study.optimize(objective, n_trials=config["n_trials"])

    # Save the study to disk
    study.trials_dataframe().to_csv("study_results.csv")

    print("Study statistics: ")
    print("  Number of finished trials: ", len(study.trials))
    print("  Best trial:")
    trial = study.best_trial
    print("    Value: ", trial.value)
    print("    Params: ")
    for key, value in trial.params.items():
        print(f"      {key}: {value}")



In [20]:
run_study()

[I 2024-03-09 22:13:27,068] A new study created in RDB with name: snn_study


Using 4 GPUs.


Epochs:  46%|▍| 23/5
[I 2024-03-09 23:53:29,382] Trial 0 finished with value: 46.8217658996582 and parameters: {'a_thresh1': 24.14804850839552, 'a_thresh2': 8.705196502743302, 'a_thresh3': 25.457568298566656, 'a_thresh4': 38.89868923524451, 'beta': 0.8364181114137388, 're_thresh1': 39.078790415573735, 're_thresh2': 29.914453168403792, 're_thresh3': 17.961092124810666, 're_thresh4': 18.503767452497325, 'beta_re': 0.5621469668615829, 'out_thresh1': 9.66505877531899, 'out_thresh2': 30.215289591871343, 'beta_out': 0.8877286162285448, 'do': 0.01495168788934992, 'lr': 0.0001246923242269125}. Best is trial 0 with value: 46.8217658996582.


Early stopping triggered.
Using 4 GPUs.


Epochs:  12%| | 6/50
[I 2024-03-10 00:22:16,986] Trial 1 finished with value: 87.75468444824219 and parameters: {'a_thresh1': 18.694148755735057, 'a_thresh2': 32.54513253240053, 'a_thresh3': 42.623932633490966, 'a_thresh4': 33.78888180707709, 'beta': 0.8406655387914931, 're_thresh1': 36.23913190595872, 're_thresh2': 10.535205045366169, 're_thresh3': 38.496165659023, 're_thresh4': 37.18975876129634, 'beta_re': 0.8111615637280576, 'out_thresh1': 18.716781253009863, 'out_thresh2': 23.375438396962288, 'beta_out': 0.3676250543616044, 'do': 0.14917582538959862, 'lr': 0.0006091240821915014}. Best is trial 0 with value: 46.8217658996582.


Early stopping triggered.
Using 4 GPUs.


Epochs:  16%|▏| 8/50
[I 2024-03-10 00:59:17,700] Trial 2 finished with value: 87.91520690917969 and parameters: {'a_thresh1': 12.572368725105258, 'a_thresh2': 39.40411255153983, 'a_thresh3': 33.42698010958485, 'a_thresh4': 43.88464097914778, 'beta': 0.17323855855716552, 're_thresh1': 8.084952962405596, 're_thresh2': 28.844085940807165, 're_thresh3': 26.652695880516973, 're_thresh4': 41.76347953368189, 'beta_re': 0.7392256242743976, 'out_thresh1': 27.65460808493841, 'out_thresh2': 14.828292969330567, 'beta_out': 0.3835283451298265, 'do': 0.7288081643067245, 'lr': 0.000634948747727468}. Best is trial 0 with value: 46.8217658996582.


Early stopping triggered.
Using 4 GPUs.


Epochs:  18%|▏| 9/50
[I 2024-03-10 01:40:22,468] Trial 3 finished with value: 87.7423095703125 and parameters: {'a_thresh1': 37.460966879466234, 'a_thresh2': 44.53387146337519, 'a_thresh3': 18.138714243328526, 'a_thresh4': 49.478455533438634, 'beta': 0.36599220862763215, 're_thresh1': 11.705490648843236, 're_thresh2': 14.041904837311035, 're_thresh3': 14.220721716248601, 're_thresh4': 15.392189007504742, 'beta_re': 0.48408622207599084, 'out_thresh1': 16.940393199815812, 'out_thresh2': 2.784012251663486, 'beta_out': 0.4626106288110742, 'do': 0.38211402180861204, 'lr': 0.0005340050813005703}. Best is trial 0 with value: 46.8217658996582.


Early stopping triggered.
Using 4 GPUs.


Epochs:  42%|▍| 21/5
[I 2024-03-10 03:10:47,233] Trial 4 finished with value: 44.83888626098633 and parameters: {'a_thresh1': 19.58794414352232, 'a_thresh2': 30.668943952690064, 'a_thresh3': 8.428338336534651, 'a_thresh4': 22.205160244752875, 'beta': 0.552669420146198, 're_thresh1': 28.694992410731498, 're_thresh2': 44.707881453679306, 're_thresh3': 17.70458303901189, 're_thresh4': 17.842462307009395, 'beta_re': 0.1559826225374984, 'out_thresh1': 20.07041856997057, 'out_thresh2': 7.688074581581245, 'beta_out': 0.6357308126351245, 'do': 0.5853706575329012, 'lr': 0.0007419824830226382}. Best is trial 4 with value: 44.83888626098633.


Early stopping triggered.
Using 4 GPUs.


Epochs:  54%|▌| 27/5
[I 2024-03-10 05:05:38,144] Trial 5 finished with value: 49.071075439453125 and parameters: {'a_thresh1': 26.909726319894542, 'a_thresh2': 18.98858436922156, 'a_thresh3': 41.28996696376394, 'a_thresh4': 43.526139339300066, 'beta': 0.2167454927683674, 're_thresh1': 23.302432597846547, 're_thresh2': 39.660557321297816, 're_thresh3': 25.830462125575583, 're_thresh4': 7.492603152328584, 'beta_re': 0.5017993507777871, 'out_thresh1': 17.402950106910627, 'out_thresh2': 1.9762063382146358, 'beta_out': 0.37143061411398826, 'do': 0.6667822736870025, 'lr': 0.0004998905889294094}. Best is trial 4 with value: 44.83888626098633.


Early stopping triggered.
Using 4 GPUs.


Epochs:  26%|▎| 13/5
[I 2024-03-10 06:02:57,874] Trial 6 finished with value: 87.71710205078125 and parameters: {'a_thresh1': 22.165436443610293, 'a_thresh2': 40.208133697984636, 'a_thresh3': 12.090418113418545, 'a_thresh4': 47.03043460060559, 'beta': 0.41908908192115246, 're_thresh1': 30.36882285502488, 're_thresh2': 26.041784396254165, 're_thresh3': 14.565288273985644, 're_thresh4': 38.71010963242826, 'beta_re': 0.32365194431236455, 'out_thresh1': 40.62845057891607, 'out_thresh2': 22.395075610729904, 'beta_out': 0.8965574146367218, 'do': 0.09135796766712598, 'lr': 0.0006595730836563502}. Best is trial 4 with value: 44.83888626098633.


Early stopping triggered.
Using 4 GPUs.


Epochs:  24%|▏| 12/5
[I 2024-03-10 06:56:21,879] Trial 7 finished with value: 87.7152099609375 and parameters: {'a_thresh1': 47.191561032377166, 'a_thresh2': 49.57143874993244, 'a_thresh3': 44.39970010962288, 'a_thresh4': 5.627410071344547, 'beta': 0.8331716790838909, 're_thresh1': 25.379073334699044, 're_thresh2': 30.475885157859686, 're_thresh3': 31.887250958160298, 're_thresh4': 32.8662983693749, 'beta_re': 0.2405447083072131, 'out_thresh1': 30.177112359382598, 'out_thresh2': 49.611353131953564, 'beta_out': 0.41603858077737055, 'do': 0.7371674187893291, 'lr': 0.000772428907937475}. Best is trial 4 with value: 44.83888626098633.


Early stopping triggered.
Using 4 GPUs.


Epochs:  34%|▎| 17/5
[I 2024-03-10 08:10:08,736] Trial 8 finished with value: 43.63385772705078 and parameters: {'a_thresh1': 24.200347600950114, 'a_thresh2': 46.32642467924991, 'a_thresh3': 36.650658080300246, 'a_thresh4': 19.71176494555868, 'beta': 0.8691181037886607, 're_thresh1': 39.929965718111795, 're_thresh2': 28.08466516460287, 're_thresh3': 11.788136423451956, 're_thresh4': 19.659776077901203, 'beta_re': 0.18525695991378363, 'out_thresh1': 8.945316556418545, 'out_thresh2': 47.03096690361973, 'beta_out': 0.5859184925908621, 'do': 0.4266910698594458, 'lr': 0.000991169819709162}. Best is trial 8 with value: 43.63385772705078.


Early stopping triggered.
Using 4 GPUs.


Epochs:  16%|▏| 8/50
[I 2024-03-10 08:47:05,718] Trial 9 finished with value: 87.93296813964844 and parameters: {'a_thresh1': 24.9981791751783, 'a_thresh2': 25.364535078813528, 'a_thresh3': 10.461819666600679, 'a_thresh4': 35.121462374820595, 'beta': 0.7186791791190006, 're_thresh1': 31.034910671020608, 're_thresh2': 27.64506422417026, 're_thresh3': 48.408146811943965, 're_thresh4': 8.313370612331742, 'beta_re': 0.6076718443920066, 'out_thresh1': 11.753243444870332, 'out_thresh2': 36.26077286694322, 'beta_out': 0.4770502464560583, 'do': 0.22161023895332274, 'lr': 0.0007255005386902611}. Best is trial 8 with value: 43.63385772705078.


Early stopping triggered.
Using 4 GPUs.


Epochs:  24%|▏| 12/5
[I 2024-03-10 09:40:34,311] Trial 10 finished with value: 44.087158203125 and parameters: {'a_thresh1': 4.779988051100613, 'a_thresh2': 1.5307493202825242, 'a_thresh3': 31.538240904735822, 'a_thresh4': 14.646764059533977, 'beta': 0.5796490740941157, 're_thresh1': 45.88182121873736, 're_thresh2': 1.6983758297006304, 're_thresh3': 2.7519894059450962, 're_thresh4': 47.62171320051058, 'beta_re': 0.30676889565630905, 'out_thresh1': 1.609132353506432, 'out_thresh2': 48.64637106308699, 'beta_out': 0.11973585229847339, 'do': 0.4249760377735084, 'lr': 0.0009949127534094647}. Best is trial 8 with value: 43.63385772705078.


Early stopping triggered.
Using 4 GPUs.


Epochs:  30%|▎| 15/5
[I 2024-03-10 10:46:05,023] Trial 11 finished with value: 44.50565719604492 and parameters: {'a_thresh1': 1.7501297300449115, 'a_thresh2': 1.1794862342229244, 'a_thresh3': 32.97541536001435, 'a_thresh4': 16.26202763171532, 'beta': 0.6199364446115921, 're_thresh1': 49.63732412336238, 're_thresh2': 2.592479494687601, 're_thresh3': 2.4366618807727054, 're_thresh4': 27.78005030678579, 'beta_re': 0.3529854681744159, 'out_thresh1': 1.5946400066436963, 'out_thresh2': 47.51058682531974, 'beta_out': 0.10142580218495367, 'do': 0.4233039611616064, 'lr': 0.000998520976312836}. Best is trial 8 with value: 43.63385772705078.


Early stopping triggered.
Using 4 GPUs.


Epochs:  40%|▍| 20/5
[I 2024-03-10 12:12:22,269] Trial 12 finished with value: 43.31108093261719 and parameters: {'a_thresh1': 2.2215283080981543, 'a_thresh2': 15.413678790280017, 'a_thresh3': 31.771089543980573, 'a_thresh4': 14.135468308962125, 'beta': 0.7045797903009885, 're_thresh1': 49.764995863138864, 're_thresh2': 17.68025481979039, 're_thresh3': 2.6198189955959044, 're_thresh4': 25.07728811413309, 'beta_re': 0.11277589604275587, 'out_thresh1': 1.4965684367907546, 'out_thresh2': 40.89253865802889, 'beta_out': 0.10426228116794273, 'do': 0.3472058357760688, 'lr': 0.0009855678780376575}. Best is trial 12 with value: 43.31108093261719.


Early stopping triggered.
Using 4 GPUs.


Epochs:  24%|▏| 12/5
[I 2024-03-10 13:05:34,987] Trial 13 finished with value: 44.46614074707031 and parameters: {'a_thresh1': 33.788092703456655, 'a_thresh2': 16.064654832980196, 'a_thresh3': 49.80464094519537, 'a_thresh4': 5.251225433540068, 'beta': 0.6758501511231252, 're_thresh1': 42.8012304418542, 're_thresh2': 17.25362537101039, 're_thresh3': 7.721064113854732, 're_thresh4': 24.93138091992838, 'beta_re': 0.13059801439305888, 'out_thresh1': 7.477781984715629, 'out_thresh2': 39.64734555338618, 'beta_out': 0.6484484856057803, 'do': 0.30744229457666455, 'lr': 0.0008924954640343709}. Best is trial 12 with value: 43.31108093261719.


Early stopping triggered.
Using 4 GPUs.


Epochs:  20%|▏| 10/5
[I 2024-03-10 13:50:47,648] Trial 14 finished with value: 87.70569610595703 and parameters: {'a_thresh1': 10.40086692875532, 'a_thresh2': 15.438819148194314, 'a_thresh3': 24.936779121370886, 'a_thresh4': 26.09210537993621, 'beta': 0.7313664312683367, 're_thresh1': 49.964877529013506, 're_thresh2': 21.089323290590194, 're_thresh3': 8.882183546100823, 're_thresh4': 1.8883328499419534, 'beta_re': 0.12528301394009034, 'out_thresh1': 47.53634551796412, 'out_thresh2': 39.5738874485886, 'beta_out': 0.25205267975920376, 'do': 0.8806361712681727, 'lr': 0.00037176731394677516}. Best is trial 12 with value: 43.31108093261719.


Early stopping triggered.
Using 4 GPUs.


Epochs: 100%|█| 50/5
[I 2024-03-10 17:15:32,504] Trial 15 finished with value: 34.45215606689453 and parameters: {'a_thresh1': 49.608034142070736, 'a_thresh2': 23.474020768452473, 'a_thresh3': 35.891284104836785, 'a_thresh4': 12.731680295564324, 'beta': 0.8853033577925282, 're_thresh1': 38.16388540501069, 're_thresh2': 37.23104386628324, 're_thresh3': 8.967364821303914, 're_thresh4': 25.09339852812513, 'beta_re': 0.25405605844912327, 'out_thresh1': 6.009454952176458, 'out_thresh2': 31.724859077316292, 'beta_out': 0.6164189066353682, 'do': 0.5315362630460818, 'lr': 0.000861324176788757}. Best is trial 15 with value: 34.45215606689453.


Using 4 GPUs.


Epochs:  26%|▎| 13/5
[I 2024-03-10 18:13:15,955] Trial 16 finished with value: 87.79843139648438 and parameters: {'a_thresh1': 49.60317713599318, 'a_thresh2': 24.62014159095875, 'a_thresh3': 18.098440172475517, 'a_thresh4': 10.874550772946929, 'beta': 0.7494363047074147, 're_thresh1': 19.781142727970362, 're_thresh2': 38.963411903865484, 're_thresh3': 1.457493807349925, 're_thresh4': 28.077741763883065, 'beta_re': 0.3964784618426147, 'out_thresh1': 33.927696924596134, 'out_thresh2': 30.680943247650887, 'beta_out': 0.7371416389565941, 'do': 0.5474394310687647, 'lr': 0.0008935672251284855}. Best is trial 15 with value: 34.45215606689453.


Early stopping triggered.
Using 4 GPUs.


Epochs:  26%|▎| 13/5
[I 2024-03-10 19:10:51,080] Trial 17 finished with value: 43.61077880859375 and parameters: {'a_thresh1': 38.74123818093319, 'a_thresh2': 10.701377328632912, 'a_thresh3': 1.5197066894783475, 'a_thresh4': 1.8034786633153992, 'beta': 0.48595472332449363, 're_thresh1': 35.19981483512364, 're_thresh2': 49.46422797910811, 're_thresh3': 22.070178374864422, 're_thresh4': 24.01269015653902, 'beta_re': 0.23138750196039454, 'out_thresh1': 4.470280513297259, 'out_thresh2': 32.959289141058214, 'beta_out': 0.26566363359078493, 'do': 0.28418728728583653, 'lr': 0.0002637629929249455}. Best is trial 15 with value: 34.45215606689453.


Early stopping triggered.
Using 4 GPUs.


Epochs:  42%|▍| 21/5
[I 2024-03-10 20:41:19,606] Trial 18 finished with value: 44.54331970214844 and parameters: {'a_thresh1': 43.726193583345704, 'a_thresh2': 19.142468838402316, 'a_thresh3': 28.042823753692765, 'a_thresh4': 27.792765620894276, 'beta': 0.7648197695817869, 're_thresh1': 44.311732620938386, 're_thresh2': 35.20959915352219, 're_thresh3': 7.984217845976975, 're_thresh4': 32.100464432384655, 'beta_re': 0.25589766428594524, 'out_thresh1': 13.13827904456206, 'out_thresh2': 16.982511974868224, 'beta_out': 0.7843932801351392, 'do': 0.5199529575365112, 'lr': 0.000877814833401282}. Best is trial 15 with value: 34.45215606689453.


Early stopping triggered.
Using 4 GPUs.


Epochs:  28%|▎| 14/5
[I 2024-03-10 21:43:03,429] Trial 19 finished with value: 87.71772766113281 and parameters: {'a_thresh1': 30.673650331674093, 'a_thresh2': 31.894420609862934, 'a_thresh3': 38.76881360391759, 'a_thresh4': 10.503363221491954, 'beta': 0.27538419330151287, 're_thresh1': 17.326841215139396, 're_thresh2': 8.145457736134272, 're_thresh3': 34.21166745718686, 're_thresh4': 12.535813350021455, 'beta_re': 0.39359265002216715, 'out_thresh1': 23.50447009463123, 'out_thresh2': 41.36610332781504, 'beta_out': 0.5706078871309855, 'do': 0.3166690004716969, 'lr': 0.0008192915808884389}. Best is trial 15 with value: 34.45215606689453.


Early stopping triggered.
Using 4 GPUs.


Epochs:  44%|▍| 22/5
[I 2024-03-10 23:17:32,498] Trial 20 finished with value: 42.510154724121094 and parameters: {'a_thresh1': 12.80343501147184, 'a_thresh2': 8.386430807327844, 'a_thresh3': 20.585460578844724, 'a_thresh4': 10.522739390782384, 'beta': 0.8962866823904059, 're_thresh1': 1.1059941498379686, 're_thresh2': 21.00609672305942, 're_thresh3': 6.193232681413043, 're_thresh4': 21.77118230288342, 'beta_re': 0.10448938894625917, 'out_thresh1': 4.693475453919284, 'out_thresh2': 28.16775360429154, 'beta_out': 0.2030024754974741, 'do': 0.6374770280721956, 'lr': 0.0004009603289223086}. Best is trial 15 with value: 34.45215606689453.


Early stopping triggered.
Using 4 GPUs.


Epochs:  42%|▍| 21/5
[I 2024-03-11 00:47:55,194] Trial 21 finished with value: 42.69572448730469 and parameters: {'a_thresh1': 11.70511072470589, 'a_thresh2': 8.712099797767939, 'a_thresh3': 20.117468751310376, 'a_thresh4': 10.225972310878692, 'beta': 0.8782301606287504, 're_thresh1': 5.9866711023469, 're_thresh2': 22.81767878763677, 're_thresh3': 5.621678055866667, 're_thresh4': 20.52799836882143, 'beta_re': 0.10668832831210102, 'out_thresh1': 5.078729244808299, 'out_thresh2': 27.191947567314863, 'beta_out': 0.19939605616767708, 'do': 0.617242212700601, 'lr': 0.0003675712861153853}. Best is trial 15 with value: 34.45215606689453.


Early stopping triggered.
Using 4 GPUs.


Epochs:  82%|▊| 41/5
[I 2024-03-11 03:40:09,742] Trial 22 finished with value: 42.85721969604492 and parameters: {'a_thresh1': 13.3905039015296, 'a_thresh2': 7.788742728165858, 'a_thresh3': 17.1316089039353, 'a_thresh4': 8.31291568589048, 'beta': 0.8999524333134095, 're_thresh1': 2.6007001115722375, 're_thresh2': 22.789529632757713, 're_thresh3': 8.595448496030468, 're_thresh4': 21.210336891415544, 'beta_re': 0.1949446916923121, 'out_thresh1': 6.439065452064699, 'out_thresh2': 27.93242843241829, 'beta_out': 0.22109602163570308, 'do': 0.6408006393234775, 'lr': 0.0003321187547683428}. Best is trial 15 with value: 34.45215606689453.


Early stopping triggered.
Using 4 GPUs.


Epochs:  50%|▌| 25/5
[I 2024-03-11 05:26:49,654] Trial 23 finished with value: 41.09369659423828 and parameters: {'a_thresh1': 9.673257609785292, 'a_thresh2': 4.97009983904792, 'a_thresh3': 21.703723883449772, 'a_thresh4': 2.352275879426511, 'beta': 0.822353206309521, 're_thresh1': 2.568135008306874, 're_thresh2': 34.05071049494633, 're_thresh3': 20.263496925827617, 're_thresh4': 12.310240771091964, 'beta_re': 0.2772506303338381, 'out_thresh1': 13.587040085252513, 'out_thresh2': 21.23915166446955, 'beta_out': 0.29857823046046517, 'do': 0.8390554351018624, 'lr': 0.0004446542478083647}. Best is trial 15 with value: 34.45215606689453.


Early stopping triggered.
Using 4 GPUs.


Epochs:  56%|▌| 28/5
[I 2024-03-11 07:26:03,918] Trial 24 finished with value: 39.849178314208984 and parameters: {'a_thresh1': 7.69529804886125, 'a_thresh2': 4.864091203641536, 'a_thresh3': 22.043160347763816, 'a_thresh4': 1.899323350988574, 'beta': 0.7887730276489686, 're_thresh1': 1.0883220577704091, 're_thresh2': 33.84370248564919, 're_thresh3': 22.32440956065254, 're_thresh4': 12.731278192806638, 'beta_re': 0.2644770270367679, 'out_thresh1': 12.416020487471009, 'out_thresh2': 17.93878275253946, 'beta_out': 0.31346836610611556, 'do': 0.8955773461631109, 'lr': 0.00045493519676962876}. Best is trial 15 with value: 34.45215606689453.


Early stopping triggered.
Using 4 GPUs.


Epochs:  54%|▌| 27/5
[I 2024-03-11 09:21:04,717] Trial 25 finished with value: 42.624568939208984 and parameters: {'a_thresh1': 7.742249397991662, 'a_thresh2': 5.7556982148105575, 'a_thresh3': 27.35361417749275, 'a_thresh4': 2.4027755231656656, 'beta': 0.7924910084586474, 're_thresh1': 14.845673235362732, 're_thresh2': 36.4513090230654, 're_thresh3': 21.24032291242146, 're_thresh4': 10.257703458039966, 'beta_re': 0.4311893538567403, 'out_thresh1': 12.697935971315278, 'out_thresh2': 17.072774767318514, 'beta_out': 0.30939032266320426, 'do': 0.892596863676414, 'lr': 0.00017825405551474697}. Best is trial 15 with value: 34.45215606689453.


Early stopping triggered.
Using 4 GPUs.


Epochs:  16%|▏| 8/50
[I 2024-03-11 09:58:11,594] Trial 26 finished with value: 87.72759246826172 and parameters: {'a_thresh1': 17.28269529449394, 'a_thresh2': 12.444906454907766, 'a_thresh3': 23.028091132169322, 'a_thresh4': 5.727199362282809, 'beta': 0.6430330319626413, 're_thresh1': 9.646863791543446, 're_thresh2': 34.674872621465774, 're_thresh3': 21.614464431804603, 're_thresh4': 4.059916107304563, 'beta_re': 0.28657464772482555, 'out_thresh1': 23.2887670797899, 'out_thresh2': 12.390261529475406, 'beta_out': 0.3110182162323886, 'do': 0.8274033707435685, 'lr': 0.00048789823299227067}. Best is trial 15 with value: 34.45215606689453.


Early stopping triggered.
Using 4 GPUs.


Epochs:  12%| | 6/50
[I 2024-03-11 10:27:03,257] Trial 27 finished with value: 87.93340301513672 and parameters: {'a_thresh1': 6.267541693878515, 'a_thresh2': 3.510444924678877, 'a_thresh3': 13.66394154353612, 'a_thresh4': 1.6531639085240073, 'beta': 0.7858221632689087, 're_thresh1': 4.288575863341207, 're_thresh2': 43.53345292699481, 're_thresh3': 27.327329607017617, 're_thresh4': 13.994999939010917, 'beta_re': 0.344620237771764, 'out_thresh1': 14.339927057611176, 'out_thresh2': 20.00856286696658, 'beta_out': 0.5344204339834484, 'do': 0.7999339308544238, 'lr': 2.3700393013581836e-05}. Best is trial 15 with value: 34.45215606689453.


Early stopping triggered.
Using 4 GPUs.


Epochs:  48%|▍| 24/5
[I 2024-03-11 12:09:57,500] Trial 28 finished with value: 42.85895919799805 and parameters: {'a_thresh1': 16.10277475302737, 'a_thresh2': 22.17720010761043, 'a_thresh3': 6.127812683395529, 'a_thresh4': 20.48728943177118, 'beta': 0.7965678536357409, 're_thresh1': 13.84205259101915, 're_thresh2': 32.64841235081682, 're_thresh3': 32.32115382662957, 're_thresh4': 6.000118377096172, 'beta_re': 0.6201320169828316, 'out_thresh1': 14.892709104996523, 'out_thresh2': 9.708819392321141, 'beta_out': 0.7142591478880544, 'do': 0.7489020256384038, 'lr': 0.0004339947002328961}. Best is trial 15 with value: 34.45215606689453.


Early stopping triggered.
Using 4 GPUs.


Epochs:  54%|▌| 27/5
[I 2024-03-11 14:05:26,933] Trial 29 finished with value: 44.01004409790039 and parameters: {'a_thresh1': 7.937816176823279, 'a_thresh2': 29.43040136100945, 'a_thresh3': 28.144812303704768, 'a_thresh4': 6.552735638799828, 'beta': 0.6658537114870784, 're_thresh1': 6.853338493809807, 're_thresh2': 42.76143871398238, 're_thresh3': 17.690061875252134, 're_thresh4': 17.57105254371804, 'beta_re': 0.22245958916246183, 'out_thresh1': 9.501668151586273, 'out_thresh2': 23.506230081228587, 'beta_out': 0.4346847649668676, 'do': 0.8126877456357683, 'lr': 0.0002345382878289218}. Best is trial 15 with value: 34.45215606689453.


Early stopping triggered.
Using 4 GPUs.


Epochs:  20%|▏| 10/5
[I 2024-03-11 14:50:51,718] Trial 30 finished with value: 87.77084350585938 and parameters: {'a_thresh1': 28.022100102059937, 'a_thresh2': 4.788081564467162, 'a_thresh3': 23.606044880261837, 'a_thresh4': 16.478017167290048, 'beta': 0.8245609695126273, 're_thresh1': 21.839037073322014, 're_thresh2': 39.005134549057125, 're_thresh3': 14.523207244345848, 're_thresh4': 11.391944095564739, 'beta_re': 0.4612723497322878, 'out_thresh1': 20.94184004810604, 'out_thresh2': 35.23796439179613, 'beta_out': 0.5303532359928302, 'do': 0.48387292381033425, 'lr': 0.0005909736699517914}. Best is trial 15 with value: 34.45215606689453.


Early stopping triggered.
Using 4 GPUs.


Epochs:  24%|▏| 12/5
[I 2024-03-11 15:44:20,730] Trial 31 finished with value: 87.7149429321289 and parameters: {'a_thresh1': 15.881333019516275, 'a_thresh2': 11.031098257864773, 'a_thresh3': 21.839282810736815, 'a_thresh4': 12.633396610524876, 'beta': 0.8984826407173859, 're_thresh1': 1.7099318062641933, 're_thresh2': 31.693794466970964, 're_thresh3': 11.739107299644878, 're_thresh4': 16.31145617282642, 'beta_re': 0.18095915660558906, 'out_thresh1': 10.355416575926114, 'out_thresh2': 29.915260803028467, 'beta_out': 0.15637326903518603, 'do': 0.6698684589664938, 'lr': 0.00044581041003327083}. Best is trial 15 with value: 34.45215606689453.


Early stopping triggered.
Using 4 GPUs.


Epochs:  20%|▏| 10/5
[I 2024-03-11 16:29:42,750] Trial 32 finished with value: 44.96511459350586 and parameters: {'a_thresh1': 10.414920847986487, 'a_thresh2': 7.341098481405486, 'a_thresh3': 16.0566335753521, 'a_thresh4': 3.413078445618808, 'beta': 0.841064248808876, 're_thresh1': 1.0355063800250315, 're_thresh2': 47.80578089776592, 're_thresh3': 23.036782551354644, 're_thresh4': 30.060780234854533, 'beta_re': 0.27809937235216464, 'out_thresh1': 5.282279357190771, 'out_thresh2': 20.055818959444245, 'beta_out': 0.32341820829570883, 'do': 0.8411846561077183, 'lr': 0.0002968352699880563}. Best is trial 15 with value: 34.45215606689453.


Early stopping triggered.
Using 4 GPUs.


Epochs:  12%| | 6/50
[I 2024-03-11 16:58:36,313] Trial 33 finished with value: 87.76870727539062 and parameters: {'a_thresh1': 20.292751656332833, 'a_thresh2': 35.33512567150251, 'a_thresh3': 26.332862044345696, 'a_thresh4': 8.219730202812514, 'beta': 0.10095681630521436, 're_thresh1': 4.754672298033874, 're_thresh2': 32.26643427799211, 're_thresh3': 28.959490281097203, 're_thresh4': 22.210735230500323, 'beta_re': 0.7423892379720081, 'out_thresh1': 8.487430388984256, 'out_thresh2': 24.06708267909972, 'beta_out': 0.2601968477526517, 'do': 0.698623330284654, 'lr': 0.0005711119778006525}. Best is trial 15 with value: 34.45215606689453.


Early stopping triggered.
Using 4 GPUs.


Epochs:  16%|▏| 8/50
[I 2024-03-11 17:35:53,333] Trial 34 finished with value: 87.87390899658203 and parameters: {'a_thresh1': 13.67369134497441, 'a_thresh2': 12.35568176715222, 'a_thresh3': 20.554941985594947, 'a_thresh4': 9.131388130172574, 'beta': 0.8383633618206917, 're_thresh1': 10.756422433116484, 're_thresh2': 35.74936828194322, 're_thresh3': 40.040050974674784, 're_thresh4': 14.546091348334748, 'beta_re': 0.3764687325478694, 'out_thresh1': 15.63361990015957, 'out_thresh2': 26.570356574151866, 'beta_out': 0.18840555744194495, 'do': 0.5660460952432205, 'lr': 0.00041842676151479703}. Best is trial 15 with value: 34.45215606689453.


Early stopping triggered.
Using 4 GPUs.


Epochs:  10%| | 5/50
[I 2024-03-11 18:00:40,524] Trial 35 finished with value: 87.71527099609375 and parameters: {'a_thresh1': 5.040976732141543, 'a_thresh2': 4.957371202750347, 'a_thresh3': 35.820026109678246, 'a_thresh4': 4.168948125600148, 'beta': 0.773098711123522, 're_thresh1': 8.699980492770937, 're_thresh2': 24.531743581485586, 're_thresh3': 11.976311508393534, 're_thresh4': 10.232095632091088, 'beta_re': 0.533372337989437, 'out_thresh1': 18.958589807433277, 'out_thresh2': 20.23873813579569, 'beta_out': 0.33657527960387534, 'do': 0.7716418272734784, 'lr': 0.0006671334540372325}. Best is trial 15 with value: 34.45215606689453.


Early stopping triggered.
Using 4 GPUs.


Epochs:  66%|▋| 33/5
[I 2024-03-11 20:20:44,362] Trial 36 finished with value: 32.09950637817383 and parameters: {'a_thresh1': 38.99096656859945, 'a_thresh2': 9.55354093570784, 'a_thresh3': 29.61925710354054, 'a_thresh4': 1.262112186662491, 'beta': 0.8372435142513912, 're_thresh1': 34.56455694133143, 're_thresh2': 40.996422516232215, 're_thresh3': 17.79099774504679, 're_thresh4': 35.06981317583107, 'beta_re': 0.17041894668337929, 'out_thresh1': 3.4083313097973784, 'out_thresh2': 14.66397286545895, 'beta_out': 0.820485156794482, 'do': 0.5036782978656426, 'lr': 0.0004990367932184232}. Best is trial 36 with value: 32.09950637817383.


Early stopping triggered.
Using 4 GPUs.


Epochs:  26%|▎| 13/5
[I 2024-03-11 21:18:30,252] Trial 37 finished with value: 51.21992874145508 and parameters: {'a_thresh1': 41.26957757573465, 'a_thresh2': 20.125375927639837, 'a_thresh3': 30.39579674262807, 'a_thresh4': 29.98959520254087, 'beta': 0.5011148070073608, 're_thresh1': 36.181895997145915, 're_thresh2': 40.21170482167835, 're_thresh3': 18.88864873101405, 're_thresh4': 36.47955631179465, 'beta_re': 0.8804429863780123, 'out_thresh1': 11.421061314571512, 'out_thresh2': 6.39503798213299, 'beta_out': 0.7786969401029628, 'do': 0.02623755348847534, 'lr': 0.0005073160123852932}. Best is trial 36 with value: 32.09950637817383.


Early stopping triggered.
Using 4 GPUs.


Epochs:  60%|▌| 30/5
[I 2024-03-11 23:26:06,921] Trial 38 finished with value: 39.46993637084961 and parameters: {'a_thresh1': 45.35719982300018, 'a_thresh2': 27.35270341873757, 'a_thresh3': 44.876505245524754, 'a_thresh4': 1.4319808984410778, 'beta': 0.817136155060782, 're_thresh1': 33.65626847992245, 're_thresh2': 45.46952279599137, 're_thresh3': 17.30325936514068, 're_thresh4': 42.325750010634835, 'beta_re': 0.17516598025705826, 'out_thresh1': 27.35497584467963, 'out_thresh2': 14.958737889627967, 'beta_out': 0.8505319866479802, 'do': 0.472235880787672, 'lr': 0.0005544836872509931}. Best is trial 36 with value: 32.09950637817383.


Early stopping triggered.
Using 4 GPUs.


Epochs:  40%|▍| 20/5
[I 2024-03-12 00:52:34,646] Trial 39 finished with value: 46.87090301513672 and parameters: {'a_thresh1': 45.096006068517525, 'a_thresh2': 28.137458171096455, 'a_thresh3': 44.789886802984874, 'a_thresh4': 7.250522659996815, 'beta': 0.3464992433643109, 're_thresh1': 34.03643139224673, 're_thresh2': 46.446276656806766, 're_thresh3': 17.2965270018265, 're_thresh4': 44.821235026546745, 'beta_re': 0.15987445650228574, 'out_thresh1': 28.666225615587248, 'out_thresh2': 14.061926383789134, 'beta_out': 0.8076426897037572, 'do': 0.5075869329478316, 'lr': 0.000688168309742169}. Best is trial 36 with value: 32.09950637817383.


Early stopping triggered.
Using 4 GPUs.


Epochs:  68%|▋| 34/5
[I 2024-03-12 03:17:12,841] Trial 40 finished with value: 41.59700393676758 and parameters: {'a_thresh1': 49.32183998903734, 'a_thresh2': 35.46563418502602, 'a_thresh3': 48.95764712439068, 'a_thresh4': 1.00649986398572, 'beta': 0.7184012555054857, 're_thresh1': 39.646691219718534, 're_thresh2': 41.712286127835306, 're_thresh3': 25.022720464433103, 're_thresh4': 40.132836897850474, 'beta_re': 0.22180790371996328, 'out_thresh1': 40.24385017665759, 'out_thresh2': 11.425145445852804, 'beta_out': 0.8705375269313135, 'do': 0.23833422121337094, 'lr': 0.0006132455039681554}. Best is trial 36 with value: 32.09950637817383.


Early stopping triggered.
Using 4 GPUs.


Epochs:  66%|▋| 33/5
[I 2024-03-12 05:37:36,593] Trial 41 finished with value: 37.479671478271484 and parameters: {'a_thresh1': 43.146982559707936, 'a_thresh2': 27.548632050598883, 'a_thresh3': 41.9004749078835, 'a_thresh4': 3.407188346163583, 'beta': 0.821160243718463, 're_thresh1': 32.41312304747817, 're_thresh2': 45.523594699133945, 're_thresh3': 19.575989541525768, 're_thresh4': 43.39492338401966, 'beta_re': 0.2593010746832306, 'out_thresh1': 31.5978042301923, 'out_thresh2': 16.88518003195665, 'beta_out': 0.8415385784496954, 'do': 0.46856821269077, 'lr': 0.0005440896081910605}. Best is trial 36 with value: 32.09950637817383.


Early stopping triggered.
Using 4 GPUs.


Epochs: 100%|█| 50/5
[I 2024-03-12 09:03:54,082] Trial 42 finished with value: 35.909461975097656 and parameters: {'a_thresh1': 41.581367041765574, 'a_thresh2': 27.963602815201092, 'a_thresh3': 40.599144791555744, 'a_thresh4': 4.146743145992078, 'beta': 0.8540398396663051, 're_thresh1': 28.31588683681043, 're_thresh2': 45.63599393571328, 're_thresh3': 16.181244972488262, 're_thresh4': 44.05691247050478, 'beta_re': 0.19548290334718765, 'out_thresh1': 31.836524309764346, 'out_thresh2': 16.740110414040323, 'beta_out': 0.8556785672540197, 'do': 0.3833783899122988, 'lr': 0.0005082163166903468}. Best is trial 36 with value: 32.09950637817383.


Using 4 GPUs.


Epochs:  34%|▎| 17/5
[I 2024-03-12 10:18:19,078] Trial 43 finished with value: 44.14511489868164 and parameters: {'a_thresh1': 36.20180244176057, 'a_thresh2': 27.294946847791756, 'a_thresh3': 40.936103655738634, 'a_thresh4': 4.032160007659723, 'beta': 0.8401766952377326, 're_thresh1': 29.356872795921767, 're_thresh2': 46.58044188240385, 're_thresh3': 16.627128049246174, 're_thresh4': 43.54061207924587, 'beta_re': 0.18762380126715436, 'out_thresh1': 32.79755391450226, 'out_thresh2': 5.282700456396617, 'beta_out': 0.8499721097357169, 'do': 0.38304006613796737, 'lr': 0.0005551757493547388}. Best is trial 36 with value: 32.09950637817383.


Early stopping triggered.
Using 4 GPUs.


Epochs:  30%|▎| 15/5
[I 2024-03-12 11:24:18,446] Trial 44 finished with value: 45.33736038208008 and parameters: {'a_thresh1': 40.88336433228423, 'a_thresh2': 36.12750777285475, 'a_thresh3': 46.79710357658629, 'a_thresh4': 5.676642599307306, 'beta': 0.8533623585588545, 're_thresh1': 32.824391866672336, 're_thresh2': 45.471441511227454, 're_thresh3': 12.919949188757636, 're_thresh4': 35.50311233458892, 'beta_re': 0.15237997898744632, 'out_thresh1': 35.445788139280324, 'out_thresh2': 14.747298268232223, 'beta_out': 0.837895413238038, 'do': 0.47231837268849175, 'lr': 0.000530400800568672}. Best is trial 36 with value: 32.09950637817383.


Early stopping triggered.
Using 4 GPUs.


Epochs:  72%|▋| 36/5
[I 2024-03-12 13:56:59,397] Trial 45 finished with value: 33.2998046875 and parameters: {'a_thresh1': 46.35637595439344, 'a_thresh2': 24.325660190396917, 'a_thresh3': 36.43418279860869, 'a_thresh4': 12.516166069252646, 'beta': 0.5984690607889351, 're_thresh1': 26.932006458995076, 're_thresh2': 49.916102364594366, 're_thresh3': 15.276852114112556, 're_thresh4': 47.40568633813357, 'beta_re': 0.3112286703685595, 'out_thresh1': 26.946211205602367, 'out_thresh2': 10.03879083627657, 'beta_out': 0.8962058580238874, 'do': 0.3764478474507137, 'lr': 0.0007738146073537622}. Best is trial 36 with value: 32.09950637817383.


Early stopping triggered.
Using 4 GPUs.


Epochs:  46%|▍| 23/5
[I 2024-03-12 15:35:47,359] Trial 46 finished with value: 45.69802474975586 and parameters: {'a_thresh1': 42.13826980107903, 'a_thresh2': 23.559660127598193, 'a_thresh3': 36.26198444708266, 'a_thresh4': 19.44481689636575, 'beta': 0.5453161994978813, 're_thresh1': 25.837315099475717, 're_thresh2': 49.880675799283495, 're_thresh3': 15.104282902935289, 're_thresh4': 49.88219244161923, 'beta_re': 0.31509665855934743, 'out_thresh1': 38.28019602568013, 'out_thresh2': 8.976403264528662, 'beta_out': 0.6791633115303046, 'do': 0.3769985775447732, 'lr': 0.0008073803705156071}. Best is trial 36 with value: 32.09950637817383.


Early stopping triggered.
Using 4 GPUs.


Epochs:  46%|▍| 23/5
[I 2024-03-12 17:14:44,158] Trial 47 finished with value: 43.43788146972656 and parameters: {'a_thresh1': 47.281999669785364, 'a_thresh2': 22.20891766453178, 'a_thresh3': 39.25762339952405, 'a_thresh4': 13.029488675293893, 'beta': 0.6025089679129183, 're_thresh1': 27.834788806784587, 're_thresh2': 41.321576142637625, 're_thresh3': 10.524677736442207, 're_thresh4': 46.63183325258827, 'beta_re': 0.32055718065944055, 'out_thresh1': 30.549286784421817, 'out_thresh2': 3.6656110001545983, 'beta_out': 0.8940690488152928, 'do': 0.4316356640096315, 'lr': 0.0007224862522925233}. Best is trial 36 with value: 32.09950637817383.


Early stopping triggered.
Using 4 GPUs.


Epochs:   8%| | 4/50