# Defining callbacks


We will see in this tutorial how to define a bending callaback to perform custom callback operations. 

## Defining stateless callbacks

To start easy, let's define a callback that does not need any internal buffer. A bending callback must define two functions : 
- the `forward` function, that is used for activation bending
- the `apply_to_param` function, that is used for weight bending


In [4]:
import torch, torch.nn as nn
from typing import Optional
import sys; sys.path.append("..")
import torchbend as tb

class Greg(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_module_1 = nn.Conv1d(1, 4, 3)
        self.conv_module_2 = nn.Conv1d(4, 1, 3)
    def forward(self, x):
        out_1 = self.conv_module_1(x)
        out_2 = self.conv_module_2(out_1)
        return out_2
    
class Square(tb.BendingCallback):
    # define class attributes to provide meta information on what the callback can do
    weight_compatible = True 
    activation_compatible = True
    jit_compatible = True 
    nntilde_compatible = True 

    def apply_to_param(self, idx: int, param: nn.Parameter, cache: Optional[torch.Tensor]):
        """
        for weight bending, apply_to_param receives three arguments : 
        - idx of recorded param (int) : bended parameters are recorded as lists to be jit-compatible, such that providing the id is useful to recover 
        paramter-wise buffers
        - param : the module parameter, modified in place
        - cache : (optional): callbacks cache the original parameter value for dynamical modification
        """
        if cache is not None:
            param.set_(cache * cache)

    def forward(self, x: torch.Tensor, name: Optional[str] = None):
        """
        for activation bending, apply_to_param receives two arguments : 
        - x : value of activation
        - name : the activation name, that can be used to retrieve activation-wise buffers
        """
        return x * x


module = Greg()
bended = tb.BendedModule(module)
x = torch.randn(1, 1, 32)
_, out = bended.trace(x=x, _return_out=True)
print(out[0])

bended.bend(Square(), "conv_module_1$", "conv_module.1.weight")
out = bended(x)
print(out)

scripted = bended.script()

tensor([[[ 0.4191,  0.3484, -0.3344, -0.4367, -0.2221,  0.1707, -0.0082,
          -0.1306, -0.0400, -0.3263, -0.1859, -0.3868, -0.6936, -0.0299,
          -0.4795, -0.2998, -0.3795,  0.4457,  0.4713, -0.0185, -0.1110,
          -0.0366,  0.2109,  0.0369, -0.4674, -0.2995,  0.2409,  0.1765]]],
       grad_fn=<ConvolutionBackward0>)
tensor([[[-0.0446, -0.1115, -0.0431, -0.0711, -0.1146, -0.1058, -0.0979,
          -0.1205, -0.1049, -0.1261, -0.1084, -0.1138, -0.1679, -0.1088,
          -0.0636, -0.0761, -0.1908,  0.0334, -0.0152, -0.1071, -0.1122,
          -0.1065, -0.0984, -0.1044, -0.0904, -0.1380, -0.0618, -0.1113]]],
       grad_fn=<ConvolutionBackward0>)


## Stateful callbacks 

Stateful callbacks may be a little trickier, nothing heavy do not worry but some steps can be cumbersome due to torch.jit. Stateful callbacks implies overriding some more methods :
- `register_parameter` : register a bended parameter in the callback module
- `register_activation` : register a bended activation shape in the callback module
- `update`, that updates inner states after a `BendingParameter` change

In [2]:
class UpClamp(tb.BendingCallback):
    weight_compatible = True 
    activation_compatible = True
    jit_compatible = True 
    nntilde_compatible = True 

    # provide this to inform BendingCallback that this init arg can be contorlled
    controllable_params = ['threshold']

    def __init__(self, threshold = 1.):
        super().__init__()
        self.threshold = threshold
        self._masks = torch.nn.ParameterList()
        self._mask_names = []

    def _add_mask(self, name, parameter):
        mask = nn.Parameter((parameter.data > self.get('threshold')).float(), requires_grad=False)
        self._masks.append(mask)
        # disable gradient
        self._mask_names.append(name)

    def register_parameter(self, parameter, name=None, cache = True):
        name = super().register_parameter(parameter, name=name, cache=cache)
        self._add_mask(name, parameter=parameter)

    def register_activation(self, name, shape):
        name = super().register_activation(name, shape)
        # here we don't need to do anything, as only parameter updates require states
        return name
    
    def get_mask_from_index(self, idx: int):
        # Don't judge me, this is because torch.jit only allows literal indexing. 
        for i, v in enumerate(self._masks):
            if i == idx:
                return v
        raise tb.BendingCallbackException('%d not present in masks'%idx)
    
    def update(self):
        for i, v in enumerate(self._masks):
                v.data.set_((self.get_cache(i) > self.get('threshold')).float())

    def apply_to_param(self, idx: int, param: nn.Parameter, cache: Optional[torch.Tensor]):
        """
        for weight bending, apply_to_param receives three arguments : 
        - idx of recorded param (int) : bended parameters are recorded as lists to be jit-compatible, such that providing the id is useful to recover 
        paramter-wise buffers
        - param : the module parameter, modified in place
        - cache : (optional): callbacks cache the original parameter value for dynamical modification
        """
        assert cache is not None
        param.set_(cache * self.get_mask_from_index(idx))

    def forward(self, x: torch.Tensor, name: Optional[str] = None):
        """
        for activation bending, apply_to_param receives two arguments : 
        - x : value of activation
        - name : the activation name, that can be used to retrieve activation-wise buffers
        """
        return x * (x < self.get('threshold')).float()

module = Greg()
bended = tb.BendedModule(module)
x = torch.randn(1, 1, 32)
_, out = bended.trace(x=x, _return_out=True)
print(out[0])

c1 = tb.BendingParameter('threshold', value=1.)
bended.bend(UpClamp(threshold=c1), "conv_module_2$", "conv_module_1.weight")
out = bended(x)
print(out)

bended.update(c1.name, 0.)
out = bended(x)
print(out)

scripted = bended.script()

tensor([[[ 0.3871,  0.3391, -0.0650,  0.0677,  0.4748, -0.0197, -0.0200,
           0.4790,  0.6838, -0.0125, -0.1180,  0.0709,  0.4235,  0.3217,
           0.3106,  0.1544,  0.1873,  0.1464, -0.4558,  0.0875,  0.7703,
           0.3424, -0.0955,  0.1810,  0.6059,  0.0449,  0.0875,  0.5416]]],
       grad_fn=<ConvolutionBackward0>)
tensor([[[ 0.3871,  0.3391, -0.0650,  0.0677,  0.4748, -0.0197, -0.0200,
           0.4790,  0.6838, -0.0125, -0.1180,  0.0709,  0.4235,  0.3217,
           0.3106,  0.1544,  0.1873,  0.1464, -0.4558,  0.0875,  0.7703,
           0.3424, -0.0955,  0.1810,  0.6059,  0.0449,  0.0875,  0.5416]]],
       grad_fn=<MulBackward0>)
tensor([[[ 0.0000,  0.0000, -0.1217, -0.0214,  0.0000,  0.0000, -0.1343,
           0.0000,  0.0000,  0.0000, -0.3166, -0.0575,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000, -0.6491, -0.0991,  0.0000,
           0.0000, -0.1503,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]],
       grad_fn=<MulBackward0>)
