# Convert the fitted GPyTorch model to a TorchRL Env

In [1]:
%load_ext autoreload
%autoreload 2

## Imports

In [2]:
import gpytorch
import torch
import numpy as np
import gymnasium as gym

from torchrl.collectors import SyncDataCollector
from tensordict import TensorDict
import torchopt
from torchrl.envs.libs.gym import GymEnv
from torchrl.envs.utils import RandomPolicy
from torchrl.data import ReplayBuffer
from torchrl.data import LazyTensorStorage

from torch_pilco.model_learning.dynamical_models import (
    DynamicalModel,
    fit,
)
from torch_pilco.policy_learning.rbf_layer import RBFLayer
from torch_pilco.rewards import pendulum_cost

## Functions

In [3]:
def build_pendulum_training_data(
    data_tensordict: TensorDict,
 ) -> tuple[torch.Tensor, torch.Tensor]:
    return data_tensordict['observation'].float(), data_tensordict['action'].float()

## Parameters

In [4]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
frames_per_batch = 100

env = GymEnv("Pendulum-v1")
random_policy = RandomPolicy(env.action_spec)
action_dim = env.action_space.shape[0]
x = env.reset()
state_dim = x['observation'].shape[0]

num_particles = 400
num_basis = 100

In [5]:
control_policy = RBFLayer(
    state_dim,
    num_basis,
    action_dim,
    u_max=env.action_space.high[0],
) 
batched_policy = torch.vmap(control_policy, in_dims=0)

## Methods

In [6]:
# Generate a random trajectory from the environment
# Should create about 5-8 trajectories then stitch them together
collector = SyncDataCollector(
    env,
    policy=random_policy,
    frames_per_batch=frames_per_batch,
    total_frames=frames_per_batch,
)
# Now determine how many frames are stacked for the dynamical model input:

replay_buffer = ReplayBuffer(storage=LazyTensorStorage(10000))

In [7]:
# Now grab some data and fit the GP
for data in collector:
    # convert the tensordict from collector to a version
    # suitable for dynamical model
    replay_buffer.extend(data)
    states, actions = build_pendulum_training_data(data)

    likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(
        num_tasks=states.shape[1]
    )
    model = DynamicalModel(
        states,
        actions,
        likelihood,
    )

    # Find optimal model hyperparameters
    fit(model, likelihood, print_loss = False)

## Convert Model

In [8]:
from torchrl.envs.utils import check_env_specs

In [9]:
from torch_pilco.policy_learning.rollout import GPyTorchEnv

In [10]:
# Example usage (assuming you have a fitted GPyTorch model named 'fitted_gp_model'):
gp_env = GPyTorchEnv(model,env,pendulum_cost,replay_buffer,batch_size=(num_particles,))

In [11]:
check_env_specs(gp_env)

