# Quickstart for audio-heads

Welcome to `torchbend`, a high-level framework for dissecting, analyzing and bending machine learning models programmed with [Pytorch](https://pytorch.org/docs/stable/index.html). This framework extends `torch.fx` and proposes convenient methods to target certain activations of a network, bend its parameters or internal values, and easily perform some [active divergence](https://arxiv.org/pdf/2107.05599) techniques to unbound the co-creative potentialities of neural networks. 

In this tutorial, we will make a short tour of `torchbend` and how it can be applied to existing modules for audio with three different models :
- a naive additive synthesis implementation to get familiar with `torchbend` concepts
- [audiocraft](https://github.com/facebookresearch/audiocraft)
- [RAVE](https://github.com/acids-ircam/RAVE), and how to perform real-time bending with [nn~](https://github.com/acids-ircam/nn_tilde). 

See the accompanying notebooks for [image](0_getting_started_image.ipynb) and [text](0_getting_started_text.ipynb).

## How does `torchbend` works?

In `pytorch`, a machine learning model can be described by two structures : 
- a set of parameters, trained during a training process
- a computational `graph`, that represents the ensemble of operations achieved on inputs through parameters. 

### Weights and activations

To see the difference, let's make a naive neural additive synthesizer and how to bend it with `torchbend` :

In [1]:
from IPython.display import Audio
import torch, torch.nn as nn
import sys; sys.path.append('..')
import torchbend as tb

class Joseph(nn.Module):
    def __init__(self, f0, n_partials, fs=44100):
        super().__init__()
        self.f0 = nn.Parameter(torch.full((1, 1, 1), f0), requires_grad=False)
        self.f_mult = nn.Parameter(torch.arange(1, n_partials+1).unsqueeze(-1), requires_grad=False)
        self.amps = nn.Parameter(torch.ones(1, n_partials, 1), requires_grad=False)

    def forward(self, t):
        t = t.unsqueeze(-2)
        freqs = self.f0 * self.f_mult
        waves =  torch.sin(2 * torch.pi * freqs * t) 
        waves = waves * torch.nn.functional.softmax(self.amps, dim=-2)
        out = waves.sum(-2)
        return out

T = 2.0
fs = 44100
module = Joseph(110, 4, fs)
t = torch.linspace(0., T, int(T*fs))

wave = module(t[None])
Audio(wave.numpy(), rate=fs)

What are the weights of the module? Let us use `torchbend` to print these weights : 

In [2]:
bended = tb.BendedModule(module)
bended.print_weights();

name    shape                  dtype            min    max    mean     stddev
------  ---------------------  -------------  -----  -----  ------  ---------
f0      torch.Size([1, 1, 1])  torch.int64      110    110   110    nan
f_mult  torch.Size([4, 1])     torch.int64        1      4     2.5    1.29099
amps    torch.Size([1, 4, 1])  torch.float32      1      1     1      0


  stdval = value.float().std()


We can see that `Joseph` has three different paramters : 
- `f0`: fundamental frequency of the module
- `f_mult` : the frequency multiplier of each save's partial
- `amps`: the weights of each partials.

Let us now print the computing graph of the module. This needs first to `trace` the forward function of the object with a given input : 

In [3]:
bended.trace(t=t)
print('Graph : ')
bended.graph().print_tabular()

print('')
print('Activations : ')
bended.print_activations();

Graph : 
opcode         name       target                                               args                      kwargs
-------------  ---------  ---------------------------------------------------  ------------------------  --------------------------------------------
placeholder    t          t                                                    ()                        {}
call_method    unsqueeze  unsqueeze                                            (t, -2)                   {}
get_attr       f0         f0                                                   ()                        {}
get_attr       f_mult     f_mult                                               ()                        {}
call_function  mul        <built-in function mul>                              (f0, f_mult)              {}
call_function  mul_1      <built-in function mul>                              (6.283185307179586, mul)  {}
call_function  mul_2      <built-in function mul>                              (m

We can, by printing the graph, see all the operations done by the module on the input `t`. Using the `print_activations` method of the `tb.BendedModule`, we can summarize all the intermediary values of the computing graph, called `activations`, with the corresponding shape. Analyzing graphs / activations can be a little tedious, but here it can be quite simply done, as with the following examples : 

- the activation `mul` is the multiplication of `f0` and `f_mult`, and the corresponds to the line `freqs = self.f0 * self.f_mult`. `mul` is then the frequency of each partial.
- the activation `sin` is the application of the sinus function on the phase vector, that is here `mul_2`.
- the activation `mul_3` is the multiplication of `softmax` (the normalized amplitude for each partial) and `sin`, and then represents the balanced partial.
- the activation `sum_1` is the sum of all the partials, hence the final sinewave. 

`torchbend` can bend both weights (the trained parameters of the module) and activations (intermediary values of a computing graph) in an almost infinte way, and even export the bended module with `torch.jit` for embedding the model into C++ code. Let's see how to bend this tiny generator :-) 

### Bending weights

Bending a weight means : applying transformations to the original weight of a module. These transformations all derive from `tb.BendingCallback` object, and are available in the documentation list here. For example, let's alternatively multiply the `f0` parameter by 4 with the `tb.Scale` callback, and the `f_mult` parameter by 1.3 : 

In [4]:
T = 2.0
fs = 44100
module = Joseph(110, 4, fs)
bended = tb.BendedModule(module)

# Original wave
t = torch.linspace(0., T, int(T*fs))
wave = bended(t[None])

# multiply f0 by 4 with tb.Scale
bended.bend(tb.Scale(4.0), 'f0')
wave_bended = bended(t[None])

# reset the bending
bended.reset()

# multiply f_mult by 1.3 with tb.Scale
bended.bend(tb.Scale(1.3), 'f_mult')
wave_bended_2 = bended(t[None])
Audio(wave.numpy(), rate=fs)

wave_out = torch.cat([wave, torch.zeros(1, fs), wave_bended, torch.zeros(1, fs), wave_bended_2], -1)

Audio(wave_out.numpy(), rate=fs)

Let's make a few remarks : 
- In the first bending, `tb.Scale(4.)` multiplies before the processing step the parameter `f0` by 4, resulting in a 2-octave transposition of the original sinewave. 
- The `bend` method does not erase the original module, such that every bendable operation is revertible with the `reset` method : 

Another important thing with `torchbend` is that, when using `bend`, everything target is actually a *regular expression*, implying that a single key way target several parameters. For example, we can modify `f0` and `f_mult` at the same time by bending the `f.*` target, that basically means "anything starting with `f`" (for more explanation, see Python's [re]() package).

In [5]:
bended.reset()
bended.bend(tb.Scale(2.), 'f.*')

# you can get all the bended entries of a BendedModule by calling the bended_keys method :
bended.bended_keys()
wave_bended = bended(t[None])

# f0 and every f_mult is multiplied by 2
Audio(wave_bended, rate=fs)

### Bending activations

Bending activations is strictly similar to bending weights, but may require a little more vigilance as transforming the computing graph may lead to errors. Before bending, `torchbend` also allows to easily access these internal activations by using the `get_activations` method : 

In [6]:
from plotly.subplots import make_subplots
import plotly.graph_objects as go

T = 2.0
fs = 44100
t = torch.linspace(0., T, int(T*fs))

module = Joseph(110, 4, fs)
bended = tb.BendedModule(module)
bended.trace(t=t)

# here we have to put a $ at the end of "mul$" to specify that it must be the last character
out = bended.get_activations(r"mul$", r"sin", r"mul_3", r"sum_1", t=t)

print("output keys : ", out.keys())
print('mul activation : ', out['mul'].squeeze())

# Plotting sin activation
fig = make_subplots(rows=2, cols=2)
for i in range(4): 
    y = fig.add_trace(go.Scatter(y=out['sin'][0, i, :int(0.1 * fs)].numpy()), row=int(i%2)+1, col=int(i//2)+1)
fig.show();

output keys :  dict_keys(['mul', 'sin', 'mul_3', 'sum_1'])
mul activation :  tensor([110, 220, 330, 440])


We can see that the `sin` activation has, among its second dimension, the different partials of its sine wave. Let's apply a weird partial wise transformation using the `tb.Lambda` callback, that allows use to use an arbitrary function to a specific activation : 

In [7]:
def waveshape(x):
    # square wave
    x[..., 0, :] = (x[..., 0, :] > 0).float()
    # differentiate
    x[..., 1, :] = (x[..., 1, :] - x[..., 1, :].roll(-1, -1)) / 2 * 100
    # add noise
    x[..., 2, :] = x[..., 2, :] + 0.2 * torch.randn_like(x[..., 2, :])
    # ???
    x[..., 3, :] = x[..., 2, :].pow(8) % 0.4
    return x

bended.bend(tb.Lambda(waveshape), "sin")
out = bended.get_activations(r"mul$", r"sin", r"mul_3", r"sum_1", t=t)
# Plotting sin activation
fig = make_subplots(rows=2, cols=2)
for i in range(4): 
    y = fig.add_trace(go.Scatter(y=out['sin'][0, i, :int(0.1 * fs)].numpy()), row=int(i%2)+1, col=int(i//2)+1)
fig.show();

 Activation bending is a little more sensitive though ; imagine that you would like to add a fifth partial to the wave, and do it in a way that would imply a change of shape of one of the internal activation : 

In [15]:
def wrong_add_partial(x):
    new_partial = torch.sqrt(1 - x[..., [0], :].pow(2)) # cos from sin
    x = torch.cat([x, new_partial], -2)
    return x

bended.reset()
bended.bend(tb.Lambda(wrong_add_partial), "sin")
try: 
    out = bended.get_activations(r"mul$", r"sin",t=t)
except Exception as e:
    print("Got error : ", e)


Got error :  The size of tensor a (5) must match the size of tensor b (4) at non-singleton dimension 1


Traceback (most recent call last):
  File "/Users/domkirke/miniconda3/envs/ml2/lib/python3.11/site-packages/torch/fx/graph_module.py", line 303, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/domkirke/miniconda3/envs/ml2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/domkirke/miniconda3/envs/ml2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<eval_with_key>.8", line 15, in forward
    mul_3 = sin_bended * softmax;  softmax = None
            ~~~~~~~~~~~^~~~~~~~~
RuntimeError: The size of tensor a (5) must match the size of tensor b (4) at non-singleton dimension 1

Call using an FX-traced Module, line 

This error is because the callback `wrong_add_partial` adds a dimension to the `sin` activation, transforming its shape from `torch.Size([1, 4, 88200])` to `torch.Size([1, 5, 88200])`. However, the subsequent operation multiplies it with another tensor `softmax` of shape `torch.Size([1, 4, 1])`, leading to a dimension mismatch. Hence, activation bending must be very careful on the modification it applies on the computational graph, and are very dependnant of the bended model.

### Summary

Let's summarize what we learnt so far :

- machine learning modules are composed of two parts : a set of trained `parameters`, and a computational `graph` with a set of intermediary `activations`
- `torchbend` can bend two different types of values : `paramters` and `activations` 
- any activation can be retrieved with the `get_activations` method
- any key can be bent in a non-destructive way with the `bend` method

If you are more interested into putting hands in network bending, do not hesitate to read the subsequent notebooks ; but first, let's apply all of that with actual audio neural generators, that provide dedicated bending interfaces for some audio generators.

## Bending MusicGen and AudioGen

MusicGen and AudioGen are two open-source text2audio models comprised in Meta's [AudioCraft](https://github.com/facebookresearch/audiocraft) project. While `torchbend`'s main object is `BendedModule`, the package also offers even higher-order objects, called *interfaces*, that address specific packages and generators from the open source community (all derived from an object called `BendingInterface`.) Let's see how to use that here!

### Bending MusicGen 
The `BendedMusicGen` and `BendedAudioGen` interfaces are used similarly to original `MusicGen` and `AudioGen` objects, except that they all import the bending routines from the `BendingModule` object.

In [None]:
from IPython.display import Audio
import sys; sys.path.append('..')
import torchbend as tb
from torchbend.interfaces.audiocraft import BendedMusicGen

bended = BendedMusicGen('facebook/musicgen-small', cache_dir = "/tmp")
prompt = ["elevator music from the 50s"]
bended.set_generation_params(duration=5) 
out = bended.generate(prompt)
Audio(out[0].numpy(), rate=bended.sample_rate)

  return torch.load(file, map_location=device)
  WeightNorm.apply(module, name, dim)




Now, let's try to bend some weights. Let's print the weights of the `BendedMusicGen` object : 

In [5]:
bended.print_weights();

name                                                               shape                        dtype                    min           max          mean        stddev
-----------------------------------------------------------------  ---------------------------  -------------  -------------  ------------  ------------  ------------
compression_model.encoder.model.0.conv.conv.bias                   torch.Size([64])             torch.float32   -0.250833      0.179229     -0.00761436     0.0905939
compression_model.encoder.model.0.conv.conv.weight_g               torch.Size([64, 1, 1])       torch.float32    0.293595      1.27896       0.786986       0.161317
compression_model.encoder.model.0.conv.conv.weight_v               torch.Size([64, 1, 7])       torch.float32   -0.810647      0.606752     -0.00172898     0.2686
compression_model.encoder.model.1.block.1.conv.conv.bias           torch.Size([32])             torch.float32   -0.35318       0.0923138    -0.0886009      0.0898439
compre

Ok, that's way more messy thant our simple additive synthesizer... But don't give up, actually the module is not really more complicated that most neural generators, and the weights are named in a hierarchical manner : the name `compression_model.decoder.model.1.lstm.weight_ih_l0` says that it points the compression model, then the decoder, then the first element of the `model` attribute, then the LSTM, then the `weight_ih_l0` (meaning : first input-to-hidden weight of the 1th layer). Of course, if you don't know the architecture of the model, this is is totally obscure ; though, you at least know that you target the compression model's decode. This is why `bend` takes string patterns instead of individual weights : trying every parameter one by one would be very long, and it is generally more convenient to target every part or subpart of a module. 

For example, here, all we need to know is that `MusicGen` is made out of two main parts (and we can see that from the weight names) : 
- a `compression_model`, that is actually the neural codec use to generate sounds from latents
- a `lm` (lm stands for language model), that converts the text to the latent.

The `compression_model` has also two modules : an `encoder` and a `decoder`. Here we will only bend the decoder, as the encoder (audio -> latents) is not used in a txt2audio manner. The main module of the `lm` is the `transformer` model, that is an attention model that is used to transform the text sequence to a latent sequence. Hence, without going to further details, we know that : 
- bending the `lm` module will alter how `AudioGen` converts the text into audio queries (encoded as latents)
- bending the `compression_model` will alter how the model converts the audio queries in raw audio waveforms.

A little in the clouds 😶‍🌫️? Maybe it's better to listen to some examples 🥳

In [None]:
# bend language model (lm)
callback = tb.Mask(0.4)

# bend compression_model
audios_compression = []
for layer in [1, 4, 13]:
    bended.reset()
    bended.bend(callback, r"compression_model.*%d.*weight.*"%layer, r"compression_model.*%d.*bias.*"%layer)
    audios_compression.append(bended(prompt))
    print('-- bending layer %s with callback %s...')%(layer, callback)
    print('bended keys : ', bended.bended_keys())
    
audios_lm = []
for layer in [1, 6, 18]:
    bended.reset()
    bended.bend(callback, r"lm.transformer.%d.weight"%layer, r"lm.transformer.%d.bias"%layer)
    audios_lm.append(bended(prompt)) 
    print('-- bending layer %s with callback %s...')%(layer, callback)
    print('bended keys : ', bended.bended_keys())
    

## Bending RAVE