In [None]:
%load_ext autoreload
%autoreload 2

## Bending RAVE for neural corrupted audio generation 

In [None]:
from functools import partial
import os 

from IPython.display import display, Audio
import json 
import numpy as np
import panel as pn
pn.extension()
import torch
torch.set_grad_enabled(False)
import torchbend as tb
import rave 
 
from dandb import *

### First step : Importing the model and creating data loader

Before generating audio with RAVE and bending the model's weights/activations, we first need to import the model. For convenience, we will rely on a model which was trained on audios from [the Studio On Line dataset](https://forum.ircam.fr/projects/detail/tinysol/), which contains recordings from a variety of musical instruments. However, all the techniques illustrated in this notebook can be applied to any kind of RAVE, notably models trained using your own datasets ! (more details [here](https://github.com/acids-ircam/RAVE)). 

In [None]:
RAVE_SAMPLE_RATE = 44100 # the RAVE we'll use was trained using 44.1 kHz audio data

with open('paths.json', 'r') as f:
    paths = json.load(f)
    checkpoint_path = paths['checkpoint_path']
    dataset_path = paths['data_path']

# Create the model
rave_model = load_rave(checkpoint_path)

# Create the data loader, which is basically a container for your audio samples
audio_loader = make_loader(dataset_path, bs=8, num_workers=8)
x = next(iter(audio_loader))
x_rec = rave_model(x)

In [None]:
for audio_example, reconstructed_audio_example, _ in zip(x.squeeze().numpy(), x_rec.squeeze().numpy(), range(4)):
    print('Original audio')
    display(Audio(data=audio_example, rate=RAVE_SAMPLE_RATE))
    print('Reconstructed audio')
    display(Audio(data=reconstructed_audio_example, rate=RAVE_SAMPLE_RATE))
    print('='*32)

### Second step : creating and exploiting a bended RAVE

In [None]:
bended_rave = tb.BendedModule(rave_model)
bended_rave.trace(x=x)
print('You may now bend RAVE')

#### Extracting activations and computing activations similarity

In [None]:
rave_activations = extract_activations(
    bended_rave, 
    RAVE_DECODER_ACT_NAMES, 
    audio_loader, 
    avg_batch=True, 
    max_batches=20
)
activations_similarity = compute_activations_similarity(
    rave_activations
)

activations_clusters = compute_clusters(
    activations_similarity, 
    threshold=0.75
)

activations_clusters = compute_non_singleton_clusters(
    activations_clusters
)
activations_clusters = sort_clusters(activations_clusters)

### Creating bending callbacks and finally applying those to RAVE intermediate activations !

In [None]:
affine_benders = make_affine_bending_modules(RAVE_DECODER_ACT_NAMES)
clustered_affine_benders = make_clustered_bending_callbacks(
    affine_benders, 
    activations_clusters
)

In [None]:
op_widgets = make_widgets(RAVE_DECODER_ACT_NAMES)
update_bending_params = partial(
    update_affine_params, 
    clustered_affine_cb=clustered_affine_benders
)

def _bend_rave(
    bended_rave_model, 
    audio_batch, 
    audio_batch_rec, 
    **params
):
    bending_callbacks = update_bending_params(
        **params)
    bended_audios = get_bended_rave_audios(
        audio_batch, 
        bended_rave_model, 
        bending_callbacks, 
    )
    audio_grid = make_audio_grid(audio_batch_rec, 
                                 bended_audios, 
                                 sr=RAVE_SAMPLE_RATE, 
                                 ncols=3)
    return audio_grid

def make_widget_box(op_names, **widgets):
    widget_boxes = {op_name: [] for op_name in op_names}
    for widget_name, widget in widgets.items():
        op_name = widget_name.split('/')[0]
        widget_boxes[op_name].append(widget)
    return pn.Row(*[pn.WidgetBox(*op_widget) for op_widget in widget_boxes.values()])
    

bend_rave = partial(_bend_rave, 
                    bended_rave_model=bended_rave, 
                    audio_batch=x, 
                    audio_batch_rec=x_rec)

dynamic_bend_rave = pn.bind(bend_rave, 
                            **op_widgets)
pn.Column(
    make_widget_box(RAVE_DECODER_ACT_NAMES, **op_widgets), 
    dynamic_bend_rave
)
    