# Bended modules 

At the very base of the `torchbend` modules lies the `BendedModule`, a wrapper for `torch.nn.Module` that uses the `torchbend` improved tracer to provide a handy interface for weight bending, interpolation, and activation retrieval / and bending. 

## Dissecting weights and activations.

Let's see how to bend this simple module : 

In [1]:
import sys; sys.path.append("..")
import torch, torch.nn as nn
import torchbend as tb

class Stanley(nn.Module):
    def __init__(self, n_channels = 4):
        super().__init__()
        self.conv_modules = nn.Sequential(
            nn.Conv1d(1, n_channels, 3), 
            nn.Conv1d(n_channels, 8, 3)
        )
        self.batch_norm = nn.BatchNorm1d(8)
        self.nnlin = nn.Sigmoid()
        self.n_channels = n_channels
    
    def forward(self, x):
        out = self.conv_modules(x)
        out = self.batch_norm(out)
        out = self.nnlin(out)
        return out

    def forward_nobatch(self, x):
        out = self.conv_modules(x)
        out = self.nnlin(out)
        return out


# instantiate 
module = Stanley()

# wrap with Bended Module
bended = tb.BendedModule(module)

`BendedModule` is one the main object of the `torchbend` library, that takes a `nn.Module` instance as its only argument. Before bending, let's analyse the weights our `Stanley` instance through the `BendedModule` wrapper : 

In [2]:
weight_names = bended.weight_names
weight_shapes = list(map(bended.weight_shape, weight_names))

# print_weights also print weight information in a tabular way. 
# The out keyword may be used to export information as a .txt file
bended.print_weights();

name                   shape                  dtype                min       max        mean    stddev
---------------------  ---------------------  -------------  ---------  --------  ----------  --------
conv_modules.0.weight  torch.Size([4, 1, 3])  torch.float32  -0.320379  0.556905   0.17653    0.296265
conv_modules.0.bias    torch.Size([4])        torch.float32  -0.527377  0.530088   0.0369331  0.561641
conv_modules.1.weight  torch.Size([8, 4, 3])  torch.float32  -0.277825  0.288154  -0.0127818  0.165427
conv_modules.1.bias    torch.Size([8])        torch.float32  -0.198426  0.259111   0.0349824  0.182413
batch_norm.weight      torch.Size([8])        torch.float32   1         1          1          0
batch_norm.bias        torch.Size([8])        torch.float32   0         0          0          0


Perfect, we have all the information we need for Stanley's weights. To retrieve activations, target methods needs to be traced first ; this can be done with the `trace` callback, that requires a given set of inputs for the function.

In [3]:
# by default, trace traces the forward callback.
x = torch.zeros(4, 1, 128)
bended.trace(x=x)
activation_names = bended.activation_names()
activation_shapes = list(map(bended.activation_shape, activation_names))
print('forward method : ')
bended.print_activations();

# as activations are callback dependent, method name may be
# given to specify the target : 
from functools import partial

fn = "forward_nobatch"
bended.trace(fn, x=x)
activation_names = bended.activation_names(fn)
activation_shapes = list(map(partial(bended.activation_shape, fn=fn), activation_names))
print(f'{fn} method : ')
bended.print_activations(fn);


forward method : 
--------------  -----------  -----------------------
x               placeholder  torch.Size([4, 1, 128])
conv_modules_0  call_module  torch.Size([4, 4, 126])
conv_modules_1  call_module  torch.Size([4, 8, 124])
batch_norm      call_module  torch.Size([4, 8, 124])
nnlin           call_module  torch.Size([4, 8, 124])
--------------  -----------  -----------------------
forward_nobatch method : 
--------------  -----------  -----------------------
x               placeholder  torch.Size([4, 1, 128])
conv_modules_0  call_module  torch.Size([4, 4, 126])
conv_modules_1  call_module  torch.Size([4, 8, 124])
nnlin           call_module  torch.Size([4, 8, 124])
--------------  -----------  -----------------------


By tracing a given method, `BendedModule` actually decomposes the method as a [torch.fx.Graph](https://pytorch.org/docs/stable/fx.html#torch.fx.Graph), tracking all the operations applied to a given set of inputs. The graph of a function is detached from the value of the module's parameters, such as the union of a graph and a state dict is called a [torch.fx.GraphModule](https://pytorch.org/docs/stable/fx.html#torch.fx.GraphModule). Both can be retrieved directly from `BendedModule` : 

In [4]:
fn = "forward"
graph = bended.graph(fn)
print('Graph : ')
graph.print_tabular()
fn = "forward_nobatch"
graph_module = bended.graph_module(fn)
print('\nGraph : ')
graph_module.graph.print_tabular()


Graph : 
opcode       name            target          args               kwargs
-----------  --------------  --------------  -----------------  --------
placeholder  x               x               ()                 {}
call_module  conv_modules_0  conv_modules.0  (x,)               {}
call_module  conv_modules_1  conv_modules.1  (conv_modules_0,)  {}
call_module  batch_norm      batch_norm      (conv_modules_1,)  {}
call_module  nnlin           nnlin           (batch_norm,)      {}
output       output          output          (nnlin,)           {}

Graph : 
opcode       name            target          args               kwargs
-----------  --------------  --------------  -----------------  --------
placeholder  x               x               ()                 {}
call_module  conv_modules_0  conv_modules.0  (x,)               {}
call_module  conv_modules_1  conv_modules.1  (conv_modules_0,)  {}
call_module  nnlin           nnlin           (conv_modules_1,)  {}
output       output    

Specific activations can be retrieved as a `dict` object using the `get_activations` method : 

In [5]:
outs = bended.get_activations("conv_modules_0", "nnlin",  x=x, fn="forward")
print({k: v.shape for k, v in outs.items()})

{'conv_modules_0': torch.Size([4, 4, 126]), 'nnlin': torch.Size([4, 8, 124])}


## Bending

Here we will see how to bend specific weights and activations. We will here use the `tb.Mask` bending operation, that masks the target feature using a binary mask. Bending operations does not touch the original module and are not made in place, such that every bending operation can be reverted using the `reset` method.

In [6]:
out = bended.forward(x)

cb = tb.Mask(prob=0.4)
bended.bend(cb, "conv_modules.0.weight", "nnlin")

out_bended = bended.forward(x)
print("original == bended :", (out == out_bended).all())
print("original param == bended param :", (module.conv_modules[0].weight == bended.module.conv_modules[0].weight).all())

# revert bending
bended.reset()
out_reverted = bended.forward(x)
print("original == reverted :", (out == out_reverted).all())

original == bended : tensor(False)
original param == bended param : tensor(True)
original == reverted : tensor(True)


Ok, the module has been correctly bended! Let's see in details how `BendingModule` bends the original module. The process can be summarized as follows : 


![bending process](img/bending.png "Bending process")

Let's now compare item by item the effect of the bending process using several inner methods of `BendingModule` : 

In [7]:
cb = tb.Mask(prob=0.4)
bended.bend(cb, "conv_modules.0.weight", "nnlin")

original_state_dict = bended.state_dict()
bended_state_dict = bended.bended_state_dict()

print("-- original state dict :\n", original_state_dict['conv_modules.0.weight'])
print("-- bended state dict: \n", bended_state_dict['conv_modules.0.weight'])

original_graph = bended.graph()
bended_graph = bended.bend_graph()

print("-- original graph :\n", original_graph)
print("-- bended graph: \n", bended_graph)

x = torch.randn(4, 1, 128)
original_activation = bended.get_activations("nnlin", x=x, bended=False)
bended_activation = bended.get_activations("nnlin", x=x)
print(original_activation['nnlin'].shape)
print(bended_activation['nnlin'].shape)
print("-- original activation :\n", original_activation['nnlin'][0, 0])
print("-- bended activation: \n", bended_activation['nnlin'][0, 0])

-- original state dict :
 tensor([[[-0.1613, -0.1353,  0.4003]],

        [[ 0.3638, -0.3204, -0.1862]],

        [[ 0.3850,  0.4367,  0.3142]],

        [[ 0.1633,  0.5569,  0.3013]]])
-- bended state dict: 
 tensor([[[-0.1613, -0.0000,  0.4003]],

        [[ 0.0000, -0.0000, -0.0000]],

        [[ 0.3850,  0.0000,  0.0000]],

        [[ 0.0000,  0.5569,  0.0000]]])
-- original graph :
 graph():
    %x : [num_users=1] = placeholder[target=x]
    %conv_modules_0 : [num_users=1] = call_module[target=conv_modules.0](args = (%x,), kwargs = {})
    %conv_modules_1 : [num_users=1] = call_module[target=conv_modules.1](args = (%conv_modules_0,), kwargs = {})
    %batch_norm : [num_users=1] = call_module[target=batch_norm](args = (%conv_modules_1,), kwargs = {})
    %nnlin : [num_users=1] = call_module[target=nnlin](args = (%batch_norm,), kwargs = {})
    return nnlin
-- bended graph: 
 graph():
    %x : [num_users=1] = placeholder[target=x]
    %conv_modules_0 : [num_users=1] = call_module[ta

### Common issues with activation bending and shapes

Let's see now a critical case of activation bending : shape handling. Indeed, let's try to change the shape of the input, and apply our bending operations : 

In [8]:
x = torch.randn(1, 1, 64)
try: 
    out = bended.forward(x)
except RuntimeError as e:
    print(e)

The size of tensor a (60) must match the size of tensor b (124) at non-singleton dimension 2


This happens because, during tracing, activation `nnlin` has been recorded to have shape `torch.Size([4, 8, 124])`. Hence, the `Mask` callback has been initialiazed with a similar shape, causing a `RuntimeError` during multiplication. To make this bending shape independant on last dimension, we can only mask the channel dimension : 

In [9]:
bended.reset()
bended.bend(tb.Mask(prob=0.3, dim=-2), "nnlin")

x = torch.randn(1, 1, 64)
out = bended.forward(x)

Activation bending is indeed allowed by our improved tracer, that records shapes of activations during graph tracing. This extension asks, in exchange, to be precautionary on how you bend the graph to be sure that no improper bending operations are applied during the execution process. For more information of this, jump to the next tutorial!

## Monitoring bending operations with `BendingConfig`

Bending operations of a `BendingModule` can be objectified using the `BendingConfig` object, that can also be used for bending, pickling, and monitoring bending operations.

In [10]:
module = Stanley()
bended = tb.BendedModule(module)
bended.trace(x=torch.randn(1, 1, 1024))

bended.bend(tb.Mask(prob=0.4), "conv_modules.0.weight", "nnlin")
bended.bend(tb.Mask(prob=0.6), "conv_modules.0.weight", "conv_modules.1.weight")

bending_config = bended.bending_config();
print("bending config : ", bending_config)
key = "conv_modules.0.weight"
print(f"operations for key {key}:", bending_config.op_from_key(key))

bending_config.save('test.tb')
bending_config = tb.BendingConfig.load('test.tb', module=bended)
print("loaded config : ", bending_config)

BendingConfig(
module = Stanley(id = 5501192016)
	Mask(prob=0.400): ['conv_modules.0.weight', 'nnlin']
	Mask(prob=0.600): ['conv_modules.0.weight', 'conv_modules.1.weight']
)
bending config :  BendingConfig(
module = Stanley(id = 5501192016)
	Mask(prob=0.400): ['conv_modules.0.weight', 'nnlin']
	Mask(prob=0.600): ['conv_modules.0.weight', 'conv_modules.1.weight']
)
operations for key conv_modules.0.weight: [Mask(prob=0.400), Mask(prob=0.600)]
loaded config :  BendingConfig(
module = Stanley(id = 5501192016)
	Mask(prob=0.400): ['conv_modules.0.weight', 'nnlin']
	Mask(prob=0.600): ['conv_modules.0.weight', 'conv_modules.1.weight']
)


  return torch.load(io.BytesIO(b))


`BendingConfig` may be created out of the blue, or bounded to a given `BendedModule` for automatic key resolution. They can also be added or compared together :

In [11]:
cb1 = tb.Mask(0.8)
config1 = tb.BendingConfig(
    (cb1, "conv_modules.\d.weight"),
    (cb1, "nnlin")
)
config2 = tb.BendingConfig((cb1, "conv_modules.\d.weight")) + tb.BendingConfig((cb1, "nnlin"))

print('Before binding : ')
print("config1 :", config1)
print("config2 :", config2)
print(config1 == config2)

config1.bind(bended)
config2.bind(bended)

print('\nAfter binding : ')
print("config1 :", config1)
print("config2 :", config2)

bended.reset()
bended.bend(config1)
module_config = bended.bending_config()
print('\nComparison : ')
print(config1 == module_config)

Before binding : 
config1 : BendingConfig(
	(Mask(prob=0.800), 'conv_modules.\\d.weight')
	(Mask(prob=0.800), 'nnlin')
)
config2 : BendingConfig(
	(Mask(prob=0.800), 'conv_modules.\\d.weight')
	(Mask(prob=0.800), 'nnlin')
)
True

After binding : 
config1 : BendingConfig(
module = Stanley(id = 5501192016)
	Mask(prob=0.800): ['conv_modules.0.weight', 'conv_modules.1.weight', 'nnlin']
)
config2 : BendingConfig(
module = Stanley(id = 5501192016)
	Mask(prob=0.800): ['conv_modules.0.weight', 'conv_modules.1.weight', 'nnlin']
)
BendingConfig(
module = Stanley(id = 5501192016)
	Mask(prob=0.800): ['conv_modules.0.weight', 'conv_modules.1.weight', 'nnlin']
)

Comparison : 
True


## Versions and interpolation

`BendedModule` also allows you to create several versions of the original module, and to interpolate between them in a smooth manner. Though, beware that this only works with weight bending as smooth interpolation between graphs does not really make sense! 

In [12]:
module = Stanley()
bended = tb.BendedModule(module)

x = torch.randn(1, 1, 128)
out_unbended = bended(x)

print("default version : ", bended.version)
bended.bend(tb.Mask(prob=0.3), "conv_modules.0.weight")
bended.write("bended")
print("current version : ", bended.version)

# revert to default
bended.version = None
print('current version : ', bended.version)

with bended.set_version():
    out_original = bended(x)
with bended.set_version("bended"):
    out_bended = bended(x)

# arguments of bended.interpolate has an optional positional argument
# for default configuration weight, plus keyword arguments for every additional
# config weights. 
with bended.interpolate(1., bended=1.):
    out_interpolated = bended(x)

print("original == bended : ", (out_original == out_bended).all())
print("original == interpolated : ", (out_original == out_interpolated).all())
print("bended == interpolated : ", (out_bended == out_interpolated).all())

default version :  _default
current version :  bended
current version :  _default
original == bended :  tensor(False)
original == interpolated :  tensor(False)
bended == interpolated :  tensor(False)


You can also load another version of the same module, provided that their state dict is the same, and interpolate between them with the `interpolate` context manager.

In [13]:
module = Stanley()
module2 = Stanley()
bended = tb.BendedModule(module)
bended.create_version("imported", module2)

x = torch.randn(1, 1, 128)
with bended.set_version():
    out_original = bended(x)
with bended.set_version("imported"):
    out_imported = bended(x)
with bended.interpolate(1., imported=1.):
    out_interpolated = bended(x)

print("original == imported : ", (out_original == out_imported).all())
print("original == interpolated : ", (out_original == out_interpolated).all())
print("imported == interpolated : ", (out_imported == out_interpolated).all())

class Doppleganger(nn.Module):
    pass

try: 
    bended.create_version("imported", Doppleganger())
except tb.BendingError as e: 
    print(e)

original == imported :  tensor(False)
original == interpolated :  tensor(False)
imported == interpolated :  tensor(False)
