# Extended tracing

We will now dive into `torchbend`'s extended tracer. Let's take this weird `nn.Module`, which is totally useless besides demonstrating how `torchbend`'s tracer extends the original `torch.fx`'s one and experiment tracing and bending.

In [2]:
import torch, torch.nn as nn
import sys; sys.path.append("..")
import torchbend as tb

class Doug(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_module = nn.Conv1d(1, 16, 3)
        self.batch_norm = nn.BatchNorm1d(16)
        
    def forward(self, x: torch.Tensor):
        if x.shape[1] > 1: 
            return torch.cat([self.forward(x[:, [i]]) for i in range(x.shape[1])], -1)
        else:
            outs = []
            for i in range(x.shape[0]):
                out_tmp = self.conv_module(x[[i]])
                out_tmp = self.batch_norm(out_tmp)
                outs.append(out_tmp)
            return sum(outs)

module = Doug()

try: 
    torch.fx.symbolic_trace(module)
except torch.fx.proxy.TraceError as e:
    print('torch.fx error : ', e)

bended = tb.BendedModule(module)
bended.trace(x=torch.randn(1, 1, 16))
print('torchbend graph: ')
print(bended.graph())

torch.fx error :  symbolically traced variables cannot be used as inputs to control flow
torchbend graph: 
graph():
    %x : torch.Tensor [num_users=3] = placeholder[target=x]
    %getattr_1 : [num_users=1] = call_function[target=builtins.getattr](args = (%x, shape), kwargs = {})
    %getitem : int [num_users=1] = call_function[target=operator.getitem](args = (%getattr_1, 1), kwargs = {})
    %gt : [num_users=0] = call_function[target=operator.gt](args = (%getitem, 1), kwargs = {})
    %getattr_2 : [num_users=1] = call_function[target=builtins.getattr](args = (%x, shape), kwargs = {})
    %getitem_1 : int [num_users=0] = call_function[target=operator.getitem](args = (%getattr_2, 0), kwargs = {})
    %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%x, [0]), kwargs = {})
    %conv_module : [num_users=1] = call_module[target=conv_module](args = (%getitem_2,), kwargs = {})
    %batch_norm : [num_users=1] = call_module[target=batch_norm](args = (%conv_module,), k

What happens here? `torch.fx` is a pure symbolic tracer, that feeds `Proxy` items as the input to record every operation of a computing graph without actually processing it. While this has many advantages, it also prevents several operations that depends on the concrete value of the arguments, as here iterating through the shape of the input argument. `torchbend` alleviates this by doubling this pure symbolical tracing with a parallel execution process, at the cost of giving the input argument `x`. 

![tracing process](img/tracing.png "Tracing process")

### Hardcoded control-flow

While this doubled tracing process to does not limit the original `torch.fx` tracer, it may ask some cautions (similarly to how [jax](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) handles computational graphs). Indeed, the control flow is some how "hardcoded" in the graph, such that different inputs may lead to different graphs :

In [5]:
module = Doug()

bended = tb.BendedModule(module)
bended.trace(x=torch.randn(1, 1, 16))
print(bended.graph())

bended.trace(x=torch.randn(1, 2, 16))
print(bended.graph())


graph():
    %x : torch.Tensor [num_users=3] = placeholder[target=x]
    %getattr_1 : [num_users=1] = call_function[target=builtins.getattr](args = (%x, shape), kwargs = {})
    %getitem : int [num_users=1] = call_function[target=operator.getitem](args = (%getattr_1, 1), kwargs = {})
    %gt : [num_users=0] = call_function[target=operator.gt](args = (%getitem, 1), kwargs = {})
    %getattr_2 : [num_users=1] = call_function[target=builtins.getattr](args = (%x, shape), kwargs = {})
    %getitem_1 : int [num_users=0] = call_function[target=operator.getitem](args = (%getattr_2, 0), kwargs = {})
    %getitem_2 : [num_users=1] = call_function[target=operator.getitem](args = (%x, [0]), kwargs = {})
    %conv_module : [num_users=1] = call_module[target=conv_module](args = (%getitem_2,), kwargs = {})
    %batch_norm : [num_users=1] = call_module[target=batch_norm](args = (%conv_module,), kwargs = {})
    %add : [num_users=1] = call_function[target=operator.add](args = (0, %batch_norm), kwargs =

We can see that, in the second case, the loop is "hardcoded" in the graph. The graph is then different, and the result may change with the same input : 

In [11]:
module = Doug()

bended = tb.BendedModule(module)
bended.trace(x=torch.randn(1, 1, 16))
bended(torch.randn(1, 1, 16))
try: 
    bended(torch.randn(1, 2, 16))
except Exception as e:
    print("raised error : ", e)
    

bended.trace(x=torch.randn(1, 2, 16))
bended(torch.randn(1, 2, 16))
try: 
    bended(torch.randn(1, 1, 16))
except Exception as e:
    print("raised error :", e)
    


raised error :  Given groups=1, weight of size [16, 1, 3], expected input[1, 2, 16] to have 1 channels, but got 2 channels instead
raised error : index 1 is out of bounds for dimension 0 with size 1


Traceback (most recent call last):
  File "/Users/domkirke/miniconda3/envs/ml2/lib/python3.11/site-packages/torch/fx/graph_module.py", line 303, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/domkirke/miniconda3/envs/ml2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/domkirke/miniconda3/envs/ml2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "<eval_with_key>.12", line 20, in forward
    getitem_6 = x[(slice(None, None, None), [1])];  x = None
                ~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
IndexError: index 1 is out of bounds for dimension 0 with size 1

Call using an FX-traced Module, line 20 of the 

Each steps is recorded into the graph, and can be retrieved by accessing the `flow_steps` attribute of the corresponding `torch.fx.Graph` object : 

In [12]:
bended.graph().flow_steps

[LogicalFlowStep(name=gt, value=True, file=/var/folders/vk/nn1706pd25b57gxz9y3p1wqh0000gn/T/ipykernel_87887/1390080083.py:forward.11)),
 LogicalFlowStep(name=gt_1, value=False, file=/var/folders/vk/nn1706pd25b57gxz9y3p1wqh0000gn/T/ipykernel_87887/1390080083.py:forward.11)),
 LogicalFlowStep(name=gt_2, value=False, file=/var/folders/vk/nn1706pd25b57gxz9y3p1wqh0000gn/T/ipykernel_87887/1390080083.py:forward.11))]

## Tracer extensions : shape, loop, logical flow

So far three operations that are not allowed by original `torch.fx.Tracer` are implemented : shape attribues, logical control flow, and loops. Let's see that we these specific modules : 

In [19]:
class LoopFoo(nn.Module):
    def forward(self, x, n: int):
        for i in range(n):
            x = x * x
        return x

class LogicalFoo(nn.Module):
    def forward(self, x):
        if x.all():
            return 1
        else: 
            return 0

class ShapeFoo(nn.Module):
    def forward(self, x):
        return x * x.shape[0]

#TODO does not work
# foo = LoopFoo()
# try: 
#     mod = torch.fx.symbolic_trace(foo)
# except torch.fx.proxy.TraceError as e:
#     print(e)
# bended = tb.BendedModule(foo)
# graph, out = bended.trace(x=torch.tensor(2), n=4, _return_out=True)
# print(out, graph.flow_steps)

foo = LogicalFoo()
try: 
    mod = torch.fx.symbolic_trace(foo)
except torch.fx.proxy.TraceError as e:
    print("torch.fx error :", e)
bended = tb.BendedModule(foo)
graph, out = bended.trace(x=torch.tensor(0), _return_out=True)
print(out, graph.flow_steps)


foo = ShapeFoo()
try: 
    mod = torch.fx.symbolic_trace(foo)
except torch.fx.proxy.TraceError as e:
    print("torch.fx error :", e)
bended = tb.BendedModule(foo)
graph, out = bended.trace(x=torch.ones(4), _return_out=True)
print(out, graph.flow_steps)
# example critical case
out_ok = bended(torch.ones(6))
print('ok case : ', out_ok)



torch.fx error : symbolically traced variables cannot be used as inputs to control flow
(None,) [LogicalFlowStep(name=all_1, value=False, file=/var/folders/vk/nn1706pd25b57gxz9y3p1wqh0000gn/T/ipykernel_87887/2232374878.py:forward.8))]
(tensor([4., 4., 4., 4.]),) []
critical case :  tensor([6., 6., 6., 6., 6., 6.])


## Cautions with activation bending

This hard-coded graph for critical values can also impact some activation bending callbacks, that have to initialize their internal states with a given shape. For example, let's take two different bending callbacks : 
- `tb.Mask`, that initializes a binary mask for a given target (hence requiring a shape), 
- `tb.bias`, that biases a given target with a static scalar value (hence not requiring a shape)

We can see that the former cannot be adapted to a change of shape, while the second can. Such issues may be alleviated by carefully adapting the bending callback, as we do here by maksing only specific channels.


In [44]:
class Greg(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_module = nn.Conv1d(1, 4, 3)
    def forward(self, x):
        return self.conv_module(x)

module = Greg()
bended = tb.BendedModule(module)
x = torch.randn(1, 1, 16)
graph, out = bended.trace(x=x, _return_out=True)
bended.print_activations()
print('\n-- Original out')
print(out[0])

# bending with Scale
#TODO wtf
print('\n-- Bias')
bended.bend(tb.Bias(1.), "conv_module$")
print(bended(x))
bended(torch.randn(4, 1, 32)); # -> OK

print('\n-- Mask')
bended.reset()
bended.bend(tb.Mask(0.3), "conv_module$")
print(bended(x))
try:
    bended(torch.randn(4, 1, 32)); # -> not OK
except Exception as e:
    print("Error with different shape : ", e)


print('\n-- Channeled mask')
bended.reset()
bended.bend(tb.Mask(prob=0.3, dim=-2), "conv_module$")
print(bended(x))
bended(torch.randn(4, 1, 32)); # -> OK



-----------  -----------  ----------------------
x            placeholder  torch.Size([1, 1, 16])
conv_module  call_module  torch.Size([1, 4, 14])
-----------  -----------  ----------------------

-- Original out
tensor([[[-1.0954, -0.4893,  0.1828, -0.0825, -0.8288, -0.8804,  0.5599,
          -0.1416, -1.2253, -0.0266, -0.0884, -0.4542,  0.2300, -1.3383],
         [ 0.4717, -0.3082, -0.4564, -0.3030,  0.6222, -0.1784, -0.9403,
           0.2578,  0.3855, -0.7710,  0.3001, -0.7324,  0.0178,  0.6979],
         [-0.1401, -0.3743, -0.2121,  0.2424,  0.1981, -0.0893, -0.4107,
           0.3499,  0.3239, -0.4214, -0.0099,  0.2867, -0.1838,  0.4334],
         [-0.8375, -0.4086, -0.5440, -0.0560, -0.9016,  0.0270, -0.3381,
          -0.7345, -0.0863, -0.1345, -1.2418,  0.9393, -1.2654, -0.3447]]],
       grad_fn=<ConvolutionBackward0>)

-- Bias
tensor([[[-0.0954,  0.5107,  1.1828,  0.9175,  0.1712,  0.1196,  1.5599,
           0.8584, -0.2253,  0.9734,  0.9116,  0.5458,  1.2300, -0.3383],
  