# Defining callbacks


We will see in this tutorial how to define a bending callaback to perform custom callback operations. 

## Defining stateless callbacks

To start easy, let's define a callback that does not need any internal buffer. A bending callback must define two functions : 
- the `forward` function, that is used for activation bending
- the `apply_to_param` function, that is used for weight bending


In [3]:
import torch, torch.nn as nn
from typing import Optional
import sys; sys.path.append("..")
import torchbend as tb

class Greg(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_module_1 = nn.Conv1d(1, 4, 3)
        self.conv_module_2 = nn.Conv1d(4, 1, 3)
    def forward(self, x):
        out_1 = self.conv_module_1(x)
        out_2 = self.conv_module_2(out_1)
        return out_2
    
class Square(tb.BendingCallback):
    # define class attributes to provide meta information on what the callback can do
    weight_compatible = True 
    activation_compatible = True
    jit_compatible = True 
    nntilde_compatible = True 

    def apply_to_param(self, idx: int, param: nn.Parameter, cache: Optional[torch.Tensor]):
        """
        for weight bending, apply_to_param receives three arguments : 
        - idx of recorded param (int) : bended parameters are recorded as lists to be jit-compatible, such that providing the id is useful to recover 
        paramter-wise buffers
        - param : the module parameter, modified in place
        - cache : (optional): callbacks cache the original parameter value for dynamical modification
        """
        if cache is not None:
            param.set_(cache * cache)

    def forward(self, x: torch.Tensor, name: Optional[str] = None):
        """
        for activation bending, apply_to_param receives two arguments : 
        - x : value of activation
        - name : the activation name, that can be used to retrieve activation-wise buffers
        """
        return x * x


module = Greg()
bended = tb.BendedModule(module)
x = torch.randn(1, 1, 32)
_, out = bended.trace(x=x, _return_out=True)
print(out[0])

bended.bend(Square(), "conv_module_1$", "conv_module.1.weight")
out = bended(x)
print(out)

scripted = bended.script()

tensor([[[ 0.0477, -0.3881,  0.8609, -0.2017,  0.3453, -0.2270,  0.1903,
           0.3508, -0.3234,  0.2075,  0.2420, -0.5515,  0.1830, -0.2658,
           0.1675,  0.1738,  0.1183,  0.0775, -0.0210,  0.0071,  0.1731,
          -0.1167, -0.3698,  0.5692,  0.1424, -0.1753,  0.0266,  0.2850]]],
       grad_fn=<ConvolutionBackward0>)
tensor([[[-0.0461,  0.0040, -0.0027,  0.0541,  0.0488,  0.0567,  0.0513,
           0.0700,  0.0420,  0.0533,  0.0543,  0.0456, -0.0016, -0.0268,
          -0.0409, -0.0141,  0.0076,  0.0398,  0.0332, -0.0019,  0.0298,
          -0.0193, -0.0367, -0.1067,  0.0098, -0.0542,  0.0266, -0.0605]]],
       grad_fn=<ConvolutionBackward0>)
caca
caca
caca
caca


RuntimeError: 
Arguments for call are not valid.
The following variants are available:
  
  aten::mul.Tensor(Tensor self, Tensor other) -> Tensor:
  Expected a value of type 'Tensor' for argument 'self' but instead found type 'Optional[Tensor]'.
  
  aten::mul.Scalar(Tensor self, Scalar other) -> Tensor:
  Expected a value of type 'Tensor' for argument 'self' but instead found type 'Optional[Tensor]'.
  
  aten::mul.out(Tensor self, Tensor other, *, Tensor(a!) out) -> Tensor(a!):
  Expected a value of type 'Tensor' for argument 'self' but instead found type 'Optional[Tensor]'.
  
  aten::mul.Scalar_out(Tensor self, Scalar other, *, Tensor(a!) out) -> Tensor(a!):
  Expected a value of type 'Tensor' for argument 'self' but instead found type 'Optional[Tensor]'.
  
  aten::mul.left_t(t[] l, int n) -> t[]:
  Could not match type Optional[Tensor] to List[t] in argument 'l': Cannot match List[t] to Optional[Tensor].
  
  aten::mul.right_(int n, t[] l) -> t[]:
  Expected a value of type 'int' for argument 'n' but instead found type 'Optional[Tensor]'.
  
  aten::mul.int(int a, int b) -> int:
  Expected a value of type 'int' for argument 'a' but instead found type 'Optional[Tensor]'.
  
  aten::mul.complex(complex a, complex b) -> complex:
  Expected a value of type 'complex' for argument 'a' but instead found type 'Optional[Tensor]'.
  
  aten::mul.float(float a, float b) -> float:
  Expected a value of type 'float' for argument 'a' but instead found type 'Optional[Tensor]'.
  
  aten::mul.int_complex(int a, complex b) -> complex:
  Expected a value of type 'int' for argument 'a' but instead found type 'Optional[Tensor]'.
  
  aten::mul.complex_int(complex a, int b) -> complex:
  Expected a value of type 'complex' for argument 'a' but instead found type 'Optional[Tensor]'.
  
  aten::mul.float_complex(float a, complex b) -> complex:
  Expected a value of type 'float' for argument 'a' but instead found type 'Optional[Tensor]'.
  
  aten::mul.complex_float(complex a, float b) -> complex:
  Expected a value of type 'complex' for argument 'a' but instead found type 'Optional[Tensor]'.
  
  aten::mul.int_float(int a, float b) -> float:
  Expected a value of type 'int' for argument 'a' but instead found type 'Optional[Tensor]'.
  
  aten::mul.float_int(float a, int b) -> float:
  Expected a value of type 'float' for argument 'a' but instead found type 'Optional[Tensor]'.
  
  aten::mul(Scalar a, Scalar b) -> Scalar:
  Expected a value of type 'number' for argument 'a' but instead found type 'Optional[Tensor]'.
  
  mul(float a, Tensor b) -> Tensor:
  Expected a value of type 'float' for argument 'a' but instead found type 'Optional[Tensor]'.
  
  mul(int a, Tensor b) -> Tensor:
  Expected a value of type 'int' for argument 'a' but instead found type 'Optional[Tensor]'.
  
  mul(complex a, Tensor b) -> Tensor:
  Expected a value of type 'complex' for argument 'a' but instead found type 'Optional[Tensor]'.

The original call is:
  File "/var/folders/vk/nn1706pd25b57gxz9y3p1wqh0000gn/T/ipykernel_45941/43206288.py", line 31
        - cache : (optional): callbacks cache the original parameter value for dynamical modification
        """
        param.set_(cache * cache)
                   ~~~~~~~~~~~~~ <--- HERE
'Square.apply_to_param' is being compiled since it was called from 'Square.apply'
  File "/Users/domkirke/Dropbox/code/torchbend/docs/../torchbend/bending/base.py", line 168
        for i, v in enumerate(self._bending_targets):
            v_cached = self.cache_from_id(i).data
            self.apply_to_param(i, v, v_cached)
            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
'Square.apply' is being compiled since it was called from 'ScriptedBendedModule._update_weights'
  File "/Users/domkirke/Dropbox/code/torchbend/docs/../torchbend/tracing/script.py", line 150
            for i, c in enumerate(self._bending_callbacks):
                for j in callbacks:
                    if i == j: c.apply()
                               ~~~~~~~ <--- HERE
'ScriptedBendedModule._update_weights' is being compiled since it was called from 'ScriptedBendedModule._set_bending_control'
  File "/Users/domkirke/Dropbox/code/torchbend/docs/../torchbend/tracing/script.py", line 167
            if v.name == name:
                v.set_value(value)
        self._update_weights(name)
        ~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
        return 0


## Stateful callbacks 

Stateful callbacks may be a little trickier, nothing heavy do not worry but some steps can be cumbersome due to torch.jit. Stateful callbacks implies overriding some more methods :
- `register_parameter` : register a bended parameter in the callback module
- `register_activation` : register a bended activation shape in the callback module
- `update`, that updates inner states after a `BendingParameter` change

In [2]:
class UpClamp(tb.BendingCallback):
    weight_compatible = True 
    activation_compatible = True
    jit_compatible = True 
    nntilde_compatible = True 

    # provide this to inform BendingCallback that this init arg can be contorlled
    controllable_params = ['threshold']

    def __init__(self, threshold = 1.):
        super().__init__()
        self.threshold = threshold
        self._masks = torch.nn.ParameterList()
        self._mask_names = []

    def _add_mask(self, name, parameter):
        mask = nn.Parameter((parameter.data > self.get('threshold')).float(), requires_grad=False)
        self._masks.append(mask)
        # disable gradient
        self._mask_names.append(name)

    def register_parameter(self, parameter, name=None, cache = True):
        name = super().register_parameter(parameter, name=name, cache=cache)
        self._add_mask(name, parameter=parameter)

    def register_activation(self, name, shape):
        name = super().register_activation(name, shape)
        # here we don't need to do anything, as only parameter updates require states
        return name
    
    def get_mask_from_index(self, idx: int):
        # Don't judge me, this is because torch.jit only allows literal indexing. 
        for i, v in enumerate(self._masks):
            if i == idx:
                return v
        raise tb.BendingCallbackException('%d not present in masks'%idx)
    
    def update(self):
        for i, v in enumerate(self._masks):
                v.data.set_((self.get_cache(i) > self.get('threshold')).float())

    def apply_to_param(self, idx: int, param: nn.Parameter, cache: Optional[torch.Tensor]):
        """
        for weight bending, apply_to_param receives three arguments : 
        - idx of recorded param (int) : bended parameters are recorded as lists to be jit-compatible, such that providing the id is useful to recover 
        paramter-wise buffers
        - param : the module parameter, modified in place
        - cache : (optional): callbacks cache the original parameter value for dynamical modification
        """
        assert cache is not None
        param.set_(cache * self.get_mask_from_index(idx))

    def forward(self, x: torch.Tensor, name: Optional[str] = None):
        """
        for activation bending, apply_to_param receives two arguments : 
        - x : value of activation
        - name : the activation name, that can be used to retrieve activation-wise buffers
        """
        return x * (x < self.get('threshold')).float()

module = Greg()
bended = tb.BendedModule(module)
x = torch.randn(1, 1, 32)
_, out = bended.trace(x=x, _return_out=True)
print(out[0])

c1 = tb.BendingParameter('threshold', value=1.)
bended.bend(UpClamp(threshold=c1), "conv_module_2$", "conv_module_1.weight")
out = bended(x)
print(out)

bended.update(c1.name, 0.)
out = bended(x)
print(out)

scripted = bended.script()

tensor([[[ 0.3871,  0.3391, -0.0650,  0.0677,  0.4748, -0.0197, -0.0200,
           0.4790,  0.6838, -0.0125, -0.1180,  0.0709,  0.4235,  0.3217,
           0.3106,  0.1544,  0.1873,  0.1464, -0.4558,  0.0875,  0.7703,
           0.3424, -0.0955,  0.1810,  0.6059,  0.0449,  0.0875,  0.5416]]],
       grad_fn=<ConvolutionBackward0>)
tensor([[[ 0.3871,  0.3391, -0.0650,  0.0677,  0.4748, -0.0197, -0.0200,
           0.4790,  0.6838, -0.0125, -0.1180,  0.0709,  0.4235,  0.3217,
           0.3106,  0.1544,  0.1873,  0.1464, -0.4558,  0.0875,  0.7703,
           0.3424, -0.0955,  0.1810,  0.6059,  0.0449,  0.0875,  0.5416]]],
       grad_fn=<MulBackward0>)
tensor([[[ 0.0000,  0.0000, -0.1217, -0.0214,  0.0000,  0.0000, -0.1343,
           0.0000,  0.0000,  0.0000, -0.3166, -0.0575,  0.0000,  0.0000,
           0.0000,  0.0000,  0.0000,  0.0000, -0.6491, -0.0991,  0.0000,
           0.0000, -0.1503,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000]]],
       grad_fn=<MulBackward0>)
