# Dissecting RAVE with `torchbend`

Welcome to this tutorial! The idea here is to use the `torchbend` library to dissect the inner guts of a general machine learning model, and especially `RAVE` models that we will learn to bend this afternoon. If you never used a Jupyter Python notebook, its utilisation is quite simple: 
- Execute the cells containing the python code on by one by clicking on the `Run` button on the top toolbar (or Shift + Enter)
- If you feel like it, you can change some of the variables to play a little bit with the code! 

Try that with the code cell below, and read it carefully : 

In [1]:
# ok, we are in a Python cell! 
# <- this sharp symbol means that this line is commented, meaning that it is just text and no code.
# below, we define some variables using the "=" set operator

# print is used to output text below the code cell, where the outputs are.
print('Testing python stuff : ')


# Execute the cell and observe the outputs!
# Try changing the variables below, and observe how the output change.
number = 3
other_number = 4.3
string = "hello world!"


list_of_things = [3, 4, 5, 2]
print(list_of_things[1]) # lists are indexed by integrals

dictionary = {'a': 3, 'b': 4}
print(dictionary['a']) # dictionary indexed by keys


# example of conditional codes : 
if (number == 3):
    print('number is three here!')
elif (number == 4):
    print('number is four here!')
else:
    print('number is... something')


# example of loops : 
for i in range(4):
    print("current i value : ", i)
# example of looping in elements of a list :
for v in list_of_things:
    print(v)

# defining a function with def
def square(x):
    return x * x
print(3, square(3))
print(4, square(4))

# defining an object, that is a set of attributes and functions
class Object():
    def __init__(self, a, b):
        # an initialization function
        self.a = a
        self.b = b

    def describe(self):
        print("a: ", self.a)
        print("b : ", self.b)

obj1 = Object(1, 2)
obj2 = Object("hello", "goodbye")
# call methods
obj1.describe()
obj2.describe()
# get attributes
a = obj1.a
b = obj2.b

Testing python stuff : 
4
3
number is three here!
current i value :  0
current i value :  1
current i value :  2
current i value :  3
3
4
5
2
3 9
4 16
a:  1
b :  2
a:  hello
b :  goodbye


## What is a machine learning model in Python?

There are many machine learning libraries for Python : [Pytorch](https://pytorch.org/), [Tensorflow](https://www.tensorflow.org/?hl=fr), [jax](https://github.com/jax-ml/jax)... All of these have their own logic, but are always based on a similar architecture of what a ML model is : 

- a set of **weights**, that are typically trained during a training process
- a **computing graph**, that describes how the paramters are used to process the inputs.

`torchbend` is a library allowing to analyse both parameters and computing graphs of a machine learning model, and also allows to hack both to perform creative operations to bend existing machine learning models. Let's describe this logic with three steps: 
1) a dumb machine learning model
2) an additive synthesizer
3) a pre-trained RAVE model

Ready? Let's go!

### A simple and useless machine learning model

In [2]:
import torch 

# a machine learning model in torch is generally implemented using a torch.nn.Module as below
# this class is a dumb module applying two linear transformations : 
# out = A * x + B
# with a non linearity (a simple tanh function) inside. 

class Foo(torch.nn.Module):
    # initialization method
    def __init__(self, in_dim, hidden_dim, out_dim):
        super().__init__()
        # we initialize here two different linear modules, that are two linear transformations of the input (out = A * x + b):
        self.linear_1 = torch.nn.Linear(in_dim, hidden_dim)
        self.linear_2 = torch.nn.Linear(hidden_dim, out_dim)
        # we also init a non-linearity module, called nnlin
        self.nnlin = torch.nn.Tanh()

    # definition of how the data is processed
    def forward(self, x: torch.Tensor):
        out = self.linear_1(x)
        print("first layer shape : ", out.shape)
        out = self.linear_2(out)
        print("second layer shape : ", out.shape)
        return out


in_dim = 4
hidden_dim = 80
out_dim = 16

# let's create a model
module = Foo(in_dim, hidden_dim, out_dim)

# let's print this model
print(module)

Foo(
  (linear_1): Linear(in_features=4, out_features=80, bias=True)
  (linear_2): Linear(in_features=80, out_features=16, bias=True)
  (nnlin): Tanh()
)


We can see that this module has three different submodule : 
- *linear_1*, the first linear transformation (expanding input dimension 4 to 80), 
- *linear_2*, the second linear transformation (expanding output dimension 80 to 16), 
- *nnlin*, a simple object representing the `Tanh` function. 

Let's use this module to process an input : 

In [3]:
# batch numbers is the number of different examples to process
n_batch = 4

# we create different examples of inputs f `in_dim` dimensions: 
x = torch.randn(n_batch, in_dim)
print('-- input : ')
print(x)

# process input
print("-- processing : ")
out = module(x)

# the output is a set of different examples of `out_dim` dimensions
print('-- output : ')
print(out)

-- input : 
tensor([[-1.0926, -0.3753,  0.9194, -1.4237],
        [ 0.8316,  0.4937, -0.4718, -1.3776],
        [-0.9213, -0.3054,  0.1044,  1.6254],
        [ 0.5309, -0.5009,  0.4503,  2.2799]])
-- processing : 
first layer shape :  torch.Size([4, 80])
second layer shape :  torch.Size([4, 16])
-- output : 
tensor([[ 0.9726, -0.8537,  0.7230, -0.1153, -0.2366, -0.0328, -0.0422,  0.0116,
          0.1455, -0.2930, -0.0625,  0.0289, -0.2523, -0.0250,  0.3223,  0.2894],
        [ 0.1882, -0.6582,  0.2066, -0.3789, -0.1684, -0.1852, -0.4643,  0.6144,
          0.3161,  0.1304,  0.2667, -0.4079,  0.1778, -0.4463,  0.2174,  0.3992],
        [-0.3142,  0.4904, -0.2038,  0.5743, -0.2762,  0.0994,  0.4435, -0.0547,
          0.0370,  0.5079,  0.3332,  0.2003, -0.4797,  0.7250,  0.0857, -0.1943],
        [-1.1852,  0.8448, -0.1709,  0.5016, -0.1405, -0.2295,  0.1248,  0.2963,
         -0.0640,  0.6052,  0.3877, -0.1702, -0.1692,  0.6749, -0.2559, -0.1670]],
       grad_fn=<AddmmBackward0>)


Ok ! Now let's see here what are the *parameters*, and what is the *graph* of this simple module with `torchbend`.

In [4]:
import torchbend as tb

# wraps the existing module inside a BendedModule object allows to analyse any module
bended_module = tb.BendedModule(module)
# "trace" is needed to analyse the computing graph of our module. 
bended_module.trace(x=x)

# print paramters
print('Weights : ')
bended_module.print_weights()

print('\nGraphs : ')
print(bended_module.graph().print_tabular())

print('\nActivations : ')
# print activations
bended_module.print_activations();

first layer shape :  ShapeAttribute(root=BendingProxy(linear_1), value=shape)
second layer shape :  ShapeAttribute(root=BendingProxy(linear_2), value=shape)
Weights : 
name             shape                 dtype                 min       max         mean     stddev
---------------  --------------------  -------------  ----------  --------  -----------  ---------
linear_1.weight  torch.Size([80, 4])   torch.float32  -0.49725    0.49778   -0.00486112  0.295011
linear_1.bias    torch.Size([80])      torch.float32  -0.491403   0.492119  -0.021417    0.287867
linear_2.weight  torch.Size([16, 80])  torch.float32  -0.111686   0.111333  -0.00196044  0.0639649
linear_2.bias    torch.Size([16])      torch.float32  -0.0992248  0.108824  -0.00208028  0.0726927

Graphs : 
opcode       name      target    args         kwargs
-----------  --------  --------  -----------  --------
placeholder  x         x         ()           {}
call_module  linear_1  linear_1  (x,)         {}
call_module  linear_2  

We can see, in the output above : 
- the **weights** of the module, that are for our module the parameters of the linear transformation (A * x + B, where A is the weight and B the bias)
- the **graph** of the module, that are all the operations made from the input `x` to the output `output`
- the **activations**, that are the *intermediary* values obtained from the input `x`.

We could describe the difference between weights and activations like this : 

**FAIS LE SCHEMA**

typically, weights are modified during the model's learning process, but do not change when using the model. The graph describes what operations are done to the model's inputs, and *activations* are all the intermediary values processed by the model's computing graph, and are then different for different inputs. 



### A additive synthesizer in Pytorch 

To make it more clear, let's take an example that should speak to you in a little less abstract way : an additive synthesizer. Indeed, an additive synthesizer can be described in a similar way than machine learning module, and will allow to make the distinction between *weights* and *activations* clearer. 

Let's define our additive synthesizer :

In [5]:
from IPython.display import Audio
import torch, torch.nn as nn
import sys; sys.path.append('..')
import torchbend as tb

class Joseph(nn.Module):
    def __init__(self, f0, n_partials, fs=44100):
        super().__init__()
        self.f0 = nn.Parameter(torch.full((1, 1, 1), f0), requires_grad=False)
        self.f_mult = nn.Parameter(torch.arange(1, n_partials+1).unsqueeze(-1), requires_grad=False)
        self.amps = nn.Parameter(torch.ones(1, n_partials, 1), requires_grad=False)

    def forward(self, t):
        t = t.unsqueeze(-2)
        freqs = self.f0 * self.f_mult
        waves =  torch.sin(2 * torch.pi * freqs * t) 
        waves = waves * torch.nn.functional.softmax(self.amps, dim=-2)
        out = waves.sum(-2)
        return out

T = 2.0
fs = 44100
module = Joseph(110, 4, fs)
t = torch.linspace(0., T, int(T*fs))

wave = module(t[None])
Audio(wave.numpy(), rate=fs)

Here, the module `Joseph` (named after Joseph Fourier of course) takes a `t` input, representing time index for each sample (in seconds), and generates a waveform with the first `n_partials` harmonics. What are the weights here, and what are the activations? Take a time to think, and execute the cell below to get the answer.

In [6]:
import torchbend as tb

# wraps the existing module inside a BendedModule object allows to analyse any module
bended_module = tb.BendedModule(module)
# "trace" is needed to analyse the computing graph of our module. 
bended_module.trace(t=t)

# print paramters
print('Weights : ')
bended_module.print_weights()

print('\nGraphs : ')
print(bended_module.graph().print_tabular())

print('\nActivations : ')
# print activations
bended_module.print_activations();

Weights : 
name    shape                  dtype            min    max    mean     stddev
------  ---------------------  -------------  -----  -----  ------  ---------
f0      torch.Size([1, 1, 1])  torch.int64      110    110   110    nan
f_mult  torch.Size([4, 1])     torch.int64        1      4     2.5    1.29099
amps    torch.Size([1, 4, 1])  torch.float32      1      1     1      0

Graphs : 
opcode         name       target                                               args                      kwargs
-------------  ---------  ---------------------------------------------------  ------------------------  --------------------------------------------
placeholder    t          t                                                    ()                        {}
call_method    unsqueeze  unsqueeze                                            (t, -2)                   {}
get_attr       f0         f0                                                   ()                        {}
get_attr      

  stdval = value.float().std()


Ok, this is a little more complicated, but with a little attention it is quite easy to understand everything. 

**Weights.** We can see that `Joseph` has three different paramters : 
- `f0`: fundamental frequency of the module
- `f_mult` : the frequency multiplier of each save's partial
- `amps`: the weights of each partials.
These values describe the **weights** of the model : indeed, they define the module's behavior and do not change across different examples. Typically, these could be trained to learn how to reproduce a sound, in a similar way than [DDSP](https://github.com/magenta/ddsp). 

**Graph.** A little more difficult to read, but by carefull reading every line of the ouptut you should be able to locate the corresponding operation in the code. 
The `opcode` column describes the operation : 
- *placeholder* is an input
- *get_attr* means retrieving the `target` parameter for the module, as here with `f0` and `f_mult`
- *call_function* means calling the function `target` to given arguments (in the `args` and `kwargs` columns)
- *call_method* means calling the method `target` of a given object
- *output* represents the output of the computing graph

**Activations.** The list of all the intermediary values of the processing graph, that are actually the output of all the operations described by the graph. You can the shape of every activation on the table, that can change for different input shapes in the `.trace(t=t)` step. Analyzing graphs / activations can be a little tedious, but here it can be quite simply done, as with the following examples : 

- the activation `mul` is the multiplication of `f0` and `f_mult`, and the corresponds to the line `freqs = self.f0 * self.f_mult`. `mul` is then the frequency of each partial.
- the activation `sin` is the application of the sinus function on the phase vector, that is here `mul_2`.
- the activation `mul_3` is the multiplication of `softmax` (the normalized amplitude for each partial) and `sin`, and then represents the balanced partial.
- the activation `sum_1` is the sum of all the partials, hence the final sinewave. 

Let's plot all the activations corresponding to temporal values (you can distinguish them by checking their shapes : all the `(*,*,88200)` logically corresponds to time series). You can easily do that with `torchbend` using the `get_activations` method : 

In [8]:
from dandb import plot_1d_activation

activations = ["t", "mul_2", "sin", "mul_3", "sum_1"]
activations = bended_module.get_activations(*activations, t=t)

for activation_name, activation_value in activations.items():
    # plot given activation
    plot = plot_1d_activation(activation_name, activation_value)
    # show plot
    plot.show()

Ok