In [9]:
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import itertools

# self.gated_spiking = nn.Sequential(
        #     nn.Linear(input_dim, hidden_dim),
        #     nn.Sigmoid()  # Simulates gating
        # )

In [10]:
class FullbandModel(nn.Module):
    def __init__(self, freq_bins, time_bins, hidden_dim, beta=0.9):
        """
        Fullband Model with normalization, spiking neuron layer, and a linear layer.
        Args:
        - freq_bins: Number of frequency bins in the spectrogram.
        - time_bins: Number of time bins in the spectrogram.
        - hidden_dim: Number of hidden units for the spiking neuron layer.
        - beta: Decay parameter for the LIF neuron.
        """
        super(FullbandModel, self).__init__()
        
        self.freq_bins = freq_bins
        self.time_bins = time_bins
        self.input_dim = freq_bins * time_bins  # Flattened input feature size
        self.hidden_dim = hidden_dim

        # Layers
        self.normalization = nn.LayerNorm(self.input_dim)  # Normalize input features
        self.spikingneuron = snn.Leaky(beta=beta)          # Leaky Integrate-and-Fire neuron
        self.linear = nn.Linear(self.input_dim, self.hidden_dim)  # Linear transformation

    def forward(self, x, num_steps=10):
        """
        Forward pass for the FullbandModel with time-stepped spiking neuron dynamics.
        Args:
        - x: Input tensor of shape (batch_size, frequency_bins, time_bins).
        - num_steps: Number of time steps for spiking neuron simulation.
        Returns:
        - spk_rec: Spiking activity across timesteps (num_steps, batch_size, hidden_dim).
        - mem_rec: Membrane potential across timesteps (num_steps, batch_size, hidden_dim).
        """
        batch_size, freq_bins, time_bins = x.shape
        assert freq_bins == self.freq_bins and time_bins == self.time_bins, \
            "Input dimensions must match model initialization dimensions."

        # Flatten the spectrogram into 1D vectors
        x = x.view(batch_size, -1)  # Shape: (batch_size, input_dim)

        # Normalize the input
        x = self.normalization(x)

        # Initialize membrane potentials for the spiking neuron
        mem = torch.zeros((batch_size, self.hidden_dim), dtype=torch.float32, device=x.device)

        # Record spiking activity and membrane potentials
        spk_rec = []
        mem_rec = []

        for step in range(num_steps):
            # Linear transformation
            cur = self.linear(x)

            # Spiking neuron dynamics
            spk, mem = self.spikingneuron(cur, mem)

            # Record outputs
            spk_rec.append(spk)
            mem_rec.append(mem)

        # Stack outputs across timesteps
        spk_rec = torch.stack(spk_rec, dim=0)  # Shape: (num_steps, batch_size, hidden_dim)
        mem_rec = torch.stack(mem_rec, dim=0)  # Shape: (num_steps, batch_size, hidden_dim)

        return spk_rec, mem_rec

In [11]:
class SubbandModel(nn.Module):
    def __init__(self, hidden_dim, num_steps, beta=0.9):
        """
        Subband Model with normalization, spiking neuron layer, and a linear layer.
        Args:
        - hidden_dim: Number of hidden units for the spiking neuron layer.
        - num_steps: Number of timesteps for spiking neuron simulation.
        - beta: Decay parameter for the LIF neuron.
        """
        super(SubbandModel, self).__init__()

        self.hidden_dim = hidden_dim
        self.num_steps = num_steps
        self.spikingneuron = snn.Leaky(beta=beta)  # Leaky Integrate-and-Fire neuron
        self.linear = None  # Linear layer to be initialized dynamically

    def forward(self, x, num_steps=None):
        """
        Forward pass for the SubbandModel with time-stepped spiking neuron dynamics.
        Args:
        - x: Input tensor of shape (batch_size, subband_dim).
        - num_steps: Number of time steps for spiking neuron simulation.
        Returns:
        - spk_rec: Spiking activity across timesteps (num_steps, batch_size, hidden_dim).
        - mem_rec: Membrane potential across timesteps (num_steps, batch_size, hidden_dim).
        """
        if num_steps is None:
            num_steps = self.num_steps

        batch_size, subband_dim = x.shape

        # Initialize normalization and linear layers dynamically
        if not hasattr(self, 'normalization') or self.normalization is None:
            self.normalization = nn.LayerNorm(subband_dim).to(x.device)
        if self.linear is None:
            self.linear = nn.Linear(subband_dim, self.hidden_dim).to(x.device)

        # Normalize input
        x = self.normalization(x)

        # Initialize membrane potentials for the spiking neuron
        mem = torch.zeros((batch_size, self.hidden_dim), dtype=torch.float32, device=x.device)

        # Record the outputs
        spk_rec = []
        mem_rec = []

        for step in range(num_steps):
            cur = self.linear(x)  # Linear transformation
            spk, mem = self.spikingneuron(cur, mem)  # Spiking neuron dynamics
            spk_rec.append(spk)
            mem_rec.append(mem)

        # Stack the recorded values across timesteps
        spk_rec = torch.stack(spk_rec, dim=0)  # Shape: (num_steps, batch_size, hidden_dim)
        mem_rec = torch.stack(mem_rec, dim=0)  # Shape: (num_steps, batch_size, hidden_dim)

        return spk_rec, mem_rec


In [12]:
def frequency_partition(spectrogram, num_subbands):
    """
    Splits the input tensor into subbands along the second dimension.
    Args:
    - spectrogram: Input tensor of shape (batch_size, hidden_dim).
    - num_subbands: Number of subbands to split the hidden_dim into.
    Returns:
    - subbands: List of tensors, each of shape (batch_size, subband_size).
    """
    batch_size, hidden_dim = spectrogram.shape
    subband_size = hidden_dim // num_subbands

    # Split along the hidden_dim axis
    subbands = torch.split(spectrogram, subband_size, dim=1)
    print(f"Number of Subbands (fp func): {len(subbands)}")
    return subbands


def frequency_reconstruct(subbands):
    """
    Reconstructs the full spectrogram from processed subbands.
    Args:
    - subbands: List of tensors, each of shape (batch_size, subband_size, time_bins).
    Returns:
    - reconstructed: Tensor of shape (batch_size, frequency_bins, time_bins).
    """
    # Concatenate the processed subbands along the frequency axis
    reconstructed = torch.cat(subbands, dim=1)
    return reconstructed

In [13]:
# class IntegratedModel(nn.Module):
#     def __init__(self, freq_bins, time_bins, hidden_dim, num_steps, num_subbands, beta=0.9):
#         """
#         Integrated model combining FullbandModel, frequency partitioning, and SubbandModels.
#         Args:
#         - freq_bins: Number of frequency bins in the spectrogram.
#         - time_bins: Number of time bins in the spectrogram.
#         - hidden_dim: Number of hidden units for the spiking neuron layer.
#         - num_steps: Number of timesteps for spiking neuron simulation.
#         - num_subbands: Number of frequency subbands.
#         - beta: Decay parameter for the LIF neuron.
#         """
#         super(IntegratedModel, self).__init__()

#         self.freq_bins = freq_bins
#         self.time_bins = time_bins
#         self.num_subbands = num_subbands

#         subband_size = freq_bins // num_subbands

#         # Fullband model
#         self.fullband_model = FullbandModel(freq_bins, time_bins, hidden_dim, beta)

#         # Subband models
#         self.subband_models = nn.ModuleList([
#             SubbandModel(hidden_dim, num_steps, beta)
#             for _ in range(num_subbands)
#         ])
        
#         print(f"Number of Subband Models: {len(self.subband_models)}")

#     def forward(self, x, num_steps=10):
#         """
#         Forward pass through FullbandModel, frequency partitioning, and SubbandModels.
#         Args:
#         - x: Input tensor of shape (batch_size, frequency_bins, time_bins).
#         - num_steps: Number of time steps for spiking neuron simulation.
#         Returns:
#         - subband_outputs: List of tensors, each corresponding to the output of a SubbandModel.
#         """
#         # Fullband processing
#         fullband_output, _ = self.fullband_model(x, num_steps)

#         # Use the last timestep output for partitioning
#         subbands = frequency_partition(fullband_output[-1], self.num_subbands)
#         print(f"Number of Subbands: {len(subbands)}")

#         # Subband processing
#         subband_outputs = [
#             self.subband_models[i](subband, num_steps=num_steps)[0][-1]
#             for i, subband in enumerate(subbands)
#         ]

#         return subband_outputs


In [14]:
class DeepFilteringLayer(nn.Module):
    def __init__(self, input_dim, output_dim, num_filters=64):
        """
        Deep filtering layer inspired by DeepFilterNet.
        Args:
        - input_dim: Number of input features (e.g., combined subband dimensions).
        - output_dim: Number of output features (e.g., original spectrogram dimensions).
        - num_filters: Number of filters in the convolutional layers.
        """
        super(DeepFilteringLayer, self).__init__()

        self.conv1 = nn.Conv1d(input_dim, num_filters, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv1d(num_filters, num_filters, kernel_size=3, padding=1)
        self.conv3 = nn.Conv1d(num_filters, output_dim, kernel_size=1)

    def forward(self, x):
        """
        Forward pass through the deep filtering layer.
        Args:
        - x: Input tensor of shape (batch_size, input_dim, time_steps).
        Returns:
        - Output tensor of shape (batch_size, output_dim, time_steps).
        """
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        x = self.conv3(x)
        return x

In [15]:
class IntegratedModel(nn.Module):
    def __init__(self, freq_bins, time_bins, hidden_dim, num_steps, num_subbands, beta=0.9):
        """
        Integrated model with FullbandModel, SubbandModels, and per-subband DeepFilteringLayers.
        Args:
        - freq_bins: Number of frequency bins in the spectrogram.
        - time_bins: Number of time bins in the spectrogram.
        - hidden_dim: Number of hidden units for the spiking neuron layer.
        - num_steps: Number of timesteps for spiking neuron simulation.
        - num_subbands: Number of frequency subbands.
        - beta: Decay parameter for the LIF neuron.
        """
        super(IntegratedModel, self).__init__()

        self.freq_bins = freq_bins
        self.time_bins = time_bins
        self.num_subbands = num_subbands

        subband_size = freq_bins // num_subbands

        # Fullband model
        self.fullband_model = FullbandModel(freq_bins, time_bins, hidden_dim, beta)

        # Subband models
        self.subband_models = nn.ModuleList([
            SubbandModel(hidden_dim, num_steps, beta)
            for _ in range(num_subbands)
        ])

        # Per-subband Deep Filtering Layers
        self.deep_filtering_layers = nn.ModuleList([
            DeepFilteringLayer(input_dim=hidden_dim, output_dim=subband_size)
            for _ in range(num_subbands)
        ])

    def forward(self, x, num_steps=10):
        """
        Forward pass through the IntegratedModel.
        Args:
        - x: Input tensor of shape (batch_size, frequency_bins, time_bins).
        - num_steps: Number of time steps for spiking neuron simulation.
        Returns:
        - concatenated_output: Tensor after deep filtering and concatenation.
        """
        # Fullband processing
        fullband_output, _ = self.fullband_model(x, num_steps)

        # Subband processing
        subbands = frequency_partition(fullband_output[-1], self.num_subbands)
        subband_outputs = [
            self.subband_models[i](subband, num_steps=num_steps)[0][-1]
            for i, subband in enumerate(subbands)
        ]

        # Per-subband deep filtering
        filtered_subbands = [
            self.deep_filtering_layers[i](subband_output.unsqueeze(-1)).squeeze(-1)
            for i, subband_output in enumerate(subband_outputs)
        ]

        # Concatenate filtered subbands along the feature dimension
        concatenated_output = torch.cat(filtered_subbands, dim=1)  # Shape: (batch_size, freq_bins)

        return concatenated_output


In [22]:
import os
import torch
import numpy as np

# Paths to the feature and label directories
feature_dir = "E:/CS541 - Deep Learning/noisy_audio_np"
label_dir = "E:/CS541 - Deep Learning/clean_audio_np"

# Load a sample feature and label pair
feature_files = sorted(os.listdir(feature_dir))
label_files = sorted(os.listdir(label_dir))

# Select an example
idx = 0  # Change this index to test different samples
feature_path = os.path.join(feature_dir, feature_files[0])
label_path = os.path.join(label_dir, label_files[0])

# Load the .npy files
noisy_tensor = torch.tensor(np.load(feature_path)).unsqueeze(0)  # Shape: (1, freq_bins, time_bins)
clean_tensor = torch.tensor(np.load(label_path)).unsqueeze(0)  # Shape: (1, freq_bins, time_bins)

# Print shapes to verify data loading
print(f"Loaded noisy tensor shape: {noisy_tensor.shape}")
print(f"Loaded clean tensor shape: {clean_tensor.shape}")

# Move tensors to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
noisy_tensor = noisy_tensor.to(device)
clean_tensor = clean_tensor.to(device)

# Model parameters
freq_bins, time_bins = noisy_tensor.shape[1], noisy_tensor.shape[2]
hidden_dim = 64        # Hidden dimension for spiking neurons
num_steps = 10         # Number of timesteps for spiking neurons
num_subbands = 4       # Number of subbands (adjusted for even division)
beta = 0.9             # Decay parameter for LIF neurons

# Initialize the model
model = IntegratedModel(freq_bins, time_bins, hidden_dim, num_steps, num_subbands, beta).to(device)

# Forward pass through the model
with torch.no_grad():  # Disable gradients for testing
    filtered_output = model(noisy_tensor, num_steps=num_steps)

# Print the output shape
print(f"Filtered output shape: {filtered_output.shape}")

# Optional: Compare the filtered output with the clean tensor
# loss = torch.nn.MSELoss()(filtered_output, clean_tensor)
# print(f"Loss (MSE) between filtered output and clean tensor: {loss.item()}")

Loaded noisy tensor shape: torch.Size([1, 128, 860])
Loaded clean tensor shape: torch.Size([1, 128, 1938])
Number of Subbands (fp func): 4
Filtered output shape: torch.Size([1, 128])
