<center><h1>Implementation of SaShiMi</h1><h3>From paper <a href="https://arxiv.org/abs/2202.09729">It's Raw! Audio Generation with State-Space Models</a></h3></center>

Generate audio from raw samples. Can be used as a CNN or RNN.

# Prelude

Setup, import modules, select device, etc.

In [1]:
%matplotlib inline
# Built-in IPython extension to reload modules when updated.
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch import nn
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm.auto import tqdm
import copy
import IPython.display as ipd

# Custom modules
from S4 import *
from wav_dataset import *

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#device = torch.device("cpu")
device

device(type='cuda')

# Dataset

The dataset is actually just the audio track of [this 4 hour piano video from YouTube](https://www.youtube.com/watch?v=EhO_MrRfftU). It's resampled at 16 kHZ and divided into 1 minute WAV files. Thankfully, the authors of the paper have already done this step and uploaded the processed dataset [here on HuggingFace](https://huggingface.co/datasets/krandiash/youtubemix).

In [4]:
youtube_mix_transform = YoutubeMixTransform(device)
train_dataset = WavDataset("../datasets/youtube-mix/train", youtube_mix_transform)

In [5]:
# Check dataset size
assert len(train_dataset) == 212
# Check sample rate
assert train_dataset.sample_rate == 16000

In [6]:
x, y = train_dataset[0]
# Approximately 1 minute (not all files were exactly 1 minute)
assert list(x.size()) == [958400]
assert list(y.size()) == [958400]

`x` contains samples 0 to N in the sound file, each sample is a float between -1 and +1. `y`, on the other hand, contains samples 1 to N+1, each sample is an integer between 0 to 255. For each input sample $i$, the network predicts the probabilities for the next sample ($i+1$) in 8-bits using the inputs from $0$ to $i$.

Let's listen to one of the sound files in the dataset:

In [7]:
ipd.Audio(x.cpu().numpy(), rate=16000)

# SaShiMi Architecture

Here, we borrow Figure 1 from the paper:

<img src="images/sashimi-architecture.png" alt="SaShiMi Figure"/>


## S4 Layer

At a high-level, S4 layer can be thought of as a mapping from a 1D sequence to another 1D sequence. Some of the important properties of S4:
- It's causal. The output at index $i$ depends on inputs from $0$ to $i-1$.
- It can be computed using either convolution or recurrence.
    - We use convolution during training and recurrence during sampling.
- Although it's 1D to 1D, we stack multiple S4 Layers together to process multidimensional signals.
    - In this case, different dimensions don't interact with each other.
    - Linear layers compute interactions between different signal dimensions but they cannot handle interactions through time. Therefore, linear layers complement S4 layers nicely.
- S4 doesn't provide non-linearity by itself. It needs to be combined with activation functions.

For more information, you can read the [paper](https://arxiv.org/abs/2202.09729) or the source code of our S4 implementation. Our code contains many docstrings and comments.

## S4 Block

The main component of SaShiMi architecture is the **S4 Block**. The paper explains the details of S4 block in Appendix A.2.

### High-level Architecture

First pass:
1. Input
2. LayerNorm
3. S4 Layer
4. GELU
5. Linear
6. Residual connection from 1

Second pass:
1. Output of the first pass
2. LayerNorm
3. Linear
4. GELU
5. Linear
6. Residual connection from 1

All linear layers are position-wise, i.e., they operate on the signal dimensions, not
the time dimension.

### Implementation

Once the S4 layer is implemented, constructing an S4 block is easy.

The following function constructs the S4 block from these arguments:
- `signal_dim`: Number of dimensions in the signal.
- `state_dim`: Number of dimensions in inner state.
- `sequence_length`: The length of the sequence on which this model will operate.
    - Can be changed later, but models trained on one sequence length perform poorly on another sequence length.
- `expansion_factor`: The factor by which the number of dimensions will be multiplied between two linear layers in the second pass.

In [8]:
def S4Block(signal_dim: int, state_dim: int, sequence_length: int, expansion_factor: int = 2):
    return Sequential(
        Residual(
            nn.LayerNorm(signal_dim),
            S4Base(signal_dim, state_dim, sequence_length),
            nn.GELU(),
            nn.Linear(signal_dim, signal_dim),
        ),
        Residual(
            nn.LayerNorm(signal_dim),
            nn.Linear(signal_dim, signal_dim * expansion_factor),
            nn.GELU(),
            nn.Linear(signal_dim * expansion_factor, signal_dim),
        ),
    )

Note that we use our own `Sequential` implementation instead of `torch.nn.Sequential`. Our implementation inherits `torch.nn.Sequential` and accounts for the S4 layer when we want to run it in recurrent mode. `Residual` is the same as `Sequential` but it applies a residual connection from the beginning to the end, i.e., $Residual(x) = x + Sequential(x)$.

## Up-Pool and Down-Pool

The next components **up-pool** and **down-pool** are combinations of reshaping and a linear layer.

In particular, for an input of shape $(T, H)$ where $T$ is the sequence length (sample count) and $H$ is the signal dimension, the **down-pool** is

$$
(T,H) \xrightarrow{\text{reshape}} (T/p, H \cdot p) \xrightarrow{\text{linear}} (T/p, H \cdot q)
$$

where $p$ is the pooling factor and $q$ is the expansion factor.

**Up-pool** is simply the opposite operation:

$$
(T/p, H \cdot q) \xrightarrow{\text{linear}} (T/p, H \cdot p) \xrightarrow{\text{reshape}} (T,H)
$$

They use $p=4$ and $q=2$ in the paper.

In [9]:
class DownPool(nn.Module):
    def __init__(self, signal_dim: int, pooling_factor: int = 4, expansion_factor: int = 2):
        super().__init__()
        self.pooling_factor = pooling_factor
        self.linear = nn.Linear(
            signal_dim * pooling_factor,
            signal_dim * expansion_factor,
        )

    def forward(self, x):
        T = x.size(dim=-2)
        H = x.size(dim=-1)
        x = x.reshape(-1, T // self.pooling_factor, H * self.pooling_factor)
        return self.linear(x)

In [10]:
class UpPool(nn.Module):
    def __init__(self, signal_dim: int, pooling_factor: int = 4, expansion_factor: int = 2):
        super().__init__()
        self.pooling_factor = pooling_factor
        self.linear = nn.Linear(
            signal_dim * expansion_factor,
            signal_dim * pooling_factor,
        )

    def forward(self, x):
        y = self.linear(x)
        T = y.size(dim=-2)
        H = y.size(dim=-1)
        return y.reshape(-1, T * self.pooling_factor, H // self.pooling_factor)

Let's check whether the dimensions are correct:

In [11]:
x = torch.randn(1, 64, 2)
y = DownPool(2)(x)
assert list(y.size()) == [1, 64 // 4, 2 * 2]
z = UpPool(2)(y)
assert list(z.size()) == [1, 64, 2]

There's also a **bidirectional S4** variant presented in the paper, which is simply the concatenation of two S4 layers but one of them is run in reverse order. However, obviously it's not used for autoregressive tasks. Thus, we won't implement it in this notebook.

## Overall Architecture

Now, we have all the components required to implement the SaShiMi architecture given in Figure 1.

In [17]:
def SaShiMi(input_dim: int,
            hidden_dim: int,
            output_dim: int,
            state_dim: int,
            sequence_length: int,
            block_count: int,
           ):
    return Sequential(
        nn.Linear(input_dim, hidden_dim),
        Residual(
            DownPool(hidden_dim),
            Residual(
                DownPool(2 * hidden_dim),
                Residual(*[
                    S4Block(4 * hidden_dim, state_dim, sequence_length // 16)
                    for _ in range(block_count)
                ]),
                UpPool(2 * hidden_dim),
            ),
            *[S4Block(2 * hidden_dim, state_dim, sequence_length // 4) for _ in range(block_count)],
            UpPool(hidden_dim),
        ),
        *[S4Block(hidden_dim, state_dim, sequence_length) for _ in range(block_count)],
        nn.Linear(hidden_dim, output_dim),
    )

# Overfitting Test

In accordance with the well-respected 796 tradition, we first test whether the model works by trying to overfit into a single example.

In [18]:
torch.manual_seed(42)

model = SaShiMi(
    input_dim=1,
    hidden_dim=128,
    output_dim=256,
    state_dim=512,
    sequence_length=youtube_mix_transform.sequence_length,
    block_count=2,
).to(device)

overfit_dataset = [train_dataset[0]]

In [19]:
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss()
losses = []

In [26]:
for epoch in tqdm(range(50), leave=False):
    train_loss = 0.0
    for x, y in overfit_dataset:
        y_hat = model(x.view(-1, 1))
        loss = criterion(y_hat, y)

        loss_val = loss.detach().cpu().item()
        train_loss += loss_val / len(dataloader)
        losses.append(loss_val)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

plt.plot(losses)
plt.show()

  0%|          | 0/50 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 116.99 GiB (GPU 0; 15.73 GiB total capacity; 4.48 GiB already allocated; 10.32 GiB free; 4.51 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

That was an anticlimax.

`cauchy_denominator` is of size `[59900, 512]` with each element requiring 8 bytes. `a0 * b0` is `[512, 1, 512]`. Broadcasting them results in a tensor of size `[512, 59900, 512]` which requires 119800 GB.

We need to optimize the computation of Cauchy Kernel...

In [25]:
%pdb

Automatic pdb calling has been turned OFF
