## Compiling `TensorDictParams`

Vincent asked me to look into fixing the behavior of compiled `TensorDictParams`. He gave a few examples here: <https://github.com/pytorch/tensordict/pull/1100#issuecomment-2491805272>

### This case works

In [1]:
from tensordict import from_module, TensorDictParams
import torch.nn

module = torch.nn.Module()
module.params = torch.nn.Parameter(torch.randn(3))
params2 = from_module(module).data.clone()
params2 *= 0
params2 = TensorDictParams(params2)

@torch.compile(fullgraph=True)
def func(z, params2):
    with params2.to_module(module):
        out = z + module.params
    return out

print(func(torch.zeros(()), params2))

tensor([0., 0., 0.], grad_fn=<CompiledFunctionBackward>)


### This case does not work

In [2]:
from tensordict import from_module, TensorDictParams
import torch.nn

module = torch.nn.Module()
module.params = torch.nn.Parameter(torch.randn(3))
params2 = from_module(module).data.clone()
params2 *= 0
params2 = TensorDictParams(params2)
# Isolate the inner tensordict
params2 = params2._param_td

@torch.compile(fullgraph=True)
def func(z, params2):
    with params2.to_module(module):
        out = z + module.params
    return out

print(func(torch.zeros(()), params2))

Unsupported: Graph break under GenericContextWrappingVariable

from user code:
   File "/tmp/ipykernel_24205/3934583939.py", line 15, in func
    out = z + module.params

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True


Notice that the only difference between the two examples aboveis that the second one has the extra line `params2 = params2._param_td` before calling the compiled function. What does this line do?

In [3]:
from tensordict import from_module, TensorDictParams
import torch.nn

module = torch.nn.Module()
module.params = torch.nn.Parameter(torch.randn(3))
params2 = from_module(module).data.clone()
params2 *= 0
params2 = TensorDictParams(params2)
print(f"before: {type(params2)}")
# Isolate the inner tensordict
params2 = params2._param_td
print(f"after: {type(params2)}")

before: <class 'tensordict.nn.params.TensorDictParams'>
after: <class 'tensordict._td.TensorDict'>


As we can see, before the extra line, `params2` is the type `TensorDictParams`. So `TensorDictParams._param_td` gives us the underlying `TensorDict` representation of the `TensorDictParams` object.

As Vincent says, the error is that "dynamo doesn't like that we set attribute to a MutableMapping". 

The relevant part from the error message is this:

```
File ~/miniconda/envs/tensordict-0/lib/python3.9/site-packages/torch/_dynamo/exc.py:313, in unimplemented(msg, from_exc, case_name)
    312     raise Unsupported(msg, case_name=case_name) from from_exc
--> 313 raise Unsupported(msg, case_name=case_name)

Unsupported: setattr(MutableMappingVariable(TensorDict), _last_op, ...)
```

The attribute Vincent is talking about is the `TensorDictParams._param_td`, and `TensorDictParams` is a subclass of `collections.abc.MutableMapping`.

In [8]:
from tensordict import from_module, TensorDictParams
import torch.nn
from collections.abc import MutableMapping

module = torch.nn.Module()
module.params = torch.nn.Parameter(torch.randn(3))
params2 = from_module(module).data.clone()
params2 *= 0
params2 = TensorDictParams(params2)
print(issubclass(type(params2), MutableMapping))

True


`MutableMappingVariable` is defined in pytorch under [`torch/_dynamo/variables/user_defined.py`](https://github.com/pytorch/pytorch/blob/f044c1a7c89f1893e2dbef5d3fcc02b7a090484a/torch/_dynamo/variables/user_defined.py#L1412)

On the face of it, it seems like we need to add `MutableMappingVariable.var_setattr`.

### Another failing case

In [9]:
from tensordict import from_module, TensorDictParams, TensorDict
import torch.nn

module = torch.nn.Module()
module.params = TensorDictParams(
    # string="a string!",
    TensorDict(a=0.0)
)
params2 = from_module(module).data.clone()
params2 *= 0
params2 = TensorDictParams(params2)

@torch.compile(fullgraph=True)
def func(z, params2):
    with params2.to_module(module):
        out = z + module.params["a"]
    return out

print(func(torch.zeros(()), params2))

Unsupported: call_method GetAttrVariable(GenericContextWrappingVariable(TensorDictParams), __dict__) __setitem__ [ConstantVariable(str: '_parameters'), ConstDictVariable()] {}

from user code:
   File "/tmp/ipykernel_24205/1940307110.py", line 15, in func
    with params2.to_module(module):
  File "/home/endoplasm/develop/tensordict-0/tensordict/utils.py", line 1238, in new_func
    out = func(_self, *args, **kwargs)
  File "/home/endoplasm/develop/tensordict-0/tensordict/base.py", line 1704, in to_module
    return self._to_module(
  File "/home/endoplasm/develop/tensordict-0/tensordict/nn/params.py", line 190, in new_func
    out = getattr(self._param_td, name)(*args, **kwargs)
  File "/home/endoplasm/develop/tensordict-0/tensordict/_td.py", line 577, in _to_module
    local_out = value._to_module(
  File "/home/endoplasm/develop/tensordict-0/tensordict/_td.py", line 482, in _to_module
    swap = module.copy()
  File "/home/endoplasm/develop/tensordict-0/tensordict/base.py", line 9685, in copy
    return self.clone(recurse=False)
  File "/home/endoplasm/develop/tensordict-0/tensordict/base.py", line 9672, in clone
    result = self._clone(recurse=recurse, **kwargs)
  File "/home/endoplasm/develop/tensordict-0/tensordict/nn/params.py", line 623, in _clone
    return TensorDictParams(self._param_td._clone(False), no_convert="skip")
  File "/home/endoplasm/develop/tensordict-0/tensordict/nn/params.py", line 340, in __init__
    self._reset_params()
  File "/home/endoplasm/develop/tensordict-0/tensordict/nn/params.py", line 384, in _reset_params
    self.__dict__["_parameters"] = dict(zip(param_keys, params))

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True


The illuminating part of the error is:

```
File ~/miniconda/envs/tensordict-0/lib/python3.9/site-packages/torch/_dynamo/exc.py:313, in unimplemented(msg, from_exc, case_name)
    311 if from_exc is not _NOTHING:
    312     raise Unsupported(msg, case_name=case_name) from from_exc
--> 313 raise Unsupported(msg, case_name=case_name)

Unsupported: call_method GetAttrVariable(GenericContextWrappingVariable(TensorDictParams), __dict__) __setitem__ [ConstantVariable(str: '_parameters'), ConstDictVariable()] {}
```

### PyTorch issue


Vincent posted an pytorch issue: <https://github.com/pytorch/pytorch/issues/141118>

He mentions that `TensorDictParms` inherits from multiple classes (`MutableMapping` and `UnspecializedNNModuleVariable`).

He brings up another case that fails compilation:


In [10]:
import torch
from tensordict import TensorDictParams, TensorDict

td = TensorDictParams(TensorDict(a=1, b=2, c=True))

@torch.compile(fullgraph=True)
def add1(td):
    return TensorDict(**td)+1

add1(td)

Unsupported: dict(): (UnspecializedNNModuleVariable(TensorDictParams),) {}

from user code:
   File "/tmp/ipykernel_24205/2760013897.py", line 8, in add1
    return TensorDict(**td)+1

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True


In short, the error is:

```
File ~/miniconda/envs/tensordict-0/lib/python3.9/site-packages/torch/_dynamo/exc.py:313, in unimplemented(msg, from_exc, case_name)
    311 if from_exc is not _NOTHING:
    312     raise Unsupported(msg, case_name=case_name) from from_exc
--> 313 raise Unsupported(msg, case_name=case_name)

Unsupported: dict(): (UnspecializedNNModuleVariable(TensorDictParams),) {}
```


Of course one question is still: How do we fix the error?

But another question is: How does dynamo decide which of the multiple base classes to resolve to?

`MutableMappingUnspecializedNNModuleVariable` is defined under [`torch/_dynamo/variables/nn_module.py`](https://github.com/pytorch/pytorch/blob/f869a0ffe1c42c2c4a95d9d22f49260e7ce5fdb9/torch/_dynamo/variables/nn_module.py#L764)

In the above example, the params got resolved to `UnspecializedNNModuleVariable(TensorDictParams)` rather than a `MutableMappingVariable`.

In [15]:
def func(**kwargs):
    print(kwargs)
    print(type(kwargs))
func(**td)

{'a': tensor(1), 'b': tensor(2), 'c': tensor(True)}
<class 'dict'>
