following tutorial at https://pytorch.org/rl/stable/tutorials/getting-started-5.html

In [2]:
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
import torch
import time
import gymnasium as gym
from torchrl.envs import GymEnv, StepCounter, TransformedEnv, GymWrapper


torch.manual_seed(0)

base_env = gym.make("intersection-v1")

# Wrap the environment
base_env = GymWrapper(base_env, device="cpu")

env = TransformedEnv(
    base_env, StepCounter()
)
env.set_seed(0)

795726461

In [6]:
env

TransformedEnv(
    env=GymWrapper(env=<OrderEnforcing<PassiveEnvChecker<ContinuousIntersectionEnv<intersection-v1>>>>, batch_size=torch.Size([])),
    transform=StepCounter(keys=[]))

{torchrl.data.tensor_specs.OneHotDiscreteTensorSpec: 'one_hot',
 torchrl.data.tensor_specs.MultiOneHotDiscreteTensorSpec: 'mult_one_hot',
 torchrl.data.tensor_specs.BinaryDiscreteTensorSpec: 'binary',
 torchrl.data.tensor_specs.DiscreteTensorSpec: 'categorical',
 'one_hot': 'one_hot',
 'one-hot': 'one_hot',
 'mult_one_hot': 'mult_one_hot',
 'mult-one-hot': 'mult_one_hot',
 'multi_one_hot': 'mult_one_hot',
 'multi-one-hot': 'mult_one_hot',
 'binary': 'binary',
 'categorical': 'categorical',
 torchrl.data.tensor_specs.MultiDiscreteTensorSpec: 'multi_categorical',
 'multi_categorical': 'multi_categorical',
 'multi-categorical': 'multi_categorical',
 'multi_discrete': 'multi_categorical',
 'multi-discrete': 'multi_categorical'}

In [10]:
env.action_spec

BoundedTensorSpec(
    shape=torch.Size([2]),
    space=ContinuousBox(
        low=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, contiguous=True),
        high=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, contiguous=True)),
    device=cpu,
    dtype=torch.float32,
    domain=continuous)

In [14]:
n_obs = env.observation_spec["observation"].shape[-1]
n_act = env.action_spec.shape[-1]

In [15]:
from torchrl.modules import EGreedyModule, MLP, QValueModule
from torch import nn
from torchrl.data.tensor_specs import BoundedTensorSpec
from torchrl.modules.distributions.continuous import NormalParamWrapper, TanhNormal
from torchrl.modules.tensordict_module.actors import ProbabilisticActor, ValueOperator
from torchrl.modules.tensordict_module.common import SafeModule
from torchrl.objectives.ppo import PPOLoss
from tensordict import TensorDict

base_layer = nn.Linear(n_obs, 5)
net = NormalParamWrapper(nn.Sequential(base_layer, nn.Linear(5, 2 * n_act)))
module = SafeModule(net, in_keys=["observation"], out_keys=["loc", "scale"])
actor = ProbabilisticActor(
    module=module,
    distribution_class=TanhNormal,
    in_keys=["loc", "scale"],
    spec=env.action_spec
)
module = nn.Sequential(base_layer, nn.Linear(5, 1))
value = ValueOperator(
    module=module,
    in_keys=["observation"])
loss = PPOLoss(actor, value)