In [41]:
import snntorch as snn
from snntorch import spikeplot as splt
from snntorch import spikegen

import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import itertools

# self.gated_spiking = nn.Sequential(
        #     nn.Linear(input_dim, hidden_dim),
        #     nn.Sigmoid()  # Simulates gating
        # )

In [42]:
class FullbandModel(nn.Module):
    def __init__(self, freq_bins, time_bins, hidden_dim, beta=0.9):
        """
        Fullband Model with normalization, spiking neuron layer, and a linear layer.
        Args:
        - freq_bins: Number of frequency bins in the spectrogram.
        - time_bins: Number of time bins in the spectrogram.
        - hidden_dim: Number of hidden units for the spiking neuron layer.
        - beta: Decay parameter for the LIF neuron.
        """
        super(FullbandModel, self).__init__()
        
        self.freq_bins = freq_bins
        self.time_bins = time_bins
        self.input_dim = freq_bins * time_bins  # Flattened input feature size
        self.hidden_dim = hidden_dim

        # Layers
        self.normalization = nn.LayerNorm(self.input_dim)  # Normalize input features
        self.spikingneuron = snn.Leaky(beta=beta)          # Leaky Integrate-and-Fire neuron
        self.linear = nn.Linear(self.input_dim, self.hidden_dim)  # Linear transformation

    def forward(self, x, num_steps=10):
        """
        Forward pass for the FullbandModel with time-stepped spiking neuron dynamics.
        Args:
        - x: Input tensor of shape (batch_size, frequency_bins, time_bins).
        - num_steps: Number of time steps for spiking neuron simulation.
        Returns:
        - spk_rec: Spiking activity across timesteps (num_steps, batch_size, hidden_dim).
        - mem_rec: Membrane potential across timesteps (num_steps, batch_size, hidden_dim).
        """
        batch_size, freq_bins, time_bins = x.shape
        assert freq_bins == self.freq_bins and time_bins == self.time_bins, \
            "Input dimensions must match model initialization dimensions."

        # Flatten the spectrogram into 1D vectors
        x = x.view(batch_size, -1)  # Shape: (batch_size, input_dim)

        # Normalize the input
        x = self.normalization(x)

        # Initialize membrane potentials for the spiking neuron
        mem = torch.zeros((batch_size, self.hidden_dim), dtype=torch.float32, device=x.device)

        # Record spiking activity and membrane potentials
        spk_rec = []
        mem_rec = []

        for step in range(num_steps):
            # Linear transformation
            cur = self.linear(x)

            # Spiking neuron dynamics
            spk, mem = self.spikingneuron(cur, mem)

            # Record outputs
            spk_rec.append(spk)
            mem_rec.append(mem)

        # Stack outputs across timesteps
        spk_rec = torch.stack(spk_rec, dim=0)  # Shape: (num_steps, batch_size, hidden_dim)
        mem_rec = torch.stack(mem_rec, dim=0)  # Shape: (num_steps, batch_size, hidden_dim)

        return spk_rec, mem_rec

