# Defining callbacks


We will see in this tutorial how to define a bending callaback to perform custom callback operations. 

## Defining stateless callbacks

To start easy, let's define a callback that does not need any internal buffer. A bending callback must define two functions : 
- the `forward` function, that is used for activation bending
- the `apply_to_param` function, that is used for weight bending


In [5]:
import torch, torch.nn as nn
from typing import Optional
import sys; sys.path.append("..")
import torchbend as tb

class Greg(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_module_1 = nn.Conv1d(1, 4, 3)
        self.conv_module_2 = nn.Conv1d(4, 1, 3)
    def forward(self, x):
        out_1 = self.conv_module_1(x)
        out_2 = self.conv_module_2(out_1)
        return out_2
    
class Square(tb.BendingCallback):
    # define class attributes to provide meta information on what the callback can do
    weight_compatible = True 
    activation_compatible = True
    jit_compatible = True 
    nntilde_compatible = True 

    def apply_to_param(self, idx: int, param: nn.Parameter, cache: Optional[torch.Tensor]):
        """
        for weight bending, apply_to_param receives three arguments : 
        - idx of recorded param (int) : bended parameters are recorded as lists to be jit-compatible, such that providing the id is useful to recover 
        paramter-wise buffers
        - param : the module parameter, modified in place
        - cache : (optional): callbacks cache the original parameter value for dynamical modification
        """
        param.set_(cache * cache)

    def forward(self, x: torch.Tensor, name: Optional[str] = None):
        """
        for activation bending, apply_to_param receives two arguments : 
        - x : value of activation
        - name : the activation name, that can be used to retrieve activation-wise buffers
        """
        return x * x


module = Greg()
bended = tb.BendedModule(module)
x = torch.randn(1, 1, 32)
_, out = bended.trace(x=x, _return_out=True)
print(out[0])

bended.bend(Square(), "conv_module_1$", "conv_module.1.weight")
out = bended(x)
print(out)



tensor([[[ 0.4892,  0.5153,  0.4182,  0.4789,  0.3084,  0.2741,  0.1425,
           0.1976,  0.0509,  0.1303,  0.0505,  0.2202,  0.1965,  0.2649,
           0.2580,  0.4473,  0.2939,  0.2936,  0.1679, -0.0621,  0.3148,
           0.1342,  0.3729,  0.3753,  0.1876,  0.2506, -0.1649, -0.0830]]],
       grad_fn=<ConvolutionBackward0>)
tensor([[[-0.1816, -0.1793, -0.1797, -0.1860, -0.1238, -0.1114, -0.0931,
          -0.0999, -0.0282, -0.0395, -0.0591, -0.0896, -0.0855, -0.1226,
          -0.1002, -0.1346, -0.1343, -0.1034, -0.0558, -0.0400, -0.0925,
          -0.0800, -0.1440, -0.1201, -0.0507, -0.0346, -0.0106, -0.0324]]],
       grad_fn=<ConvolutionBackward0>)


## Stateful callbacks 

Stateful callbacks may be a little trickier, nothing heavy do not worry but some steps can be cumbersome due to torch.jit.  

In [None]:
class UpClamp(tb.BendingCallback):
    weight_compatible = True 
    activation_compatible = True
    jit_compatible = True 
    nntilde_compatible = True 
    # provide this to inform BendingCallback that this init arg can be contorlled
    controllable_params = ['threshold']

    def __init__(self, threshold = 1.):
        super().__init__()
        self.threshold = threshold

    def apply_to_param(self, idx: int, param: nn.Parameter, cache: Optional[torch.Tensor]):
        """
        for weight bending, apply_to_param receives three arguments : 
        - idx of recorded param (int) : bended parameters are recorded as lists to be jit-compatible, such that providing the id is useful to recover 
        paramter-wise buffers
        - param : the module parameter, modified in place
        - cache : (optional): callbacks cache the original parameter value for dynamical modification
        """
        param.set_(cache * cache)

    def forward(self, x: torch.Tensor, name: Optional[str] = None):
        """
        for activation bending, apply_to_param receives two arguments : 
        - x : value of activation
        - name : the activation name, that can be used to retrieve activation-wise buffers
        """
        return x * x