In [1]:
from torchrl.envs import GymEnv
import torch

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
env = GymEnv("CartPole-v1")

In [4]:
env.observation_spec

Composite(
    observation: BoundedContinuous(
        shape=torch.Size([4]),
        space=ContinuousBox(
            low=Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, contiguous=True),
            high=Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, contiguous=True)),
        device=cpu,
        dtype=torch.float32,
        domain=continuous),
    device=None,
    shape=torch.Size([]))

In [5]:
env.observation_spec["observation"]

BoundedContinuous(
    shape=torch.Size([4]),
    space=ContinuousBox(
        low=Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, contiguous=True),
        high=Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, contiguous=True)),
    device=cpu,
    dtype=torch.float32,
    domain=continuous)

In [6]:
env.observation_spec["observation"].shape, env.observation_spec["observation"].shape[-1]

(torch.Size([4]), 4)

In [7]:
env.action_spec

OneHot(
    shape=torch.Size([2]),
    space=CategoricalBox(n=2),
    device=cpu,
    dtype=torch.int64,
    domain=discrete)

In [8]:
env.action_spec.space.n

2

In [9]:
s_dim = env.observation_spec["observation"].shape[-1]
a_dim = env.action_spec.space.n

In [10]:
from torchrl.envs import TransformedEnv, StepCounter

In [11]:
def make_env(env_name="CartPole-v1", device="cpu", from_pixels=False):
    env = GymEnv(env_name, device=device, from_pixels=from_pixels, pixels_only=False)
    env = TransformedEnv(env)
    env.append_transform(StepCounter())

    return env

In [12]:
env.specs

Composite(
    output_spec: Composite(
        full_observation_spec: Composite(
            observation: BoundedContinuous(
                shape=torch.Size([4]),
                space=ContinuousBox(
                    low=Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, contiguous=True),
                    high=Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, contiguous=True)),
                device=cpu,
                dtype=torch.float32,
                domain=continuous),
            device=None,
            shape=torch.Size([])),
        full_done_spec: Composite(
            done: Categorical(
                shape=torch.Size([1]),
                space=CategoricalBox(n=2),
                device=cpu,
                dtype=torch.bool,
                domain=discrete),
            terminated: Categorical(
                shape=torch.Size([1]),
                space=CategoricalBox(n=2),
                device=cpu,
                dtype=torch.bool,
 

In [13]:
from torchrl.modules import MLP, SafeProbabilisticModule
from torchrl.data import Composite

import torch.nn as nn
from tensordict.nn import TensorDictModule, TensorDictSequential
from torch.distributions import Categorical
from torchrl.modules.distributions import OneHotCategorical

In [14]:
def make_REINFORCE_modules(p_env, device):

    obs_d = p_env.observation_spec["observation"].shape
    env_specs = p_env.specs
    act_d = env_specs["input_spec", "full_action_spec", "action"].space.n
    act_spec = env_specs["input_spec", "full_action_spec", "action"]

    mlp = MLP(
        in_features=obs_d[-1],
        activation_class=nn.ReLU,
        out_features=act_d,
        num_cells=[128, 128],
        device = device
    )

    policy_net = TensorDictModule(
        module=mlp,
        in_keys=["observation"],
        out_keys=["logits"]
    )

    actor = SafeProbabilisticModule(
        in_keys=["logits"],
        out_keys=["action"],
        distribution_class=OneHotCategorical,
        spec=act_spec.to(device),
        return_log_prob=True,
        log_prob_key="action_log_prob"
    )

    actor = TensorDictSequential(policy_net, actor)

    return actor

def make_REINFORCE_module(env_name, device):
    proof_environment = make_env(env_name, device=device)
    REINFORCE_actor = make_REINFORCE_modules(proof_environment, device=device)
    del proof_environment
    return REINFORCE_actor



In [19]:
def eval_model(actor, test_env, num_episodes=3, max_steps=10_000, gamma=0.99, discounted=True): 
    test_rewards = torch.zeros(num_episodes, dtype=torch.float32) 
    actor.eval() 
    with torch.no_grad(): 
        for i in range(num_episodes): 
            td_test = test_env.rollout( policy=actor, 
                                       auto_reset=True, 
                                       auto_cast_to_device=True, 
                                       break_when_any_done=True, 
                                       max_steps=max_steps, ) 
            rewards = td_test["next", "reward"].squeeze(-1) # [T] 
            if discounted: 
                G = 0 
                ret = 0 
                for r in reversed(rewards): 
                    G = r + gamma * G 
                    ret = G 
                    # final G after loop is discounted return from step 0 
                test_rewards[i] = ret 
            else: test_rewards[i] = rewards.sum() 
    return test_rewards.mean().item()

In [15]:
policy = make_REINFORCE_module("CartPole-v1", device=device)

In [16]:
policy

TensorDictSequential(
    module=ModuleList(
      (0): TensorDictModule(
          module=MLP(
            (0): Linear(in_features=4, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=128, bias=True)
            (3): ReLU()
            (4): Linear(in_features=128, out_features=2, bias=True)
          ),
          device=cpu,
          in_keys=['observation'],
          out_keys=['logits'])
      (1): SafeProbabilisticModule(
          in_keys=['logits'],
          out_keys=['action', 'action_log_prob'],
          distribution_class=<class 'torchrl.modules.distributions.discrete.OneHotCategorical'>, 
          distribution_kwargs={}),
          default_interaction_type=deterministic),
          num_samples=None))
    ),
    device=cpu,
    in_keys=['observation'],
    out_keys=['logits', 'action', 'action_log_prob'])

In [17]:
import torch.optim as optim
from torchrl.envs.utils import ExplorationType, set_exploration_type

In [21]:
def train_REINFORCE(
    env_name="CartPole-v1",
    device="cpu",
    num_episodes=1000,
    gamma=0.99,
    lr=5e-4,
    log_interval=50,
):

    # Environment
    train_env = make_env(env_name, device=device)
    test_env = make_env(env_name, device=device)

    # Actor
    actor = make_REINFORCE_module(env_name, device).to(device)

    # Optimizer
    optimizer = optim.Adam(actor.parameters(), lr=lr)
    with set_exploration_type(ExplorationType.RANDOM):
        # Training
        for episode in range(1, num_episodes + 1):
            actor.train()

            # Rollout one episode
            td = train_env.rollout(
                policy=actor,
                auto_reset=True,
                auto_cast_to_device=True,
                break_when_any_done=True,
                max_steps=500
            )

            rewards = td["next", "reward"].squeeze(-1)  # [T]
            
            log_probs = td.get("action_log_prob")  # from SafeProbabilisticModule
            T = rewards.shape[0]

            # Compute discounted returns
            returns = torch.zeros(T, device=device)
            G = 0
            for t in reversed(range(T)):
                G = rewards[t] + gamma * G
                returns[t] = G

            # Normalize returns (optional, stabilizes training)
            returns = (returns - returns.mean()) / (returns.std() + 1e-8)

            # Compute REINFORCE loss
            loss = -(log_probs.squeeze(-1) * returns).mean()

            # Optimize
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Logging
            if episode % log_interval == 0:
            
                avg_reward = eval_model(actor, test_env, num_episodes=5)
                print(f"Episode {episode}/{num_episodes}, Loss: {loss.item():.3f}, Avg Reward: {avg_reward:.2f}")

    return actor



if __name__ == "__main__":
    trained_actor = train_REINFORCE(env_name="CartPole-v1", device="cpu", num_episodes=1_000)

Episode 50/1000, Loss: -0.003, Avg Reward: 12.18
Episode 100/1000, Loss: 0.011, Avg Reward: 19.59
Episode 150/1000, Loss: 0.000, Avg Reward: 20.98
Episode 200/1000, Loss: -0.013, Avg Reward: 45.23
Episode 250/1000, Loss: -0.022, Avg Reward: 53.00
Episode 300/1000, Loss: -0.046, Avg Reward: 78.34
Episode 350/1000, Loss: 0.016, Avg Reward: 83.40
Episode 400/1000, Loss: -0.054, Avg Reward: 76.43
Episode 450/1000, Loss: 0.008, Avg Reward: 92.52
Episode 500/1000, Loss: 0.009, Avg Reward: 74.32
Episode 550/1000, Loss: 0.016, Avg Reward: 96.15
Episode 600/1000, Loss: 0.023, Avg Reward: 97.32
Episode 650/1000, Loss: -0.010, Avg Reward: 99.29
Episode 700/1000, Loss: -0.002, Avg Reward: 98.02
Episode 750/1000, Loss: -0.019, Avg Reward: 99.27
Episode 800/1000, Loss: -0.029, Avg Reward: 94.78
Episode 850/1000, Loss: -0.011, Avg Reward: 85.51
Episode 900/1000, Loss: 0.021, Avg Reward: 98.97
Episode 950/1000, Loss: 0.025, Avg Reward: 86.83
Episode 1000/1000, Loss: -0.033, Avg Reward: 99.34
