In [None]:
%load_ext autoreload
%autoreload 2

# Dissecting RAVE

Ok, now let's take our scalpel and dissemble RAVE's encoder and decoder. Remembering our detailed schema : 

![RAVE detailed](assets/rave_detailed.png)


we remember that the encoder is composed of 
- a *filter bank* (pseudo-quadratic mirror filtering)
- a sequence of N strided convolution blocks (one conv, one norm, one activation)


and that the decoder is composed of 
- a sequence of residual convolutional blocks
- followed by an optional filtered noise, a loudness coefficient, and a waveform obtained from the residual blocks.

actually, we will use all along this session a *no-PQMF* version of RAVE for the decoder, for reasons we will explain later. We will use `torchbend` to analyse these different layers, and get a more precise idea of the inner DSP of RAVE.

## Dissecting the encoder
Let's dissect the encoder of RAVE : 

In [None]:
import sys; sys.path.append('../torchbend')
import torch
import torchbend as tb; print(tb.__file__)
tb.set_output('notebook')
from dandb import download_models, import_model

model_list = download_models()
model = import_model(model_list["sol_full_nopqmf"])
print(model.encoder)

Let's print out all the steps of the computing graph that calls an inner module, to have an idea of which submodule of the encoder is called along the encoding process: 

In [None]:
import torch 
# trace the forward function
# model.trace(x=torch.randn(1,1,2048), fn="encode")
model.print_graph("encode")

Ok, this is a very complex graph ; indeed, every suboperation (even accessing a subitem of a tensor for example) is graphed, and we can access ALL of them. This a little tedious, especially with an architecture like RAVE's that is based on complex convolutional modules (actually residual and multi-dilation). This makes the analysis tedious, but with a bit of a detective's attitude we will manage to find what we need. First, let's filter all the "uninteresting" steps such as attribut getting, getitem, reshape, copy, etc.

In [None]:
model.print_graph("encode", exclude=[r'getattr', r'getitem', r'copy', r'cat'])

Ok, this is a bit more readable, but still hard to decipher. Though, we can notice that : 
1) the encoder is regularly accessing convolutional weights (`encoder.encoder.net.3.aligned.branches.0.net.1.weight_v` for example) and applying convoluational operations
2) encoder is also regularly applying `sin` and `pow`, ressembling to an activation function

For the 1st point, let's check that : 

In [None]:
model.print_activations("encode", flt='conv')

Observing the shape to the left of the table (and remembering that the shapes are `n_batch, n_channels, n_steps`), we can see that every 8 layers the outputs are downsampled in time, and the channels increased. For our point 2), let's check quickly the configuration of the `.gin` file of the model : 

```
...

# Macros:
# ==============================================================================
ACTIVATION = @blocks.Snake
```

ok, the activation blocks used here are `Snake`, whose definition can be found in `RAVE/rave/blocks.py` after a quick search : 
```
def forward(self, x: torch.Tensor) -> torch.Tensor:
    return x + (self.alpha + 1e-9).reciprocal() * (self.alpha *
                                                    x).sin().pow(2)
```
so it seems that these `pow` operations could indeed correspond the activation output of each layer. Let's check that : 

In [None]:
model.print_activations("encode", flt='pow')

However, in the Snake activation the activations are a little operations later, after a multiplication and an addition with another added value. Then, let's locate the `add_*` `call_functions` that are after these `pow_*` activations :

In [None]:
model.print_activations("encode", flt='add')
model.print_graph("encode", flt='add');


*Bingo!* By double-checking shapes and arguments of respectively activations and graph (args of type `add_*, mul_*`), it seems that the activations we want are `add_4`, `add_9` , `add_14`, and `add_18`. Though, we're not over yet ; indeed, we miss the output of the final layer, that is of size `n_examples x 256 x n_seq`, where 256 is the double of the model's latent size because both mean and standard deviation are encoded (like in traditional variational auto-encoders). We give it to you here, this activation is called `conv1d_28` and is the final convolution that linearly transforms 1024 channels into 256. 

Now that we have the activations we want to target, let's retrieve them using the `get_activations` method of `torchbend` :

In [None]:
import torch
from dandb import get_sounds, plot_1d_activation

sounds = get_sounds()
x = sounds.load('violin.wav', sr=model.sample_rate)

activations = [*['add_%d'%i for i in [4,9,14,18]], 'conv1d_28']
out = model.get_activations(*activations, fn="encode", x=x)
out['conv1d_28'] = out['conv1d_28'][:, :128] # we delete the std part of the latent encoding


# change the number below to update the number of plotted dimensions! 
n_plot_dims = 16
for k, v in out.items():
    print(v.shape)
    act_max = torch.argsort(v.amax(dim=[0, 2]), descending=True)[:16]
    plot_1d_activation(k, v[:, act_max], display=True, channel_idx=act_max.tolist()) 
    
    

Ok, so this is what internal encoder's activations look like. We can see that the period aspects of the activations is quite high, explaining why the latent representations can seem so messy when using a scope of the representation of the RAVE VST. We can also observe some kind of "flattening" or "normalization" effect : at the very beginning of the encoding process the amplitudes roughly follow the amplitude of the sound, while the latter activations get saturated, and more abstract, besides being drastically downsampled (but the number of channels are increasing, such that globally most of the dimensionality of the input is compressed at the very end of the encoding process).  

Features of the encoder can then be understood as the encoding of local features (indexed by the channel index) representing an increasing temporal scope, finally piped to the latent space that actually performs most of the compression. So, what about the decoder?

## Dissecting the decoder

Actually, the decoder is roughly speaking the reverse process : a sequence of upsampling convolutions. If you want, try finding yourself the activations to plot as we did for the encoder! If the courage (or interest) is missing, you can jump the next cell and uncomment the last line of the cell.

In [None]:
import dandb

# try finding the good activations to plot by modifying the flt or exclude keywords below! 
flt = ""
exclude = None
model.print_graph("decode", flt=flt, exclude=exclude)
model.print_activations("decode", flt=flt, exclude=exclude)

target_activations = []
# discouraged or bored? uncomment the line below : 
target_activations = dandb.RAVE_DECODER_ACT_NAMES_DECODE

Ok, let's also plot them as we did for the encoder's activations :

In [None]:
import dandb
from IPython.display import Audio, display
from torchaudio.functional import resample

sounds = get_sounds()
x = sounds.load('violin.wav', sr=model.sample_rate)
n_samples = x.shape[-1]
z = model.encode(x=x)


out = model.get_activations(*target_activations, fn="decode", z=z)

# change the number below to update the number of plotted dimensions! 
n_plot_dims = 16
for k, v in out.items():
    print(k)
    act_max = torch.argsort(v.amax(dim=[0, 2]), descending=True)[:16]
    plot_1d_activation(k, v[:, act_max], display=True, channel_idx=act_max.tolist()) 

    # export as audio files the intermediary activations
    sample_rate = (v.shape[-1] / n_samples) * model.sample_rate
    v_sum = v.sum(-2)
    v_sum = v_sum / v_sum.amax()
    v_sum = resample(v_sum, int(sample_rate), 44100)
    audio = Audio(v_sum.numpy(), rate=44100)
    display(audio)


The upsampling process that is inherent to the RAVE's decoding process is clear here : each step, consisting in an upscaling layer and a sequence of convolutional operations, can be though as a transformation from the latent starting from the low-frequency compontent to the higher frequencies components. The process is non-linear though, such as this process is not exactly assimilable to a "bass2treble" generation ; though, the idea is here. This process is actually the same in decoding-based image generation models such as VAEs or StyleGANs, where the image is progressively upsampled and details added in the generation process. 

## What about the weights? 

1-d convolutional operations is mathematically defined as follows : 

$\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + \sum_{k = 0}^{C_{in} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)$

where ⋆ is the valid cross-correlation operator, $N$ is a batch size, $C$ denotes a number of channels, $L$ is a length of signal sequence. RAVE uses un-biased convolutions, such as convolutions are only represented by their weights of shape `[channels_out, channels_in, kernel_size]`. RAVE also uses strided transposed convolutions, upsampling at the same time the signal and allowing multi-resolution generation (no stride to the left, stride to the right): 

![](assets/conv_transpose_nostride.gif) ![](assets/conv_transpose_stride.gif)

We can location the weights of RAVE's encoder with the following request : 

In [None]:
# model.print_weights(flt=r"encoder.*weight_v")
model.print_weights(flt=r"encoder.*weight_v")

So, how to represent all the kernels of a convolutional operation? Well, by plotting all the 1-to-1 kernel from input to output dimensions ; however, this can be way to large for certain cases, as ours. For exemple, even for the lightest layer there are still 64 * 64 = 4096 convolutions to plot, that is still huge. This is why deep learning models are so difficult to apprehend and are said to be *black-box* : having a concrete idea of the idea of each seperate part is most of the time impossible, as the dimensionality is too high and the inner behavior highly non-linear.

Is this a reason to give up? No! By looking closely, we can see that there are sveral type of convolutions : 
- the convolutions with 1-d kernels, that can be imagined as "mixing" operations between channels
- the convolutions with 3-d kernels, that are "proper" convolution operations
- the convolutions with 8-d kernels and channel augmentation (convolutions with indexes `[0,5,10,15,19]`)

Let's focus on second last ones, plot the amplitude of each convolution kernel, trying to see if some can be avoided (again with this amplitude criterion, that is discutable).

In [None]:
from dandb import plot_kernel_grid
from plotly.express import histogram
# possible layers : 
weight_name = "encoder.encoder.net.2.aligned.branches.0.net.1.weight_v"
param = model.state_dict()[weight_name]

# mix first and second dimension
print(param.shape)
param = param.reshape(-1, param.shape[-1])
print(param.shape)
# take L2 norm
param_norm = param.pow(2).sum(-1).sqrt()  

histogram({'amplitude': param_norm.numpy()}, x="amplitude", height=200).show()

sorted_idx = torch.argsort(param_norm, descending=True)
n_kernels = 64
plot_kernel_grid(param[sorted_idx[:n_kernels]])

Ok, these kernels are not very informative.... Let's plot the longer ones : 

In [None]:
from dandb import plot_kernel_grid
from plotly.express import histogram
# possible layers : 
for layer in [0,5,10,15,19]:
    weight_name = f"encoder.encoder.net.{layer}.weight_v"
    print(weight_name)
    param = model.state_dict()[weight_name]
    # mix first and second dimension
    param = param.reshape(-1, param.shape[-1])
    # take L2 norm
    param_norm = param.pow(2).sum(-1).sqrt()  

    histogram({'amplitude': param_norm.numpy()}, x="amplitude", height=200).show()

    sorted_idx = torch.argsort(param_norm, descending=True)
    n_kernels = 64
    plot_kernel_grid(param[sorted_idx[:n_kernels]], display=True)

We cannot say these weights are much more informative ; yet, this quickly give an idea of the "filterbank" of the successive downsampling layers of the encoder, and we can see that filters closer to the waveform have learned some kind of periodicity and are quite simalar, while the kernels in the higher level differentiate a little more. All these kernels can be thought as the "dictionary" of the encoding, the plotted ones being the most significant words. 

### What about the decoder?

Let's do the same kind of operation with the decoder : 

In [None]:
model.print_weights(flt=r"decoder.*weight_v")

We can see that the process is very similar. Let's plot our most significant convolution weights for layers with channel reduction :

In [None]:
from dandb import plot_kernel_grid
from plotly.express import histogram
# possible layers : 
for layer in [2,6,11,16,21]:
    weight_name = f"decoder.net.{layer}.weight_v"
    print(weight_name)
    param = model.state_dict()[weight_name]
    # mix first and second dimension
    param = param.reshape(-1, param.shape[-1])
    # take L2 norm
    param_norm = param.pow(2).sum(-1).sqrt()  

    histogram({'amplitude': param_norm.numpy()}, x="amplitude", height=200).show()

    sorted_idx = torch.argsort(param_norm, descending=True)
    n_kernels = 64
    plot_kernel_grid(param[sorted_idx[:n_kernels]], display=True)

How is this informative? Not very much. However, this allows to get a feeling of what's happening inside, looking at all this little wavelets that are summed and processed along the decoding process. Especially, visualizing this can give some ideas to actually *mess up* with these kernels and activations, knowing a little bit more what to target and which modifications to do. We will see that this afternoon! 