# Defining callbacks


We will see in this tutorial how to define a bending callaback to perform custom callback operations. 

## Defining stateless callbacks

To start easy, let's define a callback that does not need any internal buffer. A bending callback must define two functions : 
- the `bend_input` function, that is used for weight and activation bending
- the `apply_to_param` function, that is used for weight bending after JIT compilation 


In [1]:
import torch, torch.nn as nn
from typing import Optional
import sys; sys.path.append("..")
import torchbend as tb

class Greg(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_module_1 = nn.Conv1d(1, 4, 3)
        self.conv_module_2 = nn.Conv1d(4, 1, 3)
    def forward(self, x):
        out_1 = self.conv_module_1(x)
        out_2 = self.conv_module_2(out_1)
        return out_2
    
class Square(tb.BendingCallback):
    # define class attributes to provide meta information on what the callback can do
    weight_compatible = True 
    activation_compatible = True
    jit_compatible = True 
    nntilde_compatible = True 

    def apply_to_param(self, idx: int, param: nn.Parameter, cache: Optional[torch.Tensor]):
        """
        for weight bending, apply_to_param receives three arguments : 
        - idx of recorded param (int) : bended parameters are recorded as lists to be jit-compatible, such that providing the id is useful to recover 
        paramter-wise buffers
        - param : the module parameter, modified in place
        - cache : (optional): callbacks cache the original parameter value for dynamical modification
        """
        if cache is not None:
            param.set_(cache * cache)

    def bend_input(self, x: torch.Tensor, name: Optional[str] = None):
        """
        for activation bending, apply_to_param receives two arguments : 
        - x : value of activation
        - name : the activation name, that can be used to retrieve activation-wise buffers
        """
        return x * x


module = Greg()
bended = tb.BendedModule(module)
x = torch.randn(1, 1, 32)
_, out = bended.trace(x=x, _return_out=True)
print(out[0])

bended.bend(Square(), "conv_module_1$", "conv_module.1.weight")
out = bended(x)
print(out)

scripted = bended.script()

tensor([[[ 0.0384, -0.4096, -0.3645, -0.8078, -0.6729, -0.7594, -0.5992,
           0.0806, -0.8036, -0.4242, -0.3870, -0.6751, -0.2078, -0.5525,
          -0.2823, -0.5777, -0.4736, -0.0823, -0.7809, -0.8411, -0.2782,
          -0.2959, -0.3533, -0.1839, -0.0553, -0.2795, -0.5492, -0.1855]]],
       grad_fn=<ConvolutionBackward0>)
tensor([[[-0.2223, -0.3371, -0.3487, -0.2570, -0.2490, -0.2163, -0.2379,
          -0.3109, -0.2827, -0.2512, -0.2077, -0.2559, -0.2952, -0.2522,
          -0.2119, -0.1845, -0.2836, -0.2968, -0.2521, -0.1654, -0.1012,
          -0.0396, -0.0348, -0.0646, -0.1804, -0.2581, -0.2929, -0.2695]]],
       grad_fn=<ConvolutionBackward0>)


## Stateful callbacks 

Stateful callbacks may be a little trickier, nothing heavy do not worry but some steps can be cumbersome due to torch.jit. Stateful callbacks implies overriding some more methods :
- `register_parameter` : register a bended parameter in the callback module
- `register_activation` : register a bended activation shape in the callback module
- `update`, that updates inner states after a `BendingParameter` change

In [2]:
class UpClamp(tb.BendingCallback):
    weight_compatible = True 
    activation_compatible = True
    jit_compatible = True 
    nntilde_compatible = True 

    # provide this to inform BendingCallback that this init arg can be contorlled
    controllable_params = ['threshold']

    def __init__(self, threshold = 1.):
        super().__init__()
        self.threshold = threshold
        self._masks = torch.nn.ParameterList()
        self._mask_names = []

    def _add_mask(self, name, parameter):
        mask = nn.Parameter((parameter.data > self.get('threshold')).float(), requires_grad=False)
        self._masks.append(mask)
        # disable gradient
        self._mask_names.append(name)

    def register_parameter(self, parameter, name=None, cache = True):
        name = super().register_parameter(parameter, name=name, cache=cache)
        self._add_mask(name, parameter=parameter)

    def register_activation(self, name, shape):
        name = super().register_activation(name, shape)
        # here we don't need to do anything, as only parameter updates require states
        return name
    
    def get_mask_from_index(self, idx: int):
        # Don't judge me, this is because torch.jit only allows literal indexing. 
        for i, v in enumerate(self._masks):
            if i == idx:
                return v
        raise tb.BendingCallbackException('%d not present in masks'%idx)
    
    def update(self):
        for i, v in enumerate(self._masks):
                v.data.set_((self.get_cache(i) > self.get('threshold')).float())

    def apply_to_param(self, idx: int, param: nn.Parameter, cache: Optional[torch.Tensor]):
        """
        for weight bending, apply_to_param receives three arguments : 
        - idx of recorded param (int) : bended parameters are recorded as lists to be jit-compatible, such that providing the id is useful to recover 
        paramter-wise buffers
        - param : the module parameter, modified in place
        - cache : (optional): callbacks cache the original parameter value for dynamical modification
        """
        assert cache is not None
        param.set_(cache * self.get_mask_from_index(idx))

    def bend_input(self, x: torch.Tensor, name: Optional[str] = None):
        """
        for activation bending, apply_to_param receives two arguments : 
        - x : value of activation
        - name : the activation name, that can be used to retrieve activation-wise buffers
        """
        return x * (x < self.get('threshold')).float()

module = Greg()
bended = tb.BendedModule(module)
x = torch.randn(1, 1, 32)
_, out = bended.trace(x=x, _return_out=True)
print(out[0])

c1 = tb.BendingParameter('threshold', value=1.)
bended.bend(UpClamp(threshold=c1), "conv_module_2$", "conv_module_1.weight")
out = bended(x)
print(out)

bended.update(c1.name, 0.)
out = bended(x)
print(out)

scripted = bended.script()

tensor([[[ 0.5531,  0.2135,  0.1794, -0.1940, -0.2208,  0.0698,  0.2735,
           0.1670,  0.3737,  0.2017,  0.4460,  0.3786,  0.5029,  0.3215,
           0.3719,  0.0941,  0.2514, -0.1249, -0.0867, -0.2876, -0.3940,
          -0.2336,  0.3189,  0.3808,  0.2921, -0.1922, -0.4330, -0.8138]]],
       grad_fn=<ConvolutionBackward0>)
tensor([[[ 0.5531,  0.2135,  0.1794, -0.1940, -0.2208,  0.0698,  0.2735,
           0.1670,  0.3737,  0.2017,  0.4460,  0.3786,  0.5029,  0.3215,
           0.3719,  0.0941,  0.2514, -0.1249, -0.0867, -0.2876, -0.3940,
          -0.2336,  0.3189,  0.3808,  0.2921, -0.1922, -0.4330, -0.8138]]],
       grad_fn=<MulBackward0>)
tensor([[[ 0.0000,  0.0000,  0.0000, -0.3109, -0.2443,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000, -0.2082, -0.2309, -0.4662, -0.4808,
          -0.2194,  0.0000,  0.0000,  0.0000, -0.4393, -0.6785, -0.9573]]],
       grad_fn=<MulBackward0>)


## Capturing

Callbacks can also capture a set of entries, and perform some operations for inference. The most basic callback implementing that is the `tb.Capture` callback, that be can be used as a base class for every class implementing capturing : 

In [3]:
module = Greg()
bended = tb.BendedModule(module)
x = torch.randn(1, 1, 32)
_, out = bended.trace(x=x, _return_out=True)

capture = tb.Capture()
bended.bend(capture, 'conv_module_1$')

capture.capture()
for i in range(4):
    bended(torch.randn(1, 1, 32))
capture.stop()

print("captured activation shape : ", capture.captures['conv_module_1'].shape)


captured activation shape :  torch.Size([4, 4, 30])


In [4]:
module = Greg()
bended = tb.BendedModule(module)
x = torch.randn(1, 1, 32)
_, out = bended.trace(x=x, _return_out=True)

c1 = tb.Capture()
c2 = tb.Capture()
bended.bend(c1, 'conv_module_1$')
bended.bend(c2, 'conv_module_2$')

# you can put arguments inside capture to specify some precise callbacks
with bended.capture():
    for i in range(4):
        bended(torch.randn(1, 1, 32))

# you can put arguments inside capture to specify some precise callbacks
with bended.capture(c2):
    for i in range(4):
        bended(torch.randn(1, 1, 32))


print("captured activation shape : ", c1.captures['conv_module_1'].shape)
print("captured activation shape : ", c2.captures['conv_module_2'].shape)


captured activation shape :  torch.Size([4, 4, 30])
captured activation shape :  torch.Size([8, 1, 28])


You can overload this `tb.Capture` to implement various callbacks based on recorded activations. For example, here we will remove half of the channels that have the lowest amplitude :

In [6]:
class RemoveLowestHalf(tb.Capture):
    def __init__(self):
        super().__init__()
        self._not_ready_str = "Please capture some activations before using this callback in inference mode"
        self.init_masks()

    def init_masks(self):
        self.masks = nn.ParameterDict()
    
    def compute_masks(self):
        assert self._is_initialized
        self.masks = nn.ParameterDict()
        for k, v in self._captures.items():
            energies = v.mean(dim=(0, 2)).abs()
            threshold = torch.median(energies)
            self.masks[k] = (energies > threshold)[None, :, None].float()

    def stop(self):
        """overload the stop method to preprocess the captured elements"""
        super().stop()
        self.compute_masks()

    def forward(self, x, name):
        if name is None: raise Exception()
        return x * self.get_mask_from_name(name)


module = Greg()
bended = tb.BendedModule(module)
x = torch.randn(1, 1, 32)
_, out = bended.trace(x=x, _return_out=True)

c1 = RemoveLowestHalf()
bended.bend(c1, 'conv_module_1$')

with bended.capture():
    for i in range(4):
        bended(torch.randn(1, 1, 32))

RuntimeError: torch.cat(): expected a non-empty list of Tensors