## LossModule parameter reset

Sometimes, it may be necessary to reset the trainable parameters of the functions that a `LossModule` is calculating a loss for.

### Current behavior

Currently, in order to reset the parameters, a user has directly access the child modules of the loss module. The child modules represet the trainable functions that the loss module calculates the loss for.

For instance, in the following code, we are setting up a value function and loss for DQN, and then we reset the parameters on the value function directly by calling its `reset_parameters_recursive` function.

In [1]:
from torchrl.objectives import DQNLoss
from torchrl.modules import QValueActor
from torch import nn
import torch

module = nn.Sequential(nn.Linear(1, 64), nn.ReLU(), nn.Linear(64, 64))

value_net = QValueActor(module=module, action_space="categorical")
loss = DQNLoss(value_network=value_net, action_space="categorical")

print(loss.value_network_params['module', '0', 'module', '0', 'weight'][0])

value_net.reset_parameters_recursive()

print(loss.value_network_params['module', '0', 'module', '0', 'weight'][0])


tensor([-0.3242], grad_fn=<SelectBackward0>)
tensor([0.8511], grad_fn=<SelectBackward0>)


But `LossModule` also has target params, and we may need to reset those too.

In order to reset the target params, we can temporarily put them into the value network by calling `to_module`, and then call the value network's reset function. `to_module` returns a TensorDict that has a context manager which will put the original parameters back into the value network upon exitting the `with` context. The target params were updated in place, so the changes to them persist.

In [2]:

print(loss.target_value_network_params['module', '0', 'module', '0', 'weight'][0])

with loss.target_value_network_params.to_module(value_net):
    value_net.reset_parameters_recursive()

print(loss.target_value_network_params['module', '0', 'module', '0', 'weight'][0])


tensor([-0.3242])
tensor([-0.2589])




### Problem

The problem is that this is annoying. The code required to reset the parameters of a loss module is a bit verbose to write, and it will be different for each different type of `LossModule`, since the child modules will be called different things.

It would be nice if we can have a simple `LossModule.reset_parameters` function.

### Previous ideas

One of the ideas from [this PR](https://github.com/pytorch/rl/pull/2017) was to have an API that requires the user to supply their own reset function. An example was this:

In [3]:
def reset_parameters(params):
    """ User specified resetting function depending on their needs for initialization """
    if len(params.shape) > 1:
        # weights
        nn.init.xavier_uniform_(params)
    elif len(params.shape) == 1:
        # biases
        nn.init.zeros_(params)
    else:
        raise ValueError("Unknown parameter shape: {}".format(params.shape))
  
with loss.value_network_params.to_module(loss.value_network):
    loss.apply(lambda x: reset_parameters(x.data) if hasattr(x, "data") else None)

print(loss.value_network_params['module', '0', 'module', '0', 'weight'][0])
print(loss.target_value_network_params['module', '0', 'module', '0', 'weight'][0])


tensor([0.0736], grad_fn=<SelectBackward0>)
tensor([0.1608])


It does accomplish the goal, but the issue with that is that it would be hard to make this kind of reset function work generically, since the reset function you'd have to write is very module-specific.

That PR also tried to add a `LossModule.reset_parameters` function which requires a user-defined function--not too different than calling `LossModule.apply` directly.

Vincent says [here](https://github.com/pytorch/rl/pull/2017#issuecomment-2009336630):
> The way I usually see this work is to use the module `reset_parameters` if there is one, which provides a better control over difference in initialization methods.

(Note, I think he actually meant to say `reset_parameters_recursive`.)

What he means is that `LossModule.reset_parameters` should call the `reset_parameters_recursive` function of all of its child modules.

**The point is that we need the default behavior to be a simple method call with no args, not requiring the user to supply a reset func.**

### Feature requirements

To summarize what I think are the requirements of this feature, I need to make a reset function for `LossModule` which accomplishes the following:

* Update both value params and target params for all child modules.
* Match each set of params with its corresponding child module of the `LossModule` and uses that child module's `reset_parameters_recursive` function.
* (Lower priority) Allow the user to optionally provide their own reset function.

However, I question the necessity of that last point. It seems to me that if the user wants to have fine-grained control and write their own reset function, they could just as easily access the child modules of the `LossModule` directly, rather than have this weird API that expects a user defined function. It would be hard to make it apparent what exactly is going on inside that function, whereas if the user accesses the child modules directly to reset, then it's clear to anyone reading the code where exactly these parameters are coming from and what they represent. (Perhaps it would be a good idea to document how a user can reset the target parameters with the `to_module` context manager.)

#### Experimenting

In [4]:
[n for n, p in loss.named_parameters()]

['value_network_params.module.0.module.0.weight',
 'value_network_params.module.0.module.0.bias',
 'value_network_params.module.0.module.2.weight',
 'value_network_params.module.0.module.2.bias']

In [5]:
loss.__dict__.keys()

dict_keys(['training', '_parameters', '_buffers', '_non_persistent_buffers_set', '_backward_pre_hooks', '_backward_hooks', '_is_full_backward_hook', '_forward_hooks', '_forward_hooks_with_kwargs', '_forward_hooks_always_called', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_state_dict_hooks', '_state_dict_pre_hooks', '_load_state_dict_pre_hooks', '_load_state_dict_post_hooks', '_modules', '_cache', '_param_maps', '_value_estimator', '_has_update_associated', 'value_type', '_tensor_keys', '_in_keys', 'double_dqn', 'delay_value', 'value_network', 'value_network_in_keys', 'loss_function', 'action_space', 'reduction'])

In [6]:
loss._modules.keys()

dict_keys(['value_network_params', 'target_value_network_params'])

In [7]:
list(loss._networks())

[]

In [8]:
isinstance(loss.__dict__['value_network'], nn.Module)

True

In [9]:
len(loss.__dict__.keys())

31

In [10]:
ret = getattr(loss, "value_network_params1", None)
print(ret)

None


In [11]:
[a for a, b in loss.named_parameters(prefix='value_network')]

['value_network.value_network_params.module.0.module.0.weight',
 'value_network.value_network_params.module.0.module.0.bias',
 'value_network.value_network_params.module.0.module.2.weight',
 'value_network.value_network_params.module.0.module.2.bias']

In [32]:
td = loss.value_network_params

In [34]:
td._param_td

TensorDict(
    fields={
        module: TensorDict(
            fields={
                0: TensorDict(
                    fields={
                        module: TensorDict(
                            fields={
                                0: TensorDict(
                                    fields={
                                        bias: Parameter(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False),
                                        weight: Parameter(shape=torch.Size([64, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
                                    batch_size=torch.Size([]),
                                    device=None,
                                    is_shared=False),
                                2: TensorDict(
                                    fields={
                                        bias: Parameter(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False),
                                    

In [13]:
from torchrl.objectives import DQNLoss
from torchrl.modules import QValueActor
from torch import nn
import torch

module = nn.Sequential(nn.Linear(1, 64), nn.ReLU(), nn.Linear(64, 64))

value_net = QValueActor(module=module, action_space="categorical")
loss = DQNLoss(value_network=value_net, action_space="categorical")


print(loss.value_network_params['module', '0', 'module', '0', 'bias'][0])
print(loss._modules.get('target_value_network_params')['module', '0', 'module', '0', 'bias'][0])

loss.reset_parameters_recursive()

print(loss.value_network_params['module', '0', 'module', '0', 'bias'][0])
print(loss._modules.get('target_value_network_params')['module', '0', 'module', '0', 'bias'][0])

tensor(0.5856, grad_fn=<SelectBackward0>)
tensor(0.5856)
tensor(-0.0158, grad_fn=<SelectBackward0>)
tensor(-0.5987)


In [15]:
a = loss._modules.get('target_value_network_params')

In [16]:
a.clone()

TensorDictParams(params=TensorDict(
    fields={
        module: TensorDict(
            fields={
                0: TensorDict(
                    fields={
                        module: TensorDict(
                            fields={
                                0: TensorDict(
                                    fields={
                                        bias: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False),
                                        weight: Tensor(shape=torch.Size([64, 1]), device=cpu, dtype=torch.float32, is_shared=False)},
                                    batch_size=torch.Size([]),
                                    device=None,
                                    is_shared=False),
                                2: TensorDict(
                                    fields={
                                        bias: Tensor(shape=torch.Size([64]), device=cpu, dtype=torch.float32, is_shared=False),
                     

In [20]:
a is a

True