## Neural Analog Modelling 
Please note that the dataset used for this project is not 100% clean, and the errors from the authors errata page are addressed and resolved in the other notebook `0-data-validation.ipynb`. This is done to ensure that in the below our models are not impacted.
### -- Introduction
This repo is concerned with modelling analog effects with neural networks. Why? Mostly for curiosity.

Most VST effects are typically implemented in C++, using handy frameworks like JUCE which contains libraries to handle many of the typical challenges of plugin design such as cross-platform functionality, DSP, frontend design.. In order to do real analog modelling, one typically needs to have some high level domain knowledge and have the capacity to understand complicated analog circuits present in analog synths and effects, as well as a knowledge of DSP which allows us to model these signals numerically. Other approaches include simulation of physical processes like reverberation.

In principle, neural networks are a more high level approach which, presupposing the access to a relevant dataset of dry/wet signals, allow us to model the effects without having to bust out Korg manuals from the 1980s and examine the intricacies of the circuits for their wonderful filters. 

There are clearly many limits in using neural networks to process audio. Some include:
- Clean datasets of dry/wet signals are not typically easy to come by, and are labour intensive to repair. Hence NNs are not always a suitable approach for modelling analog effects.
- Neural networks are typically quite bulky, and implemented in Python (which is typically slower compared to DSP implementations in C++). Although not impossible (see work of [C. Steinmetz](https://scholar.google.com/citations?user=jSvSfIMAAAAJ&hl=en)), this makes it difficult to use neural networks to process signals in real time. This inherent slowness makes a lot of neural approaches only suitable to asynchronous audio processing. Sequence to sequence audio modelling is notoriously slow, especially given the large size of common audio models used today e.g [Demucs](https://github.com/facebookresearch/demucs) (for source separation). 

I am interesting in exploring these limits, especially the second one. 

### -- Data 
In this case, we use the [SignalTrain](https://zenodo.org/records/3824876) dataset from 2019 which contains various dry and wet recordings, where the wet recordings are processed with an analog compressor, the Universal Audio LA-2A. 

This compressor is a very simple one, and we will be concerned with modelling the signal using only two parameters on the compressor:
- The switch between compression and limiting
- The peak reduction

The information about the parameters of the compressor are contained in the file names. In particular, the value between 0-100 represents the peak reduction knob, whereas the binary value 0/1 represents the switch between compression (0) and limiting (1). The authors say that there was no changes to the input or output gains, and that only these two parameters above were changed during recording.

### -- Audio
The audio in this dataset has a sampling rate of 44.1kHz and is mono, as the original analog LA-2A was designed to process mono signals. The individual audio files actually contain a "collage" of different pieces of music which are stitched seamlessly together. The whole recording will be passed through the compressor to obtain the wet signal.

I followed the advice of the authors in cleaning up errors in the dataset (removing / moving certain files) in other notebook `data_validation.ipynb`. I cross correlated all the signals to see if any signals except those mentioned by the authors had a phase shift between the dry and wet signals. To simplify my life, I found all the signals which had a relative phase shift and removed them, keeping only perfectly correlated dry/ wet pairs; which to the credit of the authors, was the vast majority of the signals.

## Compressors

In the context of music production we are interested in Dynamic Range Compressors. These are effects which reduce the dynamic range (the gap between the loudest and quietest parts of a signal) via downward or upward compression (reducing the gain of loud sounds, and increasing the gain of quiet sounds respectively). They have several parameters, the most important ones being:

- Threshold: the volume at which the compressor will be activated. In the case of downward compression, if the threshold is -6dB, and the signal peaks at -10dB, the compressor will never be activated, if it is, the compressor will activate and trigger gain reduction.  On the other hand, in upward compression, if the threshold is -6dB, any sound below this threshold (e.g -10dB) will activate the compression and cause a gain increase.

- Ratio: a parameter which controls the amount of compression applied to an incoming signal. It is called a ratio because it is typically defined in such a way that if the ratio is 4:1, any signal 4dB **over** the threshold will be reduced to 1dB over the threshold.

Some other common parameters include attack and release, which delay or extend resp. the activation of the compressor. e.g if the attack is 50ms, when the compressor detects a signal which crosses the threshold, it will delay its activation by 50ms. The release, if 50ms, will extend the action of the compressor by 50ms. The effect of attack / release is often not considered in a step function like manner, but will often be smoother, where for example during the attack phase of 50ms after the detection of a signal crossing the compressor's threshold, the action of the compressor may be linearly ramped up to its full action. 

Another important control is the "knee" setting. In the case of a "hard knee" compressor, the compression will only activate once the signal crosses the threshold, triggering gain reduction. In the case of a soft knee compressor, this transition point is 'blurred' and less abrupt, even for signals below the threshold, the gain may be attenuated slightly by the compressor. This results in a 'smoother' transition between compressor and uncompressed parts of the signal. 

Here is an example of a minimal compressor implemented in numpy

In [40]:
import numpy as np

def compressor(signal, threshold, ratio):
    compressed_signal = np.zeros_like(signal)
    for i in range(len(signal)):
        if np.abs(signal[i]) > threshold:
            compressed_signal[i] = np.sign(signal[i]) * (threshold + (np.abs(signal[i]) - threshold) / ratio)
        else:
            compressed_signal[i] = signal[i]
    return compressed_signal

Something important to remember is the units in which you are working. In a 16 bit system, a sample can take any value between -32768 and +32767 (2^16 values hence 16 "bit"). If a signal exceeds this maximum value, it may result in clipping (the value of the signal will be cut off at the maximum value). From these raw numbers, we can define the dBFS units (decibels relative to full scale), where here our "full scale" value will be the maximum value of our amplitude in our 16 bit system, +32767. The conversion can be defined as $ X \text(dB) = 20 \text{log}_{10} \left( \frac{\text{Amplitude}}{32767} \right)$. If we reach our max value, notice that the corresponding value in dBFS will be 0dB, as expected. 

dB is commonly used because humans have a logarithmic perception of volume - i.e doubling the intensity of a signal will not be perceived by humans as "twice as loud" indeed the human perception of loudness more closely mimics that of the logarithm of intensity, rather than a linear relationship - interestingly, loudness is a "psychological" quantity, and humans even have different perceptions of loudness depending on the frequency of the incoming signal - that means that a signal at 2kHz and 15kHz with the same intensity will not be perceived as the same loudness by a human. For further details you can read about Fletcher-Munson curves which discuss this phenomenon.

## Metadata Preparation 
It is common when working with data such as audio to keep a metadata dataframe which contains important information about the audio that is not contained in the raw audio signal, so that we can feed this information to our model. Whether compression or limiting is being applied is an example of such metadata. The authors did not include such a metadata frame i.e a df with each row being a track, with a unique id, the paths to the raw and processed audio, and the compression settings applied, so we have to create it ourselves from the filenames provided.

Our data consists of some raw audio files of the form e.g `./data/<split>/input_XXX_.wav` and some corresponding processed audio files `./data/<split>.target_XXX_LA2A_YY__Z__WW.wav`. The compression parameters are contained in the file name of the processed file. The parameters are the following:
- XXX: audio file id
- YY: seems to be the compressor revision, not that important
- Z: compressor/limiter switch either 0 or 1. 
- WW: peak reduction switch, from 0-100. 

In the other notebook I already removed pairs of audio which were out of phase (an error in the creation of the dataset), so we can proceed with extracting the data from the track files as described above. 

 The input data will be:
- Raw audio segment.
- Compressor / Limiter Switch (0 or 1 resp.).
- Compressor peak reduction (0 - 100).
Where the target will be the corresponding processed audio segment.

In [41]:
# Get list of unique track ids by parsing the info in the file names
import os
import pandas as pd
from collections import defaultdict


def get_track_ids(track_paths):
    """
    Gets a list of track_ids which are unique from the given track_paths
    """
    track_ids = [track_path.split("_")[1] for track_path in track_paths]
    numeric_ids_only = [track for track in track_ids if track.isdigit()]
    return list(set(numeric_ids_only))


def prepare_metadata_records(splits):
    """
    Gets a set of parsed metadata records for every track in each split.
    - Creates a dict with split names as keys e.g {'train': .., 'test': .., ..}
    - The value corresponding to each split name is a list of nested records of the form [{'track_id':{'X_path':xxx, 'param1':yyy, 'param2':zzz, 'Y_path':www}, ..]
    """
    metadata = defaultdict(list)
    for split in splits:
        split_name = split.split("/")[-2].lower()  # unfortunate naming of variable here
        track_paths = os.listdir(split)
        track_ids = get_track_ids(track_paths)

        split_metadata = []
        for track_id in track_ids:
            track_level_data = defaultdict(dict)
            for track_path in track_paths:
                if track_id in track_path and "target" in track_path:
                    split_path = track_path.split("_")
                    compress_or_limit = split_path[-3]
                    peak_reduction = split_path[-1].split(".wav")[0]

                    track_level_data[track_id]["raw_audio_path"] = (
                        split + "input_" + track_id + "_.wav"
                    )
                    track_level_data[track_id]["compress_or_limit"] = compress_or_limit
                    track_level_data[track_id]["peak_reduction"] = peak_reduction
                    track_level_data[track_id]["processed_audio_path"] = (
                        split + track_path
                    )
            split_metadata.append(track_level_data)
        metadata[split_name] = split_metadata
    return metadata


def turn_records_into_df(records):
    """
    Turns the metadata records into a nice and simple dataframe
    """
    flattened_data = []
    for split, records in records.items():
        for record in records:
            for record_id, details in record.items():
                flattened_data.append(
                    {"split": split, "track_id": record_id, **details}
                )

    return pd.DataFrame(flattened_data)


splits = ["./data/Train/", "./data/Test/", "./data/Val/"]
records = prepare_metadata_records(splits)
df = turn_records_into_df(records)

Now we have our basic metadata df that will allow us to prepare each small audio segment for training, let's just do a quick sanity check to make sure we have no parsing errors:

In [42]:
df.isna().sum()

split                   0
track_id                0
raw_audio_path          0
compress_or_limit       0
peak_reduction          0
processed_audio_path    0
dtype: int64

In [43]:
df.track_id.duplicated().sum()

0

Looks good, no NaNs or duplicates. Let's just fix the types of some columns as our model will need them to have the correct types later:

In [44]:
df["track_id"] = df["track_id"].astype(int)
df["compress_or_limit"] = df["compress_or_limit"].astype(int)
df["peak_reduction"] = df["peak_reduction"].astype(int)

## Creating Audio Segments for Training + Evaluation + Challenges
Each file in the dataset is not a unique track, but rather a very long (can be up to 20 minutes) collage of different tracks, stitched together without interruption. All tracks in the file are compressed with the same compression parameters indicated in the "target" file name. We also have access to the raw, uncompressed audio file corresponding to this.

Our model will not have an input size of 20 minutes and will not handle variable input sizes effectively, so we need to break each file into chunks of 3-10 seconds and process them one by one, these are common input sizes in the literature for audio ML (e.g ShortChunk has 15 seconds, whereas MusiCNN has 3 seconds). 

This means that from 1 file of 20 minutes, we will actually many training examples for a model with 1 second input length. 

It also means that during inference (i.e during modelling) we will not be able to process a whole file at once but will need to split the incoming audio into 3 second chunks and then process each chunk separately. We will call the incoming audio chunk the "buffer". This fact is actually a serious design challenge for modelling a time based effect such as compression, because a naive model architecture means it will only use incoming audio in the buffer to produce an output - but time based effects like compression have parameters such as attack/release (discussed above) which means audio in the buffer should trigger compression in the next incoming buffer. If this is not clear, imagine that our 3 second audio signal in the buffer has a loud peak at 2.99s. If our compressor has an attack time of 0,015s (15ms), the compression will only "kick in" or be triggered _after_ the current signal is out of the buffer (2.99 + 0.15 > 3.0), affecting only the start of the next 3 second signal coming into the buffer. This means there is a potential dependency between windows that are being processed independently by our model. The same problem exists in the realm of DSP audio plugin design, where these "transient" errors are often treated using a lookahead buffer. This challenge may end up motivating the design of our model later if naive models prove to suffer from this possible complication.

First, we will load a full 15-20 minute audio file and split it up into windows of size 1 seconds. We will have a window stride of 0.5 second. This means that we will keep a bit of the "last" window in the current window, which is a common technique in audio processing / time series analysis to have smoother transitions between adjacent windows. Let's make a helper function that will take an object of size N, a window of size K (K <= N) and a stride length S and compute how many overlapping intervals we can compute from our audio. For an example of what this means look below for an explicit example.

In [45]:
def overlapping_interval_count(object_size, window_size, window_stride):
    """all lengths in samples here"""
    stride = 0
    count = 0
    while True:
        pos = stride + window_size 
        if pos >= object_size:
            return count + 1 # gives number of overlapping intervals that fit in the 
        count += 1
        stride += window_stride

Now let's create our training dataset

In [46]:
import torchaudio
import torch 
from torch.utils.data import Dataset


class AudioDataset(Dataset):
    def __init__(
        self,
        input_audio_paths,
        target_audio_paths,
        window_size=0.1,
        overlap=0.05,
        params=None,
    ):
        self.input_audio_paths = input_audio_paths
        self.target_audio_paths = target_audio_paths
        self.window_size = window_size  # in seconds
        self.window_size_samples = self.window_size * 44100 # in samples
        self.overlap = overlap  # in seconds
        self.overlap_samples = self.overlap * 44100 # in samples
        self.params = params  # compressor parameters i.e compress-limit switch, peak reduction
        self.window_to_audio_mapping = self.window_index_to_audio_indices() # map an
        self.examples = self.create_examples()

    def __len__(self):
        return self.dataset_size

    def __getitem__(self, window_index):
        """ 
        Converts a window index to a tuple of audio indices, i.e the first window (0)
        corresponds to the first audio (0) and the first window in this audio (0). 

        The set of {window_index: (audio_index, window_index_in_audio)} is precomputed 
        in the `self.window_to_audio_mapping method`. 
        
        We use this tuple of indices to 
        load the actual audio data (the stream of samples), lazily, for training. 
        """
        audio_index, window_index_in_audio = self.window_to_audio_mapping[window_index]
        input_audio_path, target_audio_path, params = self.examples[audio_index]

        input_waveform, _ = torchaudio.load(
            input_audio_path,
            num_frames=self.window_size_samples,
            frame_offset=window_index_in_audio * self.overlap_samples,
        ) 

        target_waveform, _ = torchaudio.load(
            target_audio_path,
            num_frames=self.window_size_samples,
            frame_offset=window_index_in_audio * self.overlap_samples,
        )
        
        input_waveform = input_waveform[0] # 1D tensor e.g [0.1, 0.03, ...] of length window size * sample rate
        target_waveform = target_waveform[0]

        input_waveform = (input_waveform + 1.0) / 2.0 # Rescale (invertibly) values that go from -1 to 1, to be between 0 and 1
        target_waveform = (target_waveform + 1.0) / 2.0 

        compress_limit = torch.tensor(params[0], dtype=torch.float32)
        peak_reduction = torch.tensor(params[1]/100.0, dtype=torch.float32) # Ensure the peak reduction is between 0 and 1, not 0 - 100

        return (input_waveform, compress_limit, peak_reduction), target_waveform

    def window_index_to_audio_indices(self):
        """ 
        Our audio dataset constructor takes in a list of input/target audio paths. In the SignalTrain dataset, each 
        individual audio is very long, up to 20 mins. Our model will have a much smaller input size. Therefore, we 
        will use a given audio file to create several training examples by taking overlapping windows of for example 3 seconds 
        from each file. e.g 0-3s, 1-4s, 2-5s all the way up until we use all 20 minutes of the audio file.

        Since __getitem__ method is lazy, it will try to load each window of audio data on the fly instead of preparing
        them in advance. This means that when we say __getitem__(293), we need to know which audio file corresponds to 
        window 293, and where precisely this window is situated in that audio file, so that we can load the right data.
        
        For example, window with index 0 can be identified with the first audio file, and the first window in that audio i.e (0,0).
        
        This function is used to pre-compute the mapping from a window index into a tuple of indices identifying both the audio file,
        and the position of this window in that audio file so that we can load the audio data on the fly in __getitem__. 
        """
        cumulative_window_index = 0
        mapping = {}
        for audio_index, path in enumerate(self.input_audio_paths):
            num_windows = self.calculate_num_windows(path)
            for i in range(num_windows):
                mapping[cumulative_window_index + i] = (
                    audio_index,
                    i,
                )  # e.g (3rd audio, 4th window of 3rd audio)
            cumulative_window_index += num_windows
        self.dataset_size = len(mapping)

        return mapping

    def calculate_num_windows(self, audio_path):
        waveform, sample_rate = torchaudio.load(audio_path)
        window_samples = self.window_size * sample_rate
        overlap_samples = self.overlap * sample_rate
        num_windows = overlapping_interval_count(
            waveform.size(1), window_samples, overlap_samples
        )
        return num_windows

    def create_examples(self):
        return list(zip(self.input_audio_paths, self.target_audio_paths, self.params))


# Example usage with multiple paths and parameters
train_df = df[df.split == "train"]
test_df = df[df.split == "test"]
val_df = df[df.split == "val"]

train_dataset = AudioDataset(
    train_df["raw_audio_path"].to_list(),
    train_df["processed_audio_path"].to_list(),
    window_size=0.01,
    overlap=0.005,
    params=list(zip(train_df["compress_or_limit"], train_df["peak_reduction"])),
)
test_dataset = AudioDataset(
    test_df["raw_audio_path"].to_list(),
    test_df["processed_audio_path"].to_list(),
    window_size=0.01,
    overlap=0.005,
    params=list(zip(test_df["compress_or_limit"], test_df["peak_reduction"])),
)
val_dataset = AudioDataset(
    val_df["raw_audio_path"].to_list(),
    val_df["processed_audio_path"].to_list(),
    window_size=0.01,
    overlap=0.005,
    params=list(zip(val_df["compress_or_limit"], val_df["peak_reduction"])),
)

Note: 
- Torchaudio loads audio as a tuple containing a 2D tensor and the sample rate. The 2D tensor is structured as (num_channels, num_frames). In the case of mono audio (1 channel), it still follows this format, but with the number of channels being 1.

Below is an example of what our input / output data will look like:

In [47]:
(input_waveform, cl, pr), target_waveform = train_dataset.__getitem__(400)

print(f"input waveform:\n{input_waveform}, shape: {input_waveform.shape}\n")
print(f"compress or limit:\n{cl}, shape: {cl.shape}\n")
print(f"peak reduction:\n{pr}, shape: {pr.shape}\n")
print(f"target waveform:\n{target_waveform}, shape: {target_waveform.shape}\n")

input waveform:
tensor([0.4200, 0.4219, 0.4201, 0.4143, 0.4059, 0.3955, 0.3854, 0.3816, 0.3848,
        0.3880, 0.3871, 0.3857, 0.3873, 0.3886, 0.3815, 0.3664, 0.3570, 0.3626,
        0.3739, 0.3795, 0.3817, 0.3871, 0.3931, 0.3941, 0.3915, 0.3900, 0.3900,
        0.3906, 0.3937, 0.4004, 0.4081, 0.4152, 0.4216, 0.4260, 0.4273, 0.4268,
        0.4287, 0.4355, 0.4402, 0.4364, 0.4345, 0.4453, 0.4592, 0.4667, 0.4786,
        0.5056, 0.5427, 0.5806, 0.6131, 0.6332, 0.6272, 0.5823, 0.5208, 0.5022,
        0.5538, 0.6329, 0.6812, 0.6703, 0.6021, 0.5143, 0.4588, 0.4493, 0.4684,
        0.5023, 0.5241, 0.5021, 0.4567, 0.4379, 0.4451, 0.4406, 0.4164, 0.3920,
        0.3772, 0.3701, 0.3672, 0.3643, 0.3596, 0.3534, 0.3462, 0.3404, 0.3387,
        0.3381, 0.3326, 0.3202, 0.3069, 0.3010, 0.3018, 0.3027, 0.3023, 0.3046,
        0.3076, 0.3068, 0.3044, 0.3036, 0.3009, 0.2934, 0.2878, 0.2904, 0.2945,
        0.2952, 0.2991, 0.3085, 0.3158, 0.3197, 0.3238, 0.3235, 0.3173, 0.3129,
        0.3098, 0.3039, 

## Model Architecture

Let's consider a fairly standard audio sequence modelling setup. We will use an LSTM to model the input audio sequence and obtain a sequence of the same length as the input sequence, but where for each time step we will have constructed a "hidden representation" i.e an input of length N will have an output of shape (N, H) where H is the hidden dimension of the LSTM. As for the compressor parameters, we will embed them into dense representations of size K, and then transpose them. We can then flatten our LSTM output to get a 1D tensor, concat it with our transposed compressor parameters and then pass all of this into a FFN to recover an output of length N (our target audio). Here's an example diagram:

<img src="assets/sample-archi.png" alt="Image Alt Text" width="75%">

Assuming our audio has a sample rate of 44.1kHz and we use a window size of 3 seconds, we will be predicting 132300 samples (might be a lot).

Later we'll try some other architectures..

In [48]:
X, y = train_dataset.__getitem__(0)

In [49]:
X

(tensor([0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000, 0.5000,
         0.5000, 0.5000, 0.5

In [50]:
y.shape

torch.Size([441])

for data, label in dataloader:
    print(len(data))
    print(data[0].shape)
    print(data[1].shape)
    print(data[2].shape)
    print(label.shape)
    raise TypeError

In [51]:
print(len(train_dataset))

13067938


In [52]:
import torch
import torch.nn as nn
import time
from torch.utils.data import DataLoader

device = torch.device("mps")

# Hyperparameters
sequence_length = 441  # Assuming each audio interval has this length
batch_size = 16
hidden_size = 64  # Experiment with different values
fixed_embedding_size = 128
learning_rate = 0.001

dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

class MLPDecoder(nn.Module):
    def __init__(self, input_size, output_size):
        super().__init__()
        self.relu1 = nn.ReLU()
        self.relu2 = nn.ReLU()
        self.relu3 = nn.ReLU()
        self.relu4 = nn.ReLU()
        self.dense1 = nn.Linear(input_size, int(input_size/4))
        self.dense2 = nn.Linear( int(input_size/4),  int(input_size/8))
        self.dense3 = nn.Linear( int(input_size/8),  int(input_size/16))
        self.dense4 = nn.Linear( int(input_size/16),  sequence_length)

    def forward(self, x):
        x = self.relu1(self.dense1(x))
        x = self.relu2(self.dense2(x))
        x = self.relu3(self.dense3(x))
        x = self.relu4(self.dense4(x))
        return x
    
class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size):
        super(LSTMModel, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size)
        self.compress_switch_embedding = nn.Embedding(num_embeddings=1, embedding_dim=fixed_embedding_size)
        self.peak_reduction_embedding = nn.Embedding(num_embeddings=1, embedding_dim=fixed_embedding_size)
        self.decoder = MLPDecoder(sequence_length*hidden_size + 2*fixed_embedding_size, sequence_length)

    def forward(self, inputs):
        lstm_input = inputs[0].to(device)
        compress_limit_input = inputs[1].to(device)
        peak_reduction_input = inputs[2].to(device)

        # LSTM
        input_seq = lstm_input.transpose(0, 1)
        input_seq = input_seq.unsqueeze(2)
        print(f"INPUT SEQ SHAPE: {input_seq.shape}")
        lstm_output, (h_n, c_n) = self.lstm(input_seq)
        transposed_lstm_output = lstm_output.transpose(1,0)
        flattened_lstm_output = torch.flatten(transposed_lstm_output, 1,2)

        # Embed Compression Switch 
        embedded_compress_switch = self.compress_switch_embedding(compress_limit_input)

        # Embed PR Parameter
        embedded_peak_reduction = self.peak_reduction_embedding(peak_reduction_input)

        # Concat all embeddings
        concat_embeddings = torch.cat([flattened_lstm_output, embedded_compress_switch, embedded_peak_reduction], dim=1)

        # Pass embeddings through decoder MLP 
        prediction = self.decoder(concat_embeddings)

        return prediction


print(f"Before Model")
model = LSTMModel(1, hidden_size)  # Input size is 1 due to reshaping
print(f"Model initialized")
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
print(f"Optimizer initialized")
loss_fn = nn.MSELoss()  # Adjust loss function if needed
print(f"Loss fn initialized")
model.to(device)
print(f"Model on GPU: {model}")

# Training loop
for epoch in range(1):
    for data, label in dataloader:
        start_time = time.time()

        label = label.to(device)

        optimizer.zero_grad()
        prediction = model(data)
        loss = loss_fn(prediction, label)
        loss.backward()
        optimizer.step()
        
        end_time = time.time()
        print(f"EPOCH {epoch+1} TIME: {(end_time - start_time)/60.0} mins")
        print(f"EPOCH {epoch+1} LOSS: {loss.item()}\n")

print("Training complete!")

Before Model
Model initialized
Optimizer initialized
Loss fn initialized
Model on GPU: LSTMModel(
  (lstm): LSTM(1, 64)
  (compress_switch_embedding): Embedding(1, 128)
  (peak_reduction_embedding): Embedding(1, 128)
  (decoder): MLPDecoder(
    (relu1): ReLU()
    (relu2): ReLU()
    (relu3): ReLU()
    (relu4): ReLU()
    (dense1): Linear(in_features=28480, out_features=7120, bias=True)
    (dense2): Linear(in_features=7120, out_features=3560, bias=True)
    (dense3): Linear(in_features=3560, out_features=1780, bias=True)
    (dense4): Linear(in_features=1780, out_features=441, bias=True)
  )
)


RuntimeError: stack expects each tensor to be equal size, but got [441] at entry 0 and [442] at entry 2

THE OVERLAP SAMPLES IS NOT CORRECT I THINK, IT IS A DECIMAL 44100 * 0.005 = 221.5 or so, it should probably be whole number..

fix up shapes to match what we actually use and then make this work with the dataset object