In [43]:
class SubbandModel(nn.Module):
    def __init__(self, hidden_dim, num_steps, beta=0.9):
        """
        Subband Model with normalization, spiking neuron layer, and a linear layer.
        Args:
        - hidden_dim: Number of hidden units for the spiking neuron layer.
        - num_steps: Number of timesteps for spiking neuron simulation.
        - beta: Decay parameter for the LIF neuron.
        """
        super(SubbandModel, self).__init__()

        self.hidden_dim = hidden_dim
        self.num_steps = num_steps
        self.spikingneuron = snn.Leaky(beta=beta)  # Leaky Integrate-and-Fire neuron
        self.linear = None  # Linear layer to be initialized dynamically

    def forward(self, x, num_steps=None):
        """
        Forward pass for the SubbandModel with time-stepped spiking neuron dynamics.
        Args:
        - x: Input tensor of shape (batch_size, subband_dim).
        - num_steps: Number of time steps for spiking neuron simulation.
        Returns:
        - spk_rec: Spiking activity across timesteps (num_steps, batch_size, hidden_dim).
        - mem_rec: Membrane potential across timesteps (num_steps, batch_size, hidden_dim).
        """
        if num_steps is None:
            num_steps = self.num_steps

        batch_size, subband_dim = x.shape

        # Initialize normalization and linear layers dynamically
        if not hasattr(self, 'normalization') or self.normalization is None:
            self.normalization = nn.LayerNorm(subband_dim).to(x.device)
        if self.linear is None:
            self.linear = nn.Linear(subband_dim, self.hidden_dim).to(x.device)

        # Normalize input
        x = self.normalization(x)

        # Initialize membrane potentials for the spiking neuron
        mem = torch.zeros((batch_size, self.hidden_dim), dtype=torch.float32, device=x.device)

        # Record the outputs
        spk_rec = []
        mem_rec = []

        for step in range(num_steps):
            cur = self.linear(x)  # Linear transformation
            spk, mem = self.spikingneuron(cur, mem)  # Spiking neuron dynamics
            spk_rec.append(spk)
            mem_rec.append(mem)

        # Stack the recorded values across timesteps
        spk_rec = torch.stack(spk_rec, dim=0)  # Shape: (num_steps, batch_size, hidden_dim)
        mem_rec = torch.stack(mem_rec, dim=0)  # Shape: (num_steps, batch_size, hidden_dim)

        return spk_rec, mem_rec


In [44]:
def frequency_partition(spectrogram, num_subbands):
    """
    Splits the input tensor into subbands along the second dimension.
    Args:
    - spectrogram: Input tensor of shape (batch_size, hidden_dim).
    - num_subbands: Number of subbands to split the hidden_dim into.
    Returns:
    - subbands: List of tensors, each of shape (batch_size, subband_size).
    """
    batch_size, hidden_dim = spectrogram.shape
    subband_size = hidden_dim // num_subbands

    # Split along the hidden_dim axis
    subbands = torch.split(spectrogram, subband_size, dim=1)
    print(f"Number of Subbands (fp func): {len(subbands)}")
    return subbands


def frequency_reconstruct(subbands):
    """
    Reconstructs the full spectrogram from processed subbands.
    Args:
    - subbands: List of tensors, each of shape (batch_size, subband_size, time_bins).
    Returns:
    - reconstructed: Tensor of shape (batch_size, frequency_bins, time_bins).
    """
    # Concatenate the processed subbands along the frequency axis
    reconstructed = torch.cat(subbands, dim=1)
    return reconstructed

In [45]:
class IntegratedModel(nn.Module):
    def __init__(self, freq_bins, time_bins, hidden_dim, num_steps, num_subbands, beta=0.9):
        """
        Integrated model combining FullbandModel, frequency partitioning, and SubbandModels.
        Args:
        - freq_bins: Number of frequency bins in the spectrogram.
        - time_bins: Number of time bins in the spectrogram.
        - hidden_dim: Number of hidden units for the spiking neuron layer.
        - num_steps: Number of timesteps for spiking neuron simulation.
        - num_subbands: Number of frequency subbands.
        - beta: Decay parameter for the LIF neuron.
        """
        super(IntegratedModel, self).__init__()

        self.freq_bins = freq_bins
        self.time_bins = time_bins
        self.num_subbands = num_subbands

        subband_size = freq_bins // num_subbands

        # Fullband model
        self.fullband_model = FullbandModel(freq_bins, time_bins, hidden_dim, beta)

        # Subband models
        self.subband_models = nn.ModuleList([
            SubbandModel(hidden_dim, num_steps, beta)
            for _ in range(num_subbands)
        ])
        
        print(f"Number of Subband Models: {len(self.subband_models)}")

    def forward(self, x, num_steps=10):
        """
        Forward pass through FullbandModel, frequency partitioning, and SubbandModels.
        Args:
        - x: Input tensor of shape (batch_size, frequency_bins, time_bins).
        - num_steps: Number of time steps for spiking neuron simulation.
        Returns:
        - subband_outputs: List of tensors, each corresponding to the output of a SubbandModel.
        """
        # Fullband processing
        fullband_output, _ = self.fullband_model(x, num_steps)

        # Use the last timestep output for partitioning
        subbands = frequency_partition(fullband_output[-1], self.num_subbands)
        print(f"Number of Subbands: {len(subbands)}")

        # Subband processing
        subband_outputs = [
            self.subband_models[i](subband, num_steps=num_steps)[0][-1]
            for i, subband in enumerate(subbands)
        ]

        return subband_outputs


