In [None]:
%load_ext autoreload
%autoreload 2
import sys; sys.path.append('../torchbend')

# Bending a MNIST auto-encoder

This notebook has be conceived to introduce you to the notion of *network bending*, and let you perform simple bending operations on a little image generative model, that is generally more intuitive than large audio models.

## What is bending? 

The idea of *network bending*, first referenced by [Terence Broads](https://arxiv.org/pdf/2005.12420), is the equivalent of [circuit bending](https://fr.wikipedia.org/wiki/Circuit_bending) for neural network-based models : hijacking a model developed and designed for a purpose (here, the original task of the machine learning model) by the alteration of its inner circuiteries. You can also see that as opening a *modular* approach to machine learning, while this idea of *hijack* is still important and not naturally encompassed by modular synthesis. `torchbend` is a library designed originally to allow high-end functions for graph & parameter bending, allowing to generate data in a way that would not have been possible without these kind of alterations. This is why it can also be used for *dissection*, as being able to develop an understanding of how the model work is very important to not lose one's time perfoming alterations that would not have any sense, exactly like bending analogical devices. 

Ready to go? Let's bend this little image model.

In [None]:
import torch
from dandb.networks import make_mnist_vae
import torchbend as tb

vae = make_mnist_vae()
state_dict = torch.load('models/original/mnist_vae/final.ckpt', map_location=torch.device('cpu'))
vae.load_state_dict(state_dict)
bended_vae = tb.BendedModule(vae)

### Weight bending

A way of bending a module is to bend its *parameters*, or *weights*. The cool part about weight bending is also that it is available for every model written with PyTorch, while other bending techniques are unfortunately dependant of the coding style of the target module (that's silly, but true). Let's try that with our module : the only thing you have to do is to call the `bend` method with a weight key.

In [None]:
from dandb import plot_image_activations

names = bended_vae.resolve_parameters(r"decoder.net.convt.*weight.*")
print(names)

# we will here sample the latent space, as we will only bend the decoder.
n_samples = 16
z = torch.randn(n_samples, 64)
# we trace the decode function
bended_vae.trace("decode", z=z)

# we will use the Mask bending callback, that destroys the weight / activation
# with a given probability rate. 

# we keep 20% of the mask
mask = 0.2
# we put this into a BendingParameter object ; we'll see why later
param = tb.BendingParameter('mask', mask)
# init the callback
cb = tb.Mask(prob=param)

outs = []
for weight_name in names:
    # reset previous bendings
    bended_vae.reset()
    bended_vae.bend(cb, weight_name)
    out = bended_vae.decode(z=z)
    outs.append(out)

plot_image_activations(outs, n_rows=1, display=True)


We see that weights have different effects on the produced output, depending on their proximity to the output. Let's apply a scale transformation : 

In [None]:
from dandb import plot_image_activations

names = bended_vae.resolve_parameters(r"decoder.net.convt.*weight.*")
print(names)

# we will here sample the latent space, as we will only bend the decoder.
n_samples = 16
z = torch.randn(n_samples, 64)
# we trace the decode function
bended_vae.trace("decode", z=z)

# we will use the Mask bending callback, that destroys the weight / activation
# with a given probability rate. 


# scale all the weights by 10
scale = 10
# or, invert all the weights (uncomment)
# scale = -1


# we put this into a BendingParameter object ; we'll see why later
param = tb.BendingParameter('scale', scale)
# init the callback
cb = tb.Scale(scale=param)

outs = []
for weight_name in names:
    # reset previous bendings
    bended_vae.reset()
    bended_vae.bend(cb, weight_name)
    out = bended_vae.decode(z=z)
    outs.append(out)

plot_image_activations(outs, n_rows=1, display=True)


We see that the network is *almost* linear, which is quite normal as convolutional operators are linear and the nonlinearity used here acts like a saturation unit. If you invert the weights though, the output is much different ; this is because the nonlinearity `ReLU` is asymmetrical. Let's try with a bias now :  

In [None]:
from dandb import plot_image_activations

names = bended_vae.resolve_parameters(r"decoder.net.convt.*weight.*")
print(names)

# we will here sample the latent space, as we will only bend the decoder.
n_samples = 16
z = torch.randn(n_samples, 64)
# we trace the decode function
bended_vae.trace("decode", z=z)

# we will use the Mask bending callback, that destroys the weight / activation
# with a given probability rate. 


# bias positively all the weights
bias = 0.1
# or negatively
# bias = -0.1


# we put this into a BendingParameter object ; we'll see why later
param = tb.BendingParameter('bias', bias)
# init the callback
cb = tb.Bias(bias=param)

outs = []
for weight_name in names:
    # reset previous bendings
    bended_vae.reset()
    bended_vae.bend(cb, weight_name)
    out = bended_vae.decode(z=z)
    outs.append(out)

plot_image_activations(outs, n_rows=1, display=True)


The impact of the bias is more dramatic : indeed, this biases all the activations, but the effect is also very different across layers. You can experiment any function you want with the `Lambda` callback : 



In [None]:
from dandb import plot_image_activations

names = bended_vae.resolve_parameters(r"decoder.net.convt.*weight.*")
print(names)

# we will here sample the latent space, as we will only bend the decoder.
n_samples = 16
z = torch.randn(n_samples, 64)
# we trace the decode function
bended_vae.trace("decode", z=z)

# bias positively all the weights
bias = 0.1
# or negatively
# bias = -0.1

# we put this into a BendingParameter object ; we'll see why later
param = tb.BendingParameter('bias', bias)
# init the callback

def bending_op(x, f=8):
    return torch.cos(2 * f * torch.pi * x)

cb = tb.Lambda(bending_op)

outs = []
for weight_name in names:
    # reset previous bendings
    bended_vae.reset()
    bended_vae.bend(cb, weight_name)
    out = bended_vae.decode(z=z)
    outs.append(out)

plot_image_activations(outs, n_rows=1, display=True)


This is how, by applying effects on the model's weights, we can have important effect on the produced output in a way that could not be achievable without altering the network. Yet, this method bends all the inputs in the same manner. Activation bending allows more subtle way of making different operations for different inputs. 

### Activation bending

The other way to bend a module is by directly modulating its intermediary processing values, the *activations*. The syntax is strictly similar, except that the keys must be one of the activations listed in the activation list. If you to be sure to only bend activations, you can add the `bend_activation` keyword.

We will use here the `tb.ThresholdActivation` bending callback, that only keeps a given proportion of the active channels of an activation.



In [None]:
from dandb import plot_image_activations

bended_vae.reset()
names = list(sorted(bended_vae.resolve_activations(r"decoder_net_convt.*", fn="decode")))
print(names)

# we will here sample the latent space, as we will only bend the decoder.
n_samples = 16
z = torch.randn(n_samples, 64)
# we trace the decode function
bended_vae.trace("decode", z=z)

# here, we filter half of the less activated features
# change the float number below to adjsut the amount of channels kept
param = tb.BendingParameter('threshold', 0.5)
# init the callback
cb = tb.ThresholdActivation(threshold=param, dim=-3, invert=False)

outs = []
for activation_name in names:
    # reset previous bendings
    bended_vae.reset()
    bended_vae.bend(cb, activation_name)
    out = bended_vae.decode(z=z)
    outs.append(out)

plot_image_activations(outs, n_rows=1, display=True)


With activation bending, the effect can be different for every output, while weight bending has the same effect on every input. For example, instead of masking the channels of lower amplitude, we can normalize all the channels by instance : 

In [None]:
from dandb import plot_image_activations

bended_vae.reset()
names = list(sorted(bended_vae.resolve_activations(r"decoder_net_convt.*", fn="decode")))

# we will here sample the latent space, as we will only bend the decoder.
n_samples = 16
z = torch.randn(n_samples, 64)

def norm_by_instance(activation):
    # activtion shape is (batch x channel x height x width), so
    # we normalize here across instances
    return activation / activation.amax(0, keepdim=True)

param = tb.BendingParameter('threshold', 0.5)
cb = tb.Lambda(norm_by_instance)

outs = []
for activation_name in names:
    # reset previous bendings
    bended_vae.reset()
    bended_vae.bend(cb, activation_name)
    out = bended_vae.decode(z=z)
    outs.append(out)

print("normalization across instances : ")
plot_image_activations(outs, n_rows=1, display=True)


def norm_by_channel(activation):
    # activtion shape is (batch x channel x height x width), so
    # we normalize here across instances
    return activation / activation.amax(1, keepdim=True)

param = tb.BendingParameter('threshold', 0.5)
cb = tb.Lambda(norm_by_channel)

outs = []
for activation_name in names:
    # reset previous bendings
    bended_vae.reset()
    bended_vae.bend(cb, activation_name)
    out = bended_vae.decode(z=z)
    outs.append(out)

print('normalization across channels : ')
plot_image_activations(outs, n_rows=1, display=True)


What if we mixed activations? We can try that out with the `tb.InterpolateActivation` callback, that takes as additional input a mixing matrix that allows to make linear interpolation across a set of input activations. With this callback, we will levearge the `from_activations` method, the exact complementary of `get_activations`, that allows you to directly feed activations to a sub part of the network. 

In [None]:
from torchvision import datasets, transforms
from dandb import plot_image_activations

bended_vae.reset()

names = list(sorted(bended_vae.resolve_activations(r"decoder_net_convt.*", fn="decode")))

n_samples = 2
val_set = datasets.MNIST(root='datasets/MNIST', train=False, download=True, transform = transforms.ToTensor())
examples = torch.stack([val_set[0][0], val_set[100][0]])
z = bended_vae.encode(examples)[0]
bended_vae.trace("decode", z=z)

n_interp = 8
cb = tb.InterpolateActivation()

outs = []
for activation_name in names:
    # reset previous bendings
    bended_vae.reset()
    bended_vae.bend(cb, activation_name)
    out_activation = bended_vae.get_activations(activation_name, z=z, fn="decode", _filter_bended=True)
    # make linear interpolation between these two examples
    interp_weights = torch.stack([torch.linspace(0., 1., n_interp), torch.linspace(1, 0, n_interp)], 1)
    out = bended_vae.from_activations(activation_name, **out_activation, interp_weights=interp_weights, fn="decode")
    outs.append(out)

print("layer-wise interpolation across instances : ")
plot_image_activations(outs, n_rows=1, display=True)


We can see that the more we perform interpolation on higher layers of the decode (closer to the latent space), "smoother" is the interpolation. This reflect the idea quite general in deep learning that higher level generally represent more abstract features, while lower layers (closer to the data) represent more "local" features (in this example, we can see that it is closer to a direct "linear" interpolation of images). Is it the case with RAVE? Let's try that out! 