# Meta policy
The aim of this notebook is to have some short experiments to validate the potential of training a meta policy to select reward terms. Specifically, we want to train a policy that reliably selects $r_s$ before selecting $r_x$ as those terms showed the biggest difference in training influence on preliminary tests.

In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [27]:
q_val_in = torch.rand(1, 10) * 100
r_hat_in = torch.rand(1, 10)
r_in = torch.rand(1, 10)
w = torch.zeros(1, 4)
w[0] = 1

In [28]:
class MetaPolicy(nn.Module):
    def __init__(self, in_shape, out_shape, hidden_shape):
        super(MetaPolicy, self).__init__()
        self.fc1 = nn.Linear(in_shape, hidden_shape)
        self.fc2 = nn.Linear(hidden_shape, out_shape)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x), dim=-1)
        return x

In [39]:
f = torch.stack([q_val_in, r_hat_in, r_in], dim=1)

In [40]:
f.shape

torch.Size([1, 3, 10])

In [32]:
meta_policy = MetaPolicy(34, 4, 4)

x = torch.cat([q_val_in, r_hat_in, r_in, w], dim=-1)
y = meta_policy(x)
x.shape, y.shape

(torch.Size([1, 34]), torch.Size([1, 4]))

In [None]:
# TODO maybe we want some sort of attention mechanism here, espeically for q, r_hat and r

In [None]:
# load base policy
# rollout base policy

# train supervised on base policy rollouts and labels

In [None]:
# TODO
- build mechanism to load base policy