In [315]:
from torchrl.envs import GymEnv
import torch

In [316]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [317]:
env = GymEnv("CartPole-v1")

In [318]:
env.observation_spec

Composite(
    observation: BoundedContinuous(
        shape=torch.Size([4]),
        space=ContinuousBox(
            low=Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, contiguous=True),
            high=Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, contiguous=True)),
        device=cpu,
        dtype=torch.float32,
        domain=continuous),
    device=None,
    shape=torch.Size([]))

In [319]:
env.observation_spec["observation"]

BoundedContinuous(
    shape=torch.Size([4]),
    space=ContinuousBox(
        low=Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, contiguous=True),
        high=Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, contiguous=True)),
    device=cpu,
    dtype=torch.float32,
    domain=continuous)

In [320]:
env.observation_spec["observation"].shape, env.observation_spec["observation"].shape[-1]

(torch.Size([4]), 4)

In [321]:
env.action_spec

OneHot(
    shape=torch.Size([2]),
    space=CategoricalBox(n=2),
    device=cpu,
    dtype=torch.int64,
    domain=discrete)

In [322]:
env.action_spec.space.n

2

In [323]:
s_dim = env.observation_spec["observation"].shape[-1]
a_dim = env.action_spec.space.n

In [324]:
from torchrl.envs import TransformedEnv, StepCounter

In [325]:
def make_env(env_name="CartPole-v1", device="cpu", from_pixels=False):
    env = GymEnv(env_name, device=device, from_pixels=from_pixels, pixels_only=False)
    env = TransformedEnv(env)
    env.append_transform(StepCounter())

    return env

In [326]:
env.specs

Composite(
    output_spec: Composite(
        full_observation_spec: Composite(
            observation: BoundedContinuous(
                shape=torch.Size([4]),
                space=ContinuousBox(
                    low=Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, contiguous=True),
                    high=Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, contiguous=True)),
                device=cpu,
                dtype=torch.float32,
                domain=continuous),
            device=None,
            shape=torch.Size([])),
        full_done_spec: Composite(
            done: Categorical(
                shape=torch.Size([1]),
                space=CategoricalBox(n=2),
                device=cpu,
                dtype=torch.bool,
                domain=discrete),
            terminated: Categorical(
                shape=torch.Size([1]),
                space=CategoricalBox(n=2),
                device=cpu,
                dtype=torch.bool,
 

In [327]:
from torchrl.modules import MLP, SafeProbabilisticModule
from torchrl.data import Composite

import torch.nn as nn
from tensordict.nn import TensorDictModule, TensorDictSequential
from torch.distributions import Categorical

In [328]:
def make_REINFORCE_modules(p_env, device):

    obs_d = p_env.observation_spec["observation"].shape
    env_specs = p_env.specs
    act_d = env_specs["input_spec", "full_action_spec", "action"].space.n
    act_spec = env_specs["input_spec", "full_action_spec", "action"]

    mlp = MLP(
        in_features=obs_d[-1],
        activation_class=nn.ReLU,
        out_features=act_d,
        num_cells=[128, 128],
        device = device
    )

    policy_net = TensorDictModule(
        module=mlp,
        in_keys=["observation"],
        out_keys=["logits"]
    )

    actor = SafeProbabilisticModule(
        in_keys=["logits"],
        out_keys=["action"],
        distribution_class=Categorical,
        spec=act_spec.to(device),
        return_log_prob=True,
        log_prob_key="action_log_prob"
    )

    actor = TensorDictSequential(policy_net, actor)

    return actor

def make_REINFORCE_module(env_name, device):
    proof_environment = make_env(env_name, device=device)
    REINFORCE_actor = make_REINFORCE_modules(proof_environment, device=device)
    del proof_environment
    return REINFORCE_actor



In [329]:
policy = make_REINFORCE_module("CartPole-v1", device=device)

In [330]:
policy

TensorDictSequential(
    module=ModuleList(
      (0): TensorDictModule(
          module=MLP(
            (0): Linear(in_features=4, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=128, bias=True)
            (3): ReLU()
            (4): Linear(in_features=128, out_features=2, bias=True)
          ),
          device=cpu,
          in_keys=['observation'],
          out_keys=['logits'])
      (1): SafeProbabilisticModule(
          in_keys=['logits'],
          out_keys=['action', 'action_log_prob'],
          distribution_class=<class 'torch.distributions.categorical.Categorical'>, 
          distribution_kwargs={}),
          default_interaction_type=deterministic),
          num_samples=None))
    ),
    device=cpu,
    in_keys=['observation'],
    out_keys=['logits', 'action', 'action_log_prob'])

