# Defining callbacks


We will see in this tutorial how to define a bending callaback to perform custom callback operations. 

## Defining stateless callbacks

To start easy, let's define a callback that does not need any internal buffer. A bending callback must define two functions : 
- the `forward` function, that is used for activation bending
- the `apply_to_param` function, that is used for weight bending


In [2]:
import torch, torch.nn as nn
from typing import Optional
import sys; sys.path.append("..")
import torchbend as tb

class Greg(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_module_1 = nn.Conv1d(1, 4, 3)
        self.conv_module_2 = nn.Conv1d(4, 1, 3)
    def forward(self, x):
        out_1 = self.conv_module_1(x)
        out_2 = self.conv_module_2(out_1)
        return out_2
    
class Square(tb.BendingCallback):
    # define class attributes to provide meta information on what the callback can do
    weight_compatible = True 
    activation_compatible = True
    jit_compatible = True 
    nntilde_compatible = True 

    def apply_to_param(self, idx: int, param: nn.Parameter, cache: Optional[torch.Tensor]):
        """
        for weight bending, apply_to_param receives three arguments : 
        - idx of recorded param (int) : bended parameters are recorded as lists to be jit-compatible, such that providing the id is useful to recover 
        paramter-wise buffers
        - param : the module parameter, modified in place
        - cache : (optional): callbacks cache the original parameter value for dynamical modification
        """
        param.set_(cache * cache)

    def forward(self, x: torch.Tensor, name: Optional[str] = None):
        """
        for activation bending, apply_to_param receives two arguments : 
        - x : value of activation
        - name : the activation name, that can be used to retrieve activation-wise buffers
        """
        return x * x


module = Greg()
bended = tb.BendedModule(module)
x = torch.randn(1, 1, 32)
_, out = bended.trace(x=x, _return_out=True)
print(out[0])

bended.bend(Square(), "conv_module_1$", "conv_module.1.weight")
out = bended(x)
print(out)



tensor([[[-0.3480, -0.0668, -1.1553,  0.6916, -1.0866, -0.2815,  0.0194,
          -0.7783, -0.3196, -0.6542,  0.0686, -0.9633, -0.2639, -0.4900,
          -0.8211,  0.0595, -0.4032, -0.9059,  0.1834, -1.1281,  0.1439,
          -0.4905, -0.7331, -0.4273,  0.0104, -0.9659,  0.2734, -1.0282]]],
       grad_fn=<ConvolutionBackward0>)
tensor([[[-0.3753, -0.4296, -0.2166, -0.5104, -0.1630, -0.2979, -0.3271,
          -0.1033, -0.2116, -0.1745,  0.0533, -0.1424, -0.1964, -0.1858,
          -0.2004, -0.2497, -0.2142, -0.1430, -0.2445, -0.1591, -0.1478,
          -0.2582, -0.1512, -0.3252, -0.3273, -0.1801, -0.3254, -0.1487]]],
       grad_fn=<ConvolutionBackward0>)


## Stateful callbacks 

Stateful callbacks may be a little trickier, nothing heavy do not worry but some steps can be cumbersome due to torch.jit. Stateful callbacks implies overriding some more methods :
- `register_parameter` : register a bended parameter in the callback module
- `register_activation` : register a bended activation shape in the callback module
- `update`, that updates inner states after a `BendingParameter` change

In [12]:
class UpClamp(tb.BendingCallback):
    weight_compatible = True 
    activation_compatible = True
    jit_compatible = True 
    nntilde_compatible = True 

    # provide this to inform BendingCallback that this init arg can be contorlled
    controllable_params = ['threshold']

    def __init__(self, threshold = 1.):
        super().__init__()
        self.threshold = threshold
        self._masks = torch.nn.ParameterList()
        self._mask_names = []

    def _init_mask(self, parameter):
        return nn.Parameter((parameter.data > self.get('threshold')).float().requires_grad_(False))

    def _add_mask(self, name, parameter):
        self._masks.append(self._init_mask(parameter))
        # disable gradient
        self._mask_names.append(name)

    def register_parameter(self, parameter, name=None, cache = True):
        name = super().register_parameter(parameter, name=name, cache=cache)
        self._add_mask(name, parameter=parameter)

    def register_activation(self, name, shape):
        name = super().register_activation(name, shape)
        # here we don't need to do anything, as only parameter updates require states
        return name
    
    def get_mask_from_index(self, idx: int):
        # Don't judge me, this is because torch.jit only allows literal indexing. 
        for i, v in enumerate(self._masks):
            if i == idx:
                return v
        raise tb.BendingCallbackException('%d not present in masks'%idx)
    
    def update(self):
        with torch.set_grad_enabled(False):
            for i, v in enumerate(self._masks):
                v.set_(self._init_mask(self.get_cache(i)))

    def apply_to_param(self, idx: int, param: nn.Parameter, cache: Optional[torch.Tensor]):
        """
        for weight bending, apply_to_param receives three arguments : 
        - idx of recorded param (int) : bended parameters are recorded as lists to be jit-compatible, such that providing the id is useful to recover 
        paramter-wise buffers
        - param : the module parameter, modified in place
        - cache : (optional): callbacks cache the original parameter value for dynamical modification
        """
        assert cache is not None
        param.set_(cache * self.get_mask_from_index(idx))

    def forward(self, x: torch.Tensor, name: Optional[str] = None):
        """
        for activation bending, apply_to_param receives two arguments : 
        - x : value of activation
        - name : the activation name, that can be used to retrieve activation-wise buffers
        """
        return x * (x < self.get('threshold')).float()



module = Greg()
bended = tb.BendedModule(module)
x = torch.randn(1, 1, 32)
_, out = bended.trace(x=x, _return_out=True)
print(out[0])

c1 = tb.BendingParameter('threshold', value=1.)
bended.bend(UpClamp(threshold=c1), "conv_module_2$", "conv_module_1.weight")
out = bended(x)
print(out)

bended.update(c1.name, 0.)
out = bended(x)
print(out)

tensor([[[-0.2619,  0.1348,  0.3441,  0.4760,  0.6323,  0.5979,  0.7836,
           0.3744, -0.0860, -0.3237,  0.1679,  0.4835,  0.9399,  0.3376,
           0.2188,  0.0457,  0.4031,  0.6875,  0.5118,  0.2731,  0.0115,
           0.2652,  0.1423,  0.2137,  0.0537,  0.2528,  0.1761,  0.4793]]],
       grad_fn=<ConvolutionBackward0>)
tensor([[[-0.2619,  0.1348,  0.3441,  0.4760,  0.6323,  0.5979,  0.7836,
           0.3744, -0.0860, -0.3237,  0.1679,  0.4835,  0.9399,  0.3376,
           0.2188,  0.0457,  0.4031,  0.6875,  0.5118,  0.2731,  0.0115,
           0.2652,  0.1423,  0.2137,  0.0537,  0.2528,  0.1761,  0.4793]]],
       grad_fn=<MulBackward0>)
tensor([[[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
          0., 0., 0., 0., 0.]]], grad_fn=<MulBackward0>)
