In [291]:
from torchrl.envs import GymEnv
import torch

In [292]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [293]:
env = GymEnv("CartPole-v1")

In [294]:
env.observation_spec

Composite(
    observation: BoundedContinuous(
        shape=torch.Size([4]),
        space=ContinuousBox(
            low=Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, contiguous=True),
            high=Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, contiguous=True)),
        device=cpu,
        dtype=torch.float32,
        domain=continuous),
    device=None,
    shape=torch.Size([]))

In [295]:
env.observation_spec["observation"]

BoundedContinuous(
    shape=torch.Size([4]),
    space=ContinuousBox(
        low=Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, contiguous=True),
        high=Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, contiguous=True)),
    device=cpu,
    dtype=torch.float32,
    domain=continuous)

In [296]:
env.observation_spec["observation"].shape, env.observation_spec["observation"].shape[-1]

(torch.Size([4]), 4)

In [297]:
env.action_spec

OneHot(
    shape=torch.Size([2]),
    space=CategoricalBox(n=2),
    device=cpu,
    dtype=torch.int64,
    domain=discrete)

In [298]:
env.action_spec.space.n

2

In [299]:
s_dim = env.observation_spec["observation"].shape[-1]
a_dim = env.action_spec.space.n

In [300]:
from torchrl.envs import TransformedEnv, StepCounter

In [301]:
def make_env(env_name="CartPole-v1", device="cpu", from_pixels=False):
    env = GymEnv(env_name, device=device, from_pixels=from_pixels, pixels_only=False)
    env = TransformedEnv(env)
    env.append_transform(StepCounter())

    return env

In [302]:
env.specs

Composite(
    output_spec: Composite(
        full_observation_spec: Composite(
            observation: BoundedContinuous(
                shape=torch.Size([4]),
                space=ContinuousBox(
                    low=Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, contiguous=True),
                    high=Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, contiguous=True)),
                device=cpu,
                dtype=torch.float32,
                domain=continuous),
            device=None,
            shape=torch.Size([])),
        full_done_spec: Composite(
            done: Categorical(
                shape=torch.Size([1]),
                space=CategoricalBox(n=2),
                device=cpu,
                dtype=torch.bool,
                domain=discrete),
            terminated: Categorical(
                shape=torch.Size([1]),
                space=CategoricalBox(n=2),
                device=cpu,
                dtype=torch.bool,
 

In [303]:
from torchrl.modules import MLP, SafeProbabilisticModule
from torchrl.data import Composite

import torch.nn as nn
from tensordict.nn import TensorDictModule, TensorDictSequential
from torch.distributions import Categorical

In [304]:
def make_REINFORCE_modules(p_env, device):

    obs_d = p_env.observation_spec["observation"].shape
    env_specs = p_env.specs
    act_d = env_specs["input_spec", "full_action_spec", "action"].space.n
    act_spec = env_specs["input_spec", "full_action_spec", "action"]

    mlp = MLP(
        in_features=obs_d[-1],
        activation_class=nn.ReLU,
        out_features=act_d,
        num_cells=[128, 128],
        device = device
    )

    policy_net = TensorDictModule(
        module=mlp,
        in_keys=["observation"],
        out_keys=["logits"]
    )

    actor = SafeProbabilisticModule(
        in_keys=["logits"],
        out_keys=["action"],
        distribution_class=Categorical,
        spec=act_spec.to(device),
        return_log_prob=True,
        log_prob_key="action_log_prob"
    )

    actor = TensorDictSequential(policy_net, actor)

    return actor

def make_REINFORCE_module(env_name, device):
    proof_environment = make_env(env_name, device=device)
    REINFORCE_actor = make_REINFORCE_modules(proof_environment, device=device)
    del proof_environment
    return REINFORCE_actor



In [305]:
policy = make_REINFORCE_module("CartPole-v1", device=device)

In [306]:
policy

TensorDictSequential(
    module=ModuleList(
      (0): TensorDictModule(
          module=MLP(
            (0): Linear(in_features=4, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=128, bias=True)
            (3): ReLU()
            (4): Linear(in_features=128, out_features=2, bias=True)
          ),
          device=cpu,
          in_keys=['observation'],
          out_keys=['logits'])
      (1): SafeProbabilisticModule(
          in_keys=['logits'],
          out_keys=['action', 'action_log_prob'],
          distribution_class=<class 'torch.distributions.categorical.Categorical'>, 
          distribution_kwargs={}),
          default_interaction_type=deterministic),
          num_samples=None))
    ),
    device=cpu,
    in_keys=['observation'],
    out_keys=['logits', 'action', 'action_log_prob'])

In [307]:
def eval_model(actor, test_env, num_episodes=3, max_steps=10_000, gamma=0.99, discounted=True):
    test_rewards = torch.zeros(num_episodes, dtype=torch.float32)

    actor.eval()
    with torch.no_grad():
        for i in range(num_episodes):
            td_test = test_env.rollout(
                policy=actor,
                auto_reset=True,
                auto_cast_to_device=True,
                break_when_any_done=True,
                max_steps=max_steps,
            )

            rewards = td_test["next", "reward"].squeeze(-1)  # [T]
            if discounted:
                G = 0
                ret = 0
                for r in reversed(rewards):
                    G = r + gamma * G
                    ret = G  # final G after loop is discounted return from step 0
                test_rewards[i] = ret
            else:
                test_rewards[i] = rewards.sum()

    return test_rewards.mean().item()



In [308]:
import torch.optim as optim

In [309]:
def train_REINFORCE(
    env_name="CartPole-v1",
    device="cpu",
    num_episodes=1000,
    gamma=0.99,
    lr=5e-4,
    log_interval=50,
):

    # Environment
    train_env = make_env(env_name, device=device)
    test_env = make_env(env_name, device=device)

    # Actor
    actor = make_REINFORCE_module(env_name, device).to(device)

    # Optimizer
    optimizer = optim.Adam(actor.parameters(), lr=lr)

    # Training
    for episode in range(1, num_episodes + 1):
        actor.train()

        # Rollout one episode
        td = train_env.rollout(
            policy=actor,
            auto_reset=True,
            auto_cast_to_device=True,
            break_when_any_done=True,
            max_steps=500
        )

        rewards = td["next", "reward"].squeeze(-1)  # [T]
        
        log_probs = td.get("action_log_prob")  # from SafeProbabilisticModule
        T = rewards.shape[0]

        # Compute discounted returns
        returns = torch.zeros(T, device=device)
        G = 0
        for t in reversed(range(T)):
            G = rewards[t] + gamma * G
            returns[t] = G

        # Normalize returns (optional, stabilizes training)
        returns = (returns - returns.mean()) / (returns.std() + 1e-8)

        # Compute REINFORCE loss
        loss = -(log_probs.squeeze(-1) * returns).mean()

        # Optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Logging
        if episode % log_interval == 0:
            print(rewards)
            print(returns)
            print(log_probs)
            avg_reward = eval_model(actor, test_env, num_episodes=5)
            print(f"Episode {episode}/{num_episodes}, Loss: {loss.item():.3f}, Avg Reward: {avg_reward:.2f}")

    return actor



if __name__ == "__main__":
    trained_actor = train_REINFORCE(env_name="CartPole-v1", device="cpu", num_episodes=1_000)

tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
tensor([ 1.4849,  1.1968,  0.9058,  0.6119,  0.3150,  0.0151, -0.2878, -0.5937,
        -0.9028, -1.2150, -1.5303])
tensor([-0.2314, -0.2688, -0.3136, -0.3754, -0.4426, -0.5221, -0.6046, -0.6773,
        -0.6651, -0.6325, -0.6021], grad_fn=<StackBackward0>)
Episode 50/1000, Loss: -0.140, Avg Reward: 8.83
tensor([1., 1., 1., 1., 1., 1., 1., 1.])
tensor([ 1.4145,  1.0185,  0.6184,  0.2143, -0.1938, -0.6061, -1.0226, -1.4432])
tensor([-0.0472, -0.0700, -0.1071, -0.1827, -0.2951, -0.4601, -0.6088, -0.6463],
       grad_fn=<StackBackward0>)
Episode 100/1000, Loss: -0.206, Avg Reward: 8.65
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
tensor([ 1.4664,  1.1492,  0.8289,  0.5052,  0.1784, -0.1518, -0.4853, -0.8222,
        -1.1625, -1.5062])
tensor([-0.0141, -0.0255, -0.0474, -0.1052, -0.2284, -0.4472, -0.6347, -0.6638,
        -0.6802, -0.6896], grad_fn=<StackBackward0>)
Episode 150/1000, Loss: -0.257, Avg Reward: 8.46
tensor([1., 1., 1.

In [310]:
actor.parameters

<bound method Module.parameters of TensorDictSequential(
    module=ModuleList(
      (0): TensorDictModule(
          module=MLP(
            (0): Linear(in_features=4, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=128, bias=True)
            (3): ReLU()
            (4): Linear(in_features=128, out_features=2, bias=True)
          ),
          device=cpu,
          in_keys=['observation'],
          out_keys=['logits'])
      (1): SafeProbabilisticModule(
          in_keys=['logits'],
          out_keys=['action', 'action_log_prob'],
          distribution_class=<class 'torch.distributions.categorical.Categorical'>, 
          distribution_kwargs={}),
          default_interaction_type=deterministic),
          num_samples=None))
    ),
    device=cpu,
    in_keys=['observation'],
    out_keys=['logits', 'action', 'action_log_prob'])>