In [331]:
import torch.optim as optim
from torchrl.envs.utils import ExplorationType, set_exploration_type

In [332]:
def train_REINFORCE(
    env_name="CartPole-v1",
    device="cpu",
    num_episodes=1000,
    gamma=0.99,
    lr=5e-4,
    log_interval=50,
):

    # Environment
    train_env = make_env(env_name, device=device)
    test_env = make_env(env_name, device=device)

    # Actor
    actor = make_REINFORCE_module(env_name, device).to(device)

    # Optimizer
    optimizer = optim.Adam(actor.parameters(), lr=lr)
    with set_exploration_type(ExplorationType.DETERMINISTIC):
        # Training
        for episode in range(1, num_episodes + 1):
            actor.train()

            # Rollout one episode
            td = train_env.rollout(
                policy=actor,
                auto_reset=True,
                auto_cast_to_device=True,
                break_when_any_done=True,
                max_steps=500
            )

            rewards = td["next", "reward"].squeeze(-1)  # [T]
            
            log_probs = td.get("action_log_prob")  # from SafeProbabilisticModule
            T = rewards.shape[0]

            # Compute discounted returns
            returns = torch.zeros(T, device=device)
            G = 0
            for t in reversed(range(T)):
                G = rewards[t] + gamma * G
                returns[t] = G

            # Normalize returns (optional, stabilizes training)
            returns = (returns - returns.mean()) / (returns.std() + 1e-8)

            # Compute REINFORCE loss
            loss = -(log_probs.squeeze(-1) * returns).mean()

            # Optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Logging
            if episode % log_interval == 0:
            
                avg_reward = eval_model(actor, test_env, num_episodes=5)
                print(f"Episode {episode}/{num_episodes}, Loss: {loss.item():.3f}, Avg Reward: {avg_reward:.2f}")

    return actor



if __name__ == "__main__":
    trained_actor = train_REINFORCE(env_name="CartPole-v1", device="cpu", num_episodes=1_000)

Episode 50/1000, Loss: -0.145, Avg Reward: 9.19
Episode 100/1000, Loss: -0.219, Avg Reward: 8.83
Episode 150/1000, Loss: -0.253, Avg Reward: 9.38
Episode 200/1000, Loss: -0.268, Avg Reward: 9.19
Episode 250/1000, Loss: -0.273, Avg Reward: 9.01
Episode 300/1000, Loss: -0.275, Avg Reward: 9.01
Episode 350/1000, Loss: -0.273, Avg Reward: 9.20
Episode 400/1000, Loss: -0.278, Avg Reward: 9.20
Episode 450/1000, Loss: -0.272, Avg Reward: 8.65
Episode 500/1000, Loss: -0.278, Avg Reward: 9.38
Episode 550/1000, Loss: -0.273, Avg Reward: 9.38
Episode 600/1000, Loss: -0.276, Avg Reward: 9.01
Episode 650/1000, Loss: -0.276, Avg Reward: 9.01
Episode 700/1000, Loss: -0.282, Avg Reward: 8.65
Episode 750/1000, Loss: -0.265, Avg Reward: 9.01
Episode 800/1000, Loss: -0.278, Avg Reward: 8.83
Episode 850/1000, Loss: -0.282, Avg Reward: 8.46
Episode 900/1000, Loss: -0.270, Avg Reward: 9.20
Episode 950/1000, Loss: -0.276, Avg Reward: 9.38
Episode 1000/1000, Loss: -0.279, Avg Reward: 8.28