In [47]:
import os
import torch
import numpy as np

# Define paths
feature_dir = "E:/CS541 - Deep Learning/noisy_audio_np"
label_dir = "E:/CS541 - Deep Learning/clean_audio_np"

# Load a sample feature and label
test_feature_file = os.path.join(feature_dir, "noisy_spectrogram1.npy")
test_label_file = os.path.join(label_dir, "clean_spectrogram1.npy")

# Load the .npy files
noisy_spectrogram = np.load(test_feature_file)  # Shape: (frequency_bins, time_bins)
clean_spectrogram = np.load(test_label_file)    # Shape: (frequency_bins, time_bins)

# Convert to PyTorch tensors
noisy_tensor = torch.tensor(noisy_spectrogram, dtype=torch.float32).unsqueeze(0)  # Add batch dimension
clean_tensor = torch.tensor(clean_spectrogram, dtype=torch.float32).unsqueeze(0)  # Add batch dimension

noisy_spectrogram = np.load("E:/CS541 - Deep Learning/noisy_audio_np/noisy_spectrogram1.npy")  # Shape: (freq_bins, time_bins)
noisy_tensor = torch.tensor(noisy_spectrogram, dtype=torch.float32).unsqueeze(0)  # Add batch dimension
print("Loaded tensor shape:", noisy_tensor.shape)


# Parameters for the model
freq_bins = noisy_tensor.shape[1]  # Frequency bins
time_bins = noisy_tensor.shape[2]  # Time bins
hidden_dim = 64  # Hidden layer size for spiking neurons
num_steps = 10   # Number of timesteps for spiking neuron simulation
num_subbands = 4  # Number of frequency partitions
beta = 0.9       # Decay parameter for LIF neurons

# Initialize the IntegratedModel
model = IntegratedModel(freq_bins, time_bins, hidden_dim, num_steps, num_subbands, beta).to("cuda")

# Move tensors to GPU if available
noisy_tensor = noisy_tensor.to("cuda")
print("Input tensor shape:", noisy_tensor.shape)  # Should match (batch_size, freq_bins, time_bins)
clean_tensor = clean_tensor.to("cuda")
print("Model freq_bins:", model.freq_bins)
print("Model time_bins:", model.time_bins)

# Forward pass
subband_outputs = model(noisy_tensor, num_steps=num_steps)

# Print the shape of subband outputs for validation
for i, subband_output in enumerate(subband_outputs):
    print(f"Subband {i+1} output shape: {subband_output.shape}")

Loaded tensor shape: torch.Size([1, 128, 860])
Number of Subband Models: 4
Input tensor shape: torch.Size([1, 128, 860])
Model freq_bins: 128
Model time_bins: 860
Number of Subbands (fp func): 4
Number of Subbands: 4
Subband 1 output shape: torch.Size([1, 64])
Subband 2 output shape: torch.Size([1, 64])
Subband 3 output shape: torch.Size([1, 64])
Subband 4 output shape: torch.Size([1, 64])


In [None]:
# # Model initialization
# freq_bins = 128   # Number of Mel frequency bins
# time_bins = 431   # Number of time bins in each spectrogram
# input_dim = freq_bins * time_bins  # Flattened input dimension
# hidden_dim = 512  # Hidden layer size
# num_steps = 25    # Number of timesteps for spiking dynamics