[92m2025-12-17 15:38:31,812 [torchrl][INFO][0m    check_env_specs succeeded![92m [END][0m


In [12]:
gp_env.reset();

In [13]:
from tensordict.nn import TensorDictModule
policy = TensorDictModule(
    batched_policy,
    in_keys=["observation"],
    out_keys=["action"],
)

In [15]:
optim = torch.optim.Adam(control_policy.parameters(), lr=1e-3)

In [16]:
import tqdm
from collections import defaultdict

In [17]:
batch_size = num_particles
N = 20000
pbar = tqdm.tqdm(range(N // batch_size))        # unsqueeze states
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, N)
logs = defaultdict(list)

for _ in pbar:
    rollout = gp_env.rollout(35, control_policy)
    traj_return = rollout["next", "reward"].mean(dim=0).sum()
    traj_return.backward()
    gn = torch.nn.utils.clip_grad_norm_(control_policy.parameters(), 1.0)
    optim.step()
    optim.zero_grad()
    pbar.set_description(
        f"reward: {traj_return: 4.4f}, "
        f"last reward: {rollout[..., -1]['next', 'reward'].mean(): 4.4f}, gradient norm: {gn: 4.4}"
    )
    logs["return"].append(traj_return.item())
    logs["last_reward"].append(rollout[..., -1]["next", "reward"].mean(dim=0).item())
    scheduler.step()

reward:  114.2512, last reward:  3.3466, gradient norm:  7.887e-05: 100%|██████████████████████████████████████████████| 50/50 [12:23<00:00, 14.86s/it]


In [None]:
# Did control policy change?  or just policy?  what about batched_policy?

In [18]:
?policy

[31mSignature:[39m      policy(*args, **kwargs)
[31mType:[39m           TensorDictModule
[31mString form:[39m   
TensorDictModule(
    module=<function vmap.<locals>.wrapped at 0x7fc023341440>,
    device=cpu,
    in_keys=['observation'],
    out_keys=['action'])
[31mFile:[39m           ~/dev/torch-pilco/.venv/lib/python3.12/site-packages/tensordict/nn/common.py
[31mDocstring:[39m     
A TensorDictModule, is a python wrapper around a :obj:`nn.Module` that reads and writes to a TensorDict.

Args:
    module (Callable[[Any], Any]): a callable, typically a :class:`torch.nn.Module`,
        used to map the input to the output parameter space. Its forward method
        can return a single tensor, a tuple of tensors or even a dictionary.
        In the latter case, the output keys of the :class:`TensorDictModule`
        will be used to populate the output tensordict (ie. the keys present
        in ``out_keys`` should be present in the dictionary returned by the
        ``module`

In [21]:
# Now run the true environment with the learned policy
def env_control_policy(observation):
    td_in = TensorDict({"observation": observation})
    td_out = policy(td_in)
    return td_out["action"].squeeze()

In [22]:
collector = SyncDataCollector(
    env,
    policy=env_control_policy,
    frames_per_batch=frames_per_batch,
    total_frames=frames_per_batch,
)

RuntimeError: shape '[3, 1, 1, 3]' is invalid for input of size 3

In [None]:
# Now grab some data and fit the GP
for data in collector:
    # convert the tensordict from collector to a version
    # suitable for dynamical model
    replay_buffer.extend(data)
    # Now train with all of the data seen so far:
    # We get this by sampling from the replay buffer as many items as there are!
    states, actions = build_pendulum_training_data(replay_buffer.sample(len(replay_buffer)))

    likelihood = gpytorch.likelihoods.MultitaskGaussianLikelihood(
        num_tasks=states.shape[1]
    )
    model = DynamicalModel(
        states,
        actions,
        likelihood,
    )

    # Find optimal model hyperparameters
    fit(model, likelihood, print_loss = False)

In [None]:
gp_env = GPyTorchEnv(model,env,pendulum_cost,replay_buffer,batch_size=(num_particles,))
batched_policy = torch.vmap(control_policy, in_dims=0)

In [None]:
policy = TensorDictModule(
    batched_policy,
    in_keys=["observation"],
    out_keys=["action"],
)
optim = torch.optim.Adam(control_policy.parameters(), lr=2e-3)

In [None]:
N = 20_000
pbar = tqdm.tqdm(range(N // batch_size))        # unsqueeze states
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optim, N)
logs = defaultdict(list)

for _ in pbar:
    rollout = gp_env.rollout(35, control_policy)
    traj_return = rollout["next", "reward"].mean(dim=0).sum()
    traj_return.backward()
    gn = torch.nn.utils.clip_grad_norm_(control_policy.parameters(), 1.0)
    optim.step()
    optim.zero_grad()
    pbar.set_description(
        f"reward: {traj_return: 4.4f}, "
        f"last reward: {rollout[..., -1]['next', 'reward'].mean(): 4.4f}, gradient norm: {gn: 4.4}"
    )
    logs["return"].append(traj_return.item())
    logs["last_reward"].append(rollout[..., -1]["next", "reward"].mean(dim=0).item())
    scheduler.step()