In [62]:
%load_ext autoreload
%autoreload 2

## Bending RAVE for neural corrupted audio generation 

In [63]:
from functools import partial
import os 

from IPython.display import display, Audio, Markdown
import json 
import numpy as np
import panel as pn
pn.extension()
from plotly.subplots import make_subplots
import plotly.graph_objects as go
import torch
torch.set_grad_enabled(False)
import torchbend as tb
import rave 
 
from dandb import *

### First step : Importing the model and creating data loader

Before generating audio with RAVE and bending the model's weights/activations, we first need to import the model. For convenience, we will rely on a model which was trained on audios from [the Studio On Line dataset](https://forum.ircam.fr/projects/detail/tinysol/), which contains recordings from a variety of musical instruments. However, all the techniques illustrated in this notebook can be applied to any kind of RAVE, notably models trained using your own datasets ! (more details [here](https://github.com/acids-ircam/RAVE)). 

In [64]:
RAVE_SAMPLE_RATE = 44100 # the RAVE we'll use was trained using 44.1 kHz audio data

with open('paths.json', 'r') as f:
    paths = json.load(f)
    checkpoint_path = paths['checkpoint_path']
    dataset_path = paths['data_path']

# Create the model
rave_model = load_rave(checkpoint_path)

# Create the data loader, which is basically a container for your audio samples
audio_loader = make_loader(dataset_path, bs=8, num_workers=8)
x = next(iter(audio_loader))
x_rec = rave_model(x)


`torch.nn.utils.weight_norm` is deprecated in favor of `torch.nn.utils.parametrizations.weight_norm`.


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



In [65]:
for audio_example, reconstructed_audio_example, _ in zip(x.squeeze().numpy(), x_rec.squeeze().numpy(), range(4)):
    print('Original audio')
    display(Audio(data=audio_example, rate=RAVE_SAMPLE_RATE))
    print('Reconstructed audio')
    display(Audio(data=reconstructed_audio_example, rate=RAVE_SAMPLE_RATE))
    print('='*32)

Original audio


Reconstructed audio


Original audio


Reconstructed audio


Original audio


Reconstructed audio


Original audio


Reconstructed audio




### Second step : creating and exploiting a bended RAVE

In [66]:
bended_rave = tb.BendedModule(rave_model)
bended_rave.trace(x=x)
print('You may now bend RAVE')

You may now bend RAVE


#### Extracting activations and computing activations similarity

Extracting groups (or clusters) of similar activations is achieved through three steps. First, we feed multiple audio samples to the model, and record the outputs of each activation for which we want to compute clusters. Then, for a given layer, we compute the cross-similarity between each activations. Precisely, for a given input, the output of a specific activation will be a set of multiple features $f_1, f_2, ..., f_N$. We define a similarity function $s$, which takes two features as inputs, and returns a similarity score between those two. Namely, two identical features will produce $s(f_i, f_j)=1$, whereas completely uncorrelated features will have a score of 0 

In [114]:
# Example 
t = torch.linspace(0, 1, 200).unsqueeze(0)
feature_1 = torch.cos(20*t)
feature_2 = torch.cos(21*t)+.01*torch.randn_like(feature_1)
feature_3 = 3*t+1
feature_4 = 5*t+1.1

fig = make_subplots(rows=1, cols=4, subplot_titles=[f'Feature {idx+1}' for idx in range(4)])
for idx, feature in enumerate([feature_1, feature_2, feature_3, feature_4]):
    fig.add_trace(
        go.Scatter(
            y=feature.squeeze().numpy(), 
            x=t.squeeze().numpy(), 
            name = f'Feature {idx+1}'
            ), 
        row=1, col=idx+1)
fig.update_layout(showlegend=False, title='Illustration for deep features (activations) of deep neural networks')
fig.show()

In the above example, we would expect features 1 and 2 to have a high similarity score, as well as for features 3 and 4. 

In [115]:
similarity_function = nn.CosineSimilarity(dim=-1)

similarity_scores = torch.zeros(4, 4)
features = torch.cat([feature_1, feature_2, feature_3, feature_4], dim=0)
for idx, feat in enumerate(features):
    feat = feat.unsqueeze(0).repeat(4, 1)
    similarity_scores[idx] = similarity_function(feat, features)

display(Markdown('### Similarity scores : '))
for i in range(4):
    for j in range(i, 4):
        if i==j:
            continue
        if (i==0 and j==1) or (i==2 and j==3):
            display(Markdown((f' - **feature {i+1} and feature {j+1}: {round(100*similarity_scores[i, j].item(), 2)}%**')))
        else:
            display(Markdown((f' - feature {i+1} and feature {j+1}: {round(100*similarity_scores[i, j].item(), 2)}%')))

### Similarity scores : 

 - **feature 1 and feature 2: 83.71%**

 - feature 1 and feature 3: 9.72%

 - feature 1 and feature 4: 10.05%

 - feature 2 and feature 3: 7.9%

 - feature 2 and feature 4: 8.1%

 - **feature 3 and feature 4: 99.88%**

As expected, our similarity function indeed assigns high similarity score between features 1 and 2, as well as for features 3 and 4. It also outputs low similarity scores between seemingly unrelated features, _e.g._ features 1 and 3. Let us now apply this to inner features of an actual neural network.

In [69]:
rave_activations = extract_activations(
    bended_rave, 
    RAVE_DECODER_ACT_NAMES, 
    audio_loader, 
    avg_batch=True, 
    max_batches=20
)
activations_similarity = compute_activations_similarity(
    rave_activations
)

To make sure our method works, we can choose an intermediate activation of RAVE, and display all activations having high similarity with it.

In [103]:
target_activation_name = 'add_25'
feature_index = 10
threshold = .8

target_activation = rave_activations[target_activation_name]
similarity_scores = activations_similarity[target_activation_name]
similar_features = []
for compared_feature_index, score in enumerate(similarity_scores[feature_index]):
    if compared_feature_index==feature_index:
        continue
    if score>=threshold:
        similar_features.append(compared_feature_index)

if len(similar_features)>4:
    similar_features = similar_features[:3]
    
if not len(similar_features):
    display(Markdown('This activation does not seem to have any other similar features'))
else:
    subplot_titles = ['An inner activation of RAVE']+[f'Similar feature {idx+1}' for idx, _ in enumerate(similar_features)]
    fig = make_subplots(cols=1, rows=len(similar_features)+1, subplot_titles=subplot_titles)
    fig.add_trace(
        go.Scatter(
            y=target_activation[feature_index].numpy(), 
            name = f'Target feature'
            ), 
            col=1, row=1)
    for idx, compared_feature_idx in enumerate(similar_features):
        fig.add_trace(
            go.Scatter(
                y=target_activation[compared_feature_idx].numpy(), 
                name = f'Similar feature {idx+1}'
                ), 
            col=1, row=idx+2)
    fig.update_layout(height=800, width=1000)
    fig.show()

Our similarity function seems to indeed identify features that look similar. Now, we can define clusters of alike activations by grouping activations sharing high similarity scores.

In [None]:
activations_clusters = compute_clusters(
    activations_similarity, 
    threshold=0.75
)

# Remove clusters that have only 1 element, as we don't want to bend those, for now
activations_clusters = compute_non_singleton_clusters(
    activations_clusters
)

# Sort clusters by decreasing size
activations_clusters = sort_clusters(activations_clusters)

### Creating bending callbacks and finally applying those to RAVE intermediate activations !

In [None]:
# Creating affine callbacks, then wrapping them inside ClusteredBendingCallbacks, 
# i.e. bending operations will be applied to groups of activations

affine_benders = make_affine_bending_modules(RAVE_DECODER_ACT_NAMES)
clustered_affine_benders = make_clustered_bending_callbacks(
    affine_benders, 
    activations_clusters
)

In [None]:
op_widgets = make_widgets(RAVE_DECODER_ACT_NAMES)

update_bending_params = partial(
    update_affine_params, 
    clustered_affine_cb=clustered_affine_benders
)

def _bend_rave(
    bended_rave_model, 
    audio_batch, 
    audio_batch_rec, 
    **params
):
    bending_callbacks = update_bending_params(
        **params)
    bended_audios = get_bended_rave_audios(
        audio_batch, 
        bended_rave_model, 
        bending_callbacks, 
    )
    audio_grid = make_audio_grid(audio_batch_rec, 
                                 bended_audios, 
                                 sr=RAVE_SAMPLE_RATE, 
                                 ncols=3)
    return audio_grid

bend_rave = partial(_bend_rave, 
                    bended_rave_model=bended_rave, 
                    audio_batch=x, 
                    audio_batch_rec=x_rec)

dynamic_bend_rave = pn.bind(bend_rave, 
                            **op_widgets)
pn.Column(
    make_widget_box(RAVE_DECODER_ACT_NAMES, **op_widgets), 
    dynamic_bend_rave
)
    