In [13]:
from gpytorch.kernels import Kernel
from gpytorch import constraints
import torch
import gymnasium as gym
from aps.gaps_policy import GAPSPolicy

In [14]:
env = gym.make("Swimmer-v4")
env.observation_space.shape[0]
num_inputs = env.observation_space.shape[0]
num_outputs = env.action_space.shape[0]
policy = GAPSPolicy(num_inputs, num_outputs)

In [29]:
policy.theta.shape

torch.Size([2, 8])

In [30]:
torch.rand(100, 16).reshape(100, )

tensor([[0.2023, 0.6221, 0.3917,  ..., 0.3406, 0.4091, 0.9884],
        [0.6824, 0.4499, 0.1533,  ..., 0.9209, 0.1099, 0.8873],
        [0.0996, 0.6225, 0.9153,  ..., 0.6180, 0.4695, 0.3822],
        ...,
        [0.0215, 0.4812, 0.3174,  ..., 0.0676, 0.6586, 0.5673],
        [0.4331, 0.1044, 0.9031,  ..., 0.5594, 0.0915, 0.2769],
        [0.5729, 0.3790, 0.8501,  ..., 0.7663, 0.7300, 0.4241]])

In [27]:
class PolicyKernel(Kernel):
    def __init__(self, state_size: int, num_samples: int=100):
        super(PolicyKernel, self).__init__()
        self.state_size = state_size
        self.num_samples = num_samples

        self.register_parameter(name="raw_lengthscale", 
                                parameter=torch.nn.Parameter(torch.zeros(1)))
        self.register_parameter(name="raw_variance", 
                                parameter=torch.nn.Parameter(torch.zeros(1)))
        self.raw_lengthscale_constraint = constraints.Positive()
        self.raw_variance_constraint = constraints.Positive()

    @property
    def lengthscale(self):
        return self.raw_lengthscale_constraint.transform(self.raw_lengthscale)

    @property
    def variance(self):
        return self.raw_variance_constraint.transform(self.raw_variance)

    @lengthscale.setter
    def lengthscale(self, value):
        self._set_lengthscale(value)

    @variance.setter
    def variance(self, value):
        self._set_variance(value)

    def forward(self, x1: torch.Tensor, x2: torch.Tensor):
        random_states = (-1.0 - 1.0) * torch.rand(self.state_size, self.num_samples).to(torch.float64) + 1.0
        actions1 = torch.matmul(x1, random_states)
        actions2 = torch.matmul(x2, random_states)
        dists = self.covar_dist(actions1, actions2, square=True).mean(axis=0)

        return self.variance * torch.exp(-dists / self.lengthscale)

In [26]:
kern = PolicyKernel(state_size=num_inputs)
kern.forward(policy.theta, policy.theta)

tensor([0.0027, 0.0027], dtype=torch.float64, grad_fn=<MulBackward0>)

tensor([[[ 0.2734,  0.3679],
         [-0.6261, -0.1530],
         [-0.3638, -0.4592]],

        [[ 0.4196,  0.6323],
         [ 0.0509, -0.6803],
         [ 0.4745,  0.3169]],

        [[-0.3987, -0.8536],
         [ 0.2364, -0.3556],
         [ 0.3059,  0.3107]],

        [[-0.9406, -0.1197],
         [ 0.1445, -0.3468],
         [ 0.9934,  0.8769]],

        [[-0.4008, -0.9256],
         [-0.4688, -0.5982],
         [-0.0200, -0.1050]]])