In [1]:
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch.nn.functional as F

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import itertools

In [2]:
class FullbandModel(nn.Module):
    def __init__(self, freq_bins, time_bins, hidden_dim, beta=0.9):
        """
        Fullband Model with normalization, spiking neuron layer, and a linear layer.
        Args:
        - freq_bins: Number of frequency bins in the spectrogram.
        - time_bins: Number of time bins in the spectrogram.
        - hidden_dim: Number of hidden units for the spiking neuron layer.
        - beta: Decay parameter for the LIF neuron.
        """
        super(FullbandModel, self).__init__()
        
        self.freq_bins = freq_bins
        self.time_bins = time_bins
        self.input_dim = freq_bins * time_bins  # Flattened input feature size
        self.hidden_dim = hidden_dim

        # Layers
        self.normalization = nn.LayerNorm(self.input_dim)  # Normalize input features
        self.spikingneuron = snn.Leaky(beta=beta)          # Leaky Integrate-and-Fire neuron
        self.linear = nn.Linear(self.input_dim, self.hidden_dim)  # Linear transformation

    def forward(self, x, num_steps=10):
        """
        Forward pass for the FullbandModel with time-stepped spiking neuron dynamics.
        Args:
        - x: Input tensor of shape (batch_size, frequency_bins, time_bins).
        - num_steps: Number of time steps for spiking neuron simulation.
        Returns:
        - spk_rec: Spiking activity across timesteps (num_steps, batch_size, hidden_dim).
        - mem_rec: Membrane potential across timesteps (num_steps, batch_size, hidden_dim).
        """
        batch_size, freq_bins, time_bins = x.shape
        assert freq_bins == self.freq_bins and time_bins == self.time_bins, \
            "Input dimensions must match model initialization dimensions."

        # Flatten the spectrogram into 1D vectors
        x = x.view(batch_size, -1)  # Shape: (batch_size, input_dim)

        # Normalize the input
        x = self.normalization(x)

        # Initialize membrane potentials for the spiking neuron
        mem = torch.zeros((batch_size, self.hidden_dim), dtype=torch.float32, device=x.device)

        # Record spiking activity and membrane potentials
        spk_rec = []
        mem_rec = []

        for step in range(num_steps):
            # Linear transformation
            cur = self.linear(x)

            # Spiking neuron dynamics
            spk, mem = self.spikingneuron(cur, mem)

            # Record outputs
            spk_rec.append(spk)
            mem_rec.append(mem)

        # Stack outputs across timesteps
        spk_rec = torch.stack(spk_rec, dim=0)  # Shape: (num_steps, batch_size, hidden_dim)
        mem_rec = torch.stack(mem_rec, dim=0)  # Shape: (num_steps, batch_size, hidden_dim)

        return spk_rec, mem_rec

In [3]:
class SubbandModel(nn.Module):
    def __init__(self, hidden_dim, num_steps, beta=0.9):
        """
        Subband Model with normalization, spiking neuron layer, and a linear layer.
        Args:
        - hidden_dim: Number of hidden units for the spiking neuron layer.
        - num_steps: Number of timesteps for spiking neuron simulation.
        - beta: Decay parameter for the LIF neuron.
        """
        super(SubbandModel, self).__init__()

        self.hidden_dim = hidden_dim
        self.num_steps = num_steps
        self.spikingneuron = snn.Leaky(beta=beta)  # Leaky Integrate-and-Fire neuron
        self.linear = None  # Linear layer to be initialized dynamically

    def forward(self, x, num_steps=None):
        """
        Forward pass for the SubbandModel with time-stepped spiking neuron dynamics.
        Args:
        - x: Input tensor of shape (batch_size, subband_dim).
        - num_steps: Number of time steps for spiking neuron simulation.
        Returns:
        - spk_rec: Spiking activity across timesteps (num_steps, batch_size, hidden_dim).
        - mem_rec: Membrane potential across timesteps (num_steps, batch_size, hidden_dim).
        """
        if num_steps is None:
            num_steps = self.num_steps

        batch_size, subband_dim = x.shape

        # Initialize normalization and linear layers dynamically
        if not hasattr(self, 'normalization') or self.normalization is None:
            self.normalization = nn.LayerNorm(subband_dim).to(x.device)
        if self.linear is None:
            self.linear = nn.Linear(subband_dim, self.hidden_dim).to(x.device)

        # Normalize input
        x = self.normalization(x)

        # Initialize membrane potentials for the spiking neuron
        mem = torch.zeros((batch_size, self.hidden_dim), dtype=torch.float32, device=x.device)

        # Record the outputs
        spk_rec = []
        mem_rec = []

        for step in range(num_steps):
            cur = self.linear(x)  # Linear transformation
            spk, mem = self.spikingneuron(cur, mem)  # Spiking neuron dynamics
            spk_rec.append(spk)
            mem_rec.append(mem)

        # Stack the recorded values across timesteps
        spk_rec = torch.stack(spk_rec, dim=0)  # Shape: (num_steps, batch_size, hidden_dim)
        mem_rec = torch.stack(mem_rec, dim=0)  # Shape: (num_steps, batch_size, hidden_dim)

        return spk_rec, mem_rec


In [4]:
# def frequency_partition(spectrogram):
#     """
#     Splits the input spectrogram tensor into 3 sub-tensors of different frequencies.

#     Args:
#     - spectrogram: Input tensor of shape (128, 1938) (frequency bins, time bins).

#     Returns:
#     - A tuple containing three sub-tensors:
#         - low_freq: Frequencies from [0, 16).
#         - mid_freq: Frequencies from [16, 64).
#         - high_freq: Frequencies from [64, 128]
#     """
    
#     if spectrogram.shape != (128, 860):
#         raise ValueError(f"Input spectrogram must have shape (128, 860) - ({spectrogram.shape})")

#     low_freq = spectrogram[:16, :]
#     mid_freq = spectrogram[16:64, :]
#     high_freq = spectrogram[64:128, :]

#     return low_freq, mid_freq, high_freq

In [5]:
def frequency_partition(spectrogram, num_subbands):
    """
    Splits the input tensor into subbands along the second dimension.
    Args:
    - spectrogram: Input tensor of shape (batch_size, hidden_dim).
    - num_subbands: Number of subbands to split the hidden_dim into.
    Returns:
    - subbands: List of tensors, each of shape (batch_size, subband_size)
    """
    batch_size, hidden_dim = spectrogram.shape
    subband_size = hidden_dim // num_subbands

    # Split along the hidden_dim axis
    subbands = torch.split(spectrogram, subband_size, dim=1)
    return subbands


def frequency_reconstruct(subbands):
    """
    Reconstructs the full spectrogram from processed subbands.
    Args:
    - subbands: List of tensors, each of shape (batch_size, subband_size, time_bins).
    Returns:
    - reconstructed: Tensor of shape (batch_size, frequency_bins, time_bins).
    """
    # Concatenate the processed subbands along the frequency axis
    reconstructed = torch.cat(subbands, dim=1)
    return reconstructed

In [6]:
class DeepFilterLayer(nn.Module):
    def __init__(self, in_channels=1, num_filters=64, kernel_size=3, dilation_rates=(1, 2, 4)):
        """
        Args:
        in_channels (int): Number of input channels. Default is 1, shouldn't be altered as we only have 1 channel.
        num_filters (int): Number of filters for intermediate convolution layers. Default is 64.
        kernel_size (int): Size of the convolutional kernel. Default is 3.
        dilation_rates (tuple): Dilation rates for stacked dilated convolutions. Default is (1, 2, 4).
        """
        super(DeepFilterLayer, self).__init__()

        self.input_conv = nn.Conv2d(
            in_channels, num_filters, kernel_size=kernel_size, padding=kernel_size // 2
        )

        # Stacked dilated convolutions for temporal context
        self.dilated_convs = nn.ModuleList([
            nn.Conv2d(
                num_filters, num_filters, kernel_size=kernel_size,
                dilation=dilation, padding=((kernel_size - 1) * dilation) // 2
            )
            for dilation in dilation_rates
        ])

        # Aggregation convolution to reduce back to single output channel
        self.output_conv = nn.Conv2d(
            num_filters, in_channels, kernel_size=kernel_size, padding=kernel_size // 2
        )

    def forward(self, x):
        """
        Forward pass of the DeepFilterLayer.

        Args:
            x (Tensor): Input tensor of shape (batch_size, freq_bins, time_bins).

        Returns:
            Tensor: Filtered output of shape (batch_size, freq_bins, time_bins).
        """
        # Ensure the input has a channel dimension
        if x.dim() == 3:  # If input shape is (batch_size, freq_bins, time_bins)
            x = x.unsqueeze(1)  # Add a channel dimension: (batch_size, 1, freq_bins, time_bins)

        # Initial convolution
        x = F.relu(self.input_conv(x))

        # Stacked dilated convolutions with residual connections
        for dilated_conv in self.dilated_convs:
            x = F.relu(dilated_conv(x))

        # Output convolution to reduce back to the input channel count
        x = self.output_conv(x)

        return x.squeeze(1)  # Remove the channel dimension to return (batch_size, freq_bins, time_bins)


In [7]:
class IntegratedModel(nn.Module):
    def __init__(self, freq_bins, time_bins, hidden_dim, num_steps, num_subbands, beta=0.9, num_filters=64):
        """
        Integrated model with FullbandModel, SubbandModels, and per-subband DeepFilteringLayers.
        Args:
        - freq_bins: Number of frequency bins in the spectrogram.
        - time_bins: Number of time bins in the spectrogram.
        - hidden_dim: Number of hidden units for the spiking neuron layer.
        - num_steps: Number of timesteps for spiking neuron simulation.
        - num_subbands: Number of frequency subbands.
        - beta: Decay parameter for the LIF neuron.
        """
        super(IntegratedModel, self).__init__()

        self.freq_bins = freq_bins
        self.time_bins = time_bins
        self.num_subbands = num_subbands

        subband_size = freq_bins // num_subbands

        # Fullband model
        self.fullband_model = FullbandModel(freq_bins, time_bins, hidden_dim, beta)

        # Subband models
        self.subband_models = nn.ModuleList([
            SubbandModel(hidden_dim, num_steps, beta)
            for _ in range(num_subbands)
        ])

        # Per-subband Deep Filtering Layers
        self.deep_filtering_layers = nn.ModuleList([
            DeepFilterLayer(in_channels=1, num_filters=num_filters)
            for _ in range(num_subbands)
        ])

    def forward(self, x, num_steps, clean_time_bins):
        """
        Forward pass through the IntegratedModel.

        Args:
            x (Tensor): Input tensor of shape (batch_size, freq_bins, time_bins).
            num_steps (int): Number of timesteps for spiking neuron simulation.
            clean_time_bins (int): Time bins of the clean tensor for resizing.

        Returns:
            Tensor: Filtered output of shape (batch_size, freq_bins, clean_time_bins).
        """
        # Fullband processing
        fullband_output, _ = self.fullband_model(x, num_steps)

        # Subband processing
        subbands = frequency_partition(fullband_output[-1], self.num_subbands)
        subband_outputs = [
            self.subband_models[i](subband, num_steps=num_steps)[0][-1]
            for i, subband in enumerate(subbands)
        ]

        # Per-subband deep filtering
        filtered_subbands = [
            self.deep_filtering_layers[i](subband_output.unsqueeze(1))  # Add channel dim for DeepFilterLayer
            for i, subband_output in enumerate(subband_outputs)
        ]

        # Concatenate filtered subbands along the frequency dimension
        concatenated_output = torch.cat(filtered_subbands, dim=2)  # Adjust dim for concatenation (C, Freq, Time)

        # Reshape concatenated output to 4D for interpolation
        concatenated_output = concatenated_output.permute(0, 2, 1).unsqueeze(1)  # Shape: (batch, channels=1, freq, time)

        # Interpolation to match clean time bins
        filtered_output = F.interpolate(
            concatenated_output, size=(self.freq_bins, clean_time_bins), mode='bilinear', align_corners=True
        )

        # Remove unnecessary channel dimension
        filtered_output = filtered_output.squeeze(1)  # Shape: (batch_size, freq_bins, clean_time_bins)

        return filtered_output

In [8]:
def si_sdr_loss(clean, enhanced, eps=1e-8):
    """
    Computes the Scale-Invariant Signal-to-Distortion Ratio (SI-SDR) Loss.
    Args:
    - clean: Ground truth clean signal (batch_size, time_steps).
    - enhanced: Enhanced (predicted) signal (batch_size, time_steps).
    - eps: Small value to avoid division by zero.
    Returns:
    - Loss value (negative SI-SDR).
    """
    # Ensure signals have zero mean
    clean = clean - torch.mean(clean, dim=-1, keepdim=True)
    enhanced = enhanced - torch.mean(enhanced, dim=-1, keepdim=True)

    # Compute scaling factor
    scale = torch.sum(clean * enhanced, dim=-1, keepdim=True) / (torch.sum(clean**2, dim=-1, keepdim=True) + eps)

    # Projection of enhanced signal onto clean signal
    projection = scale * clean

    # Compute noise (residual)
    noise = enhanced - projection

    # SI-SDR calculation with stabilization
    numerator = torch.sum(projection**2, dim=-1) + eps
    denominator = torch.sum(noise**2, dim=-1) + eps
    si_sdr = 10 * torch.log10(numerator / denominator)

    # Return negative SI-SDR as the loss
    return -torch.mean(si_sdr)

def composite_loss(filtered_output, clean_tensor, 
                   filtered_output_complex=None, clean_tensor_complex=None, 
                   alpha=0.5, beta=0.3, p=1, target_metric=4.5):
    """
    Composite loss function combining SI-SDR, proxy MetricGAN+ generator loss, and frequency loss.
    Args:
    - filtered_output: Time-domain enhanced signal.
    - clean_tensor: Time-domain clean signal.
    - filtered_output_complex: Complex spectrogram of enhanced signal.
    - clean_tensor_complex: Complex spectrogram of clean signal.
    - alpha: Weight for SI-SDR loss.
    - beta: Weight for proxy MetricGAN+ generator loss.
    - p: Power for dynamic range compression in frequency-domain loss.
    - target_metric: Ideal target metric score (e.g., 4.5 for PESQ-like proxy metric).
    Returns:
    - Combined loss value.
    """
    # SI-SDR Loss
    si_sdr = si_sdr_loss(clean_tensor, filtered_output)

    # Proxy MetricGAN+ Generator Loss
    if filtered_output_complex is not None and clean_tensor_complex is not None:
        def spectrogram_metric_loss(clean_spectrogram, enhanced_spectrogram, p=1):
            magnitude_loss = torch.mean(torch.abs(torch.abs(clean_spectrogram)**p - torch.abs(enhanced_spectrogram)**p))
            phase_loss = torch.mean(torch.abs(torch.angle(clean_spectrogram) - torch.angle(enhanced_spectrogram)))
            return magnitude_loss + phase_loss

        gen_loss = spectrogram_metric_loss(clean_tensor_complex, filtered_output_complex, p=p)
        proxy_metric_loss = (gen_loss - target_metric) ** 2
    else:
        proxy_metric_loss = 0

    # Frequency-Domain Loss
    if filtered_output_complex is not None and clean_tensor_complex is not None:
        freq_loss = torch.mean(
            torch.abs(filtered_output_complex.abs()**p - clean_tensor_complex.abs()**p)
        ) + torch.mean(
            torch.abs(filtered_output_complex - clean_tensor_complex)
        )
    else:
        freq_loss = 0

    # Composite Loss
    total_loss = alpha * si_sdr + beta * proxy_metric_loss + freq_loss
    return total_loss

In [9]:
# Resample clean_tensor to match noisy_tensor's time bins
def resample_tensor(clean_tensor, target_time_bins):
    """
    Resamples the clean tensor to match the target time bins using interpolation.
    Args:
        clean_tensor (torch.Tensor): Tensor of shape [batch, freq_bins, time_bins].
        target_time_bins (int): Target number of time bins.
    Returns:
        torch.Tensor: Resampled tensor of shape [batch, freq_bins, target_time_bins].
    """
    batch, freq_bins, time_bins = clean_tensor.shape
    clean_tensor_resampled = F.interpolate(
        clean_tensor.unsqueeze(1),  # Add a channel dimension
        size=(freq_bins, target_time_bins),  # Resample to target size
        mode="bilinear",
        align_corners=False
    ).squeeze(1)  # Remove the added channel dimension
    return clean_tensor_resampled


In [10]:
import os
import torch
import numpy as np
from sklearn.model_selection import train_test_split

# Paths to the feature and label directories
feature_dir = "E:/CS541 - Deep Learning/noisy_audio_np"
label_dir = "E:/CS541 - Deep Learning/clean_audio_np"

# Load all feature and label file paths
feature_files = sorted(os.listdir(feature_dir))
label_files = sorted(os.listdir(label_dir))
indices = torch.arange(len(feature_files))

# Ensure that there are the same number of clean and noisy files
assert len(feature_files) == len(label_files), "Mismatch between feature and label file counts!"

# Need to split the data into training and testing - Running an 80/20 split here, lacking validation due to time constraints (It is 6:45 on Wednesday Night!)
train_features, test_features, train_labels, test_labels, train_indices, test_indices = train_test_split(feature_files, label_files, indices, test_size=0.2)

In [11]:
# from tqdm import tqdm  # For tracking progress

# # Model parameters
# freq_bins, time_bins = 128, 860     # Based on the dimensions of your spectrograms
# hidden_dims = [32,64,128,256]       # Hidden dimension for spiking neurons
# num_steps_list = [10,15,25,100]     # Number of timesteps for spiking neurons
# num_subbands_list = 4               # Number of subbands (adjusted for even division)
# betas = [0.1,0.5,0.9]               # Decay parameter for LIF neurons

# # Best subset:
# # 128, 860
# # 32
# # 10
# # 4
# # 0.9

# # Initialize the model
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # Track losses
# average_losses = []

# for hidden_dim in hidden_dims:
#     for num_subbands in num_subbands_list:
#         for num_steps in num_steps_list:
#             for beta in betas:

#                 model = IntegratedModel(freq_bins, time_bins, hidden_dim, num_steps, num_subbands, beta).to(device)
#                 model.train()

#                 # Optimizer
#                 optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

#                 losses = []

#                 # Training loop
#                 for idx in tqdm(range(len(train_features)), desc="Processing files"):
#                     feature_path = os.path.join(feature_dir, train_features[idx])
#                     label_path = os.path.join(label_dir, train_labels[idx])

#                     noisy_tensor = torch.tensor(np.load(feature_path)).unsqueeze(0).to(device)
#                     clean_tensor = torch.tensor(np.load(label_path)).unsqueeze(0).to(device)

#                     clean_tensor_resampled = resample_tensor(clean_tensor, noisy_tensor.shape[2])

#                     # Forward pass
#                     filtered_output = model(noisy_tensor, num_steps, clean_time_bins=clean_tensor.shape[2])

#                     # Compute composite loss
#                     total_loss = composite_loss(
#                         filtered_output, clean_tensor,
#                         filtered_output_complex=None, clean_tensor_complex=None,  # Replace None if using spectrograms
#                         alpha=0.25, beta=0.1, p=1, target_metric=4.5
#                     )
                    
#                     losses.append(total_loss)

#                     # Backpropagation
#                     optimizer.zero_grad()
#                     total_loss.backward()
#                     optimizer.step()

#                 # Print average loss across all files
#                 average_loss = sum(losses) / len(losses)
#                 average_losses.append(average_loss)
#                 print(f"\nAverage Loss across all files: {average_loss:.12f}")

In [12]:
from tqdm import tqdm  # For tracking progress

# Model parameters
freq_bins, time_bins = 128, 860  # Based on the dimensions of your spectrograms
hidden_dim = 32                  # Hidden dimension for hidden layers
num_steps = 10                   # Number of timesteps for spiking neurons
num_subbands = 4                 # Number of subbands (adjusted for even division, ideally wouldve been low/medium/high freq. partition)
beta = 0.9                       # Decay parameter for LIF neurons

epochs = 3                       # Number of epochs to train for

# Initialize the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = IntegratedModel(freq_bins, time_bins, hidden_dim, num_steps, num_subbands, beta).to(device)
model.train()

# Optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

# Track losses
losses = []

for _ in range(epochs):
    # Training loop
    for idx in tqdm(range(len(train_features)), desc="Processing files"):
        feature_path = os.path.join(feature_dir, train_features[idx])
        label_path = os.path.join(label_dir, train_labels[idx])

        noisy_tensor = torch.tensor(np.load(feature_path)).unsqueeze(0).to(device)
        clean_tensor = torch.tensor(np.load(label_path)).unsqueeze(0).to(device)

        clean_tensor_resampled = resample_tensor(clean_tensor, noisy_tensor.shape[2])

        # Forward pass
        filtered_output = model(noisy_tensor, num_steps, clean_time_bins=clean_tensor.shape[2])

        # Compute composite loss
        total_loss = composite_loss(
            filtered_output, clean_tensor,
            filtered_output_complex=None, clean_tensor_complex=None,  # Replace None if using spectrograms
            alpha=0.25, beta=0.1, p=1, target_metric=4.5
        )
        
        losses.append(total_loss)

        # Backpropagation
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()

# Print average loss across all files
average_loss = sum(losses) / len(losses)
print(f"\nAverage Loss across all files: {average_loss:.6f}")

Processing files: 100%|██████████| 6400/6400 [11:06<00:00,  9.61it/s]
Processing files: 100%|██████████| 6400/6400 [10:02<00:00, 10.61it/s]
Processing files:  14%|█▍        | 887/6400 [01:21<08:29, 10.83it/s]


KeyboardInterrupt: 

In [None]:
# from tqdm import tqdm  # For tracking progress

# # Model parameters
# freq_bins, time_bins = 128, 860     # Based on the dimensions of your spectrograms
# hidden_dims = [32,64,128,256]       # Hidden dimension for spiking neurons
# num_steps_list = [10,15,25,100]     # Number of timesteps for spiking neurons
# num_subbands      # Number of subbands
# beta = 0.9                          # Decay parameter for LIF neurons

# # Best subset:
# # 128, 860
# # 32
# # 10
# # 4
# # 0.9

# # Initialize the model
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # Track losses
# average_losses = []

# for hidden_dim in hidden_dims:
#     for num_subbands in num_subbands_list:
#         for num_steps in num_steps_list:

#             model = IntegratedModel(freq_bins, time_bins, hidden_dim, num_steps, num_subbands, beta).to(device)
#             model.train()

#             # Optimizer
#             optimizer = torch.optim.AdamW(model.parameters(), lr=0.001)

#             losses = []

#             # Training loop
#             for idx in tqdm(range(len(train_features)), desc="Processing files"):
#                 feature_path = os.path.join(feature_dir, train_features[idx])
#                 label_path = os.path.join(label_dir, train_labels[idx])

#                 noisy_tensor = torch.tensor(np.load(feature_path)).unsqueeze(0).to(device)
#                 clean_tensor = torch.tensor(np.load(label_path)).unsqueeze(0).to(device)

#                 clean_tensor_resampled = resample_tensor(clean_tensor, noisy_tensor.shape[2])

#                 # Forward pass
#                 filtered_output = model(noisy_tensor, num_steps, clean_time_bins=clean_tensor.shape[2])

#                 # Compute composite loss
#                 total_loss = composite_loss(
#                     filtered_output, clean_tensor,
#                     filtered_output_complex=None, clean_tensor_complex=None,  # Replace None if using spectrograms
#                     alpha=0.25, beta=0.1, p=1, target_metric=4.5
#                 )
                
#                 losses.append(total_loss)

#                 # Backpropagation
#                 optimizer.zero_grad()
#                 total_loss.backward()
#                 optimizer.step()

#             # Print average loss across all files
#             average_loss = sum(losses) / len(losses)
#             average_losses.append(average_loss)
#             print(f"\nAverage Loss across all files: {average_loss:.12f}")

Processing files: 100%|██████████| 6400/6400 [10:03<00:00, 10.61it/s]



Average Loss across all files: 0.000000


Processing files: 100%|██████████| 6400/6400 [14:55<00:00,  7.15it/s]



Average Loss across all files: 0.000000


Processing files: 100%|██████████| 6400/6400 [10:05<00:00, 10.57it/s]



Average Loss across all files: 0.000000


Processing files: 100%|██████████| 6400/6400 [15:17<00:00,  6.98it/s]  



Average Loss across all files: 0.000000


Processing files: 100%|██████████| 6400/6400 [10:33<00:00, 10.11it/s]



Average Loss across all files: 0.000000


Processing files: 100%|██████████| 6400/6400 [16:26<00:00,  6.49it/s]



Average Loss across all files: 0.000000


In [None]:
# for loss in average_losses:
#     print(f"{loss:.12f}")

0.000000002345
0.000000006213
0.000000002729
0.000000005716
0.000000005132
0.000000005178


In [None]:
model.eval()

test_loss = 0
test_set_size = len(test_features)
filtered_spectrograms = []

with torch.no_grad():
    for idx in tqdm(range(len(test_features)), desc="Processing files"):
        feature_path = os.path.join(feature_dir, test_features[idx])
        label_path = os.path.join(label_dir, test_labels[idx])

        noisy_tensor = torch.tensor(np.load(feature_path)).unsqueeze(0).to(device)
        clean_tensor = torch.tensor(np.load(label_path)).unsqueeze(0).to(device)

        clean_tensor_resampled = resample_tensor(clean_tensor, noisy_tensor.shape[2])

        # Forward pass
        filtered_output = model(noisy_tensor, num_steps=10, clean_time_bins=clean_tensor.shape[2])
        filtered_spectrograms.append(filtered_output)

        # Compute composite loss
        loss = composite_loss(
            filtered_output, clean_tensor,
            filtered_output_complex=None, clean_tensor_complex=None,  # Replace None if using spectrograms
            alpha=0.25, beta=0.1, p=1, target_metric=4.5
        )
        test_loss += loss

# Calculate average test loss
avg_test_loss = test_loss / test_set_size
print(f"Average Test Loss: {avg_test_loss:.8f}")