In [262]:
from torchrl.envs import GymEnv
import torch

In [263]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [264]:
env = GymEnv("CartPole-v1")

In [265]:
env.observation_spec

Composite(
    observation: BoundedContinuous(
        shape=torch.Size([4]),
        space=ContinuousBox(
            low=Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, contiguous=True),
            high=Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, contiguous=True)),
        device=cpu,
        dtype=torch.float32,
        domain=continuous),
    device=None,
    shape=torch.Size([]))

In [266]:
env.observation_spec["observation"]

BoundedContinuous(
    shape=torch.Size([4]),
    space=ContinuousBox(
        low=Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, contiguous=True),
        high=Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, contiguous=True)),
    device=cpu,
    dtype=torch.float32,
    domain=continuous)

In [267]:
env.observation_spec["observation"].shape, env.observation_spec["observation"].shape[-1]

(torch.Size([4]), 4)

In [268]:
env.action_spec

OneHot(
    shape=torch.Size([2]),
    space=CategoricalBox(n=2),
    device=cpu,
    dtype=torch.int64,
    domain=discrete)

In [269]:
env.action_spec.space.n

2

In [270]:
s_dim = env.observation_spec["observation"].shape[-1]
a_dim = env.action_spec.space.n

In [None]:
from torchrl.envs import TransformedEnv, StepCounter

In [272]:
def make_env(env_name="CartPole-v1", device="cpu", from_pixels=False):
    env = GymEnv(env_name, device=device, from_pixels=from_pixels, pixels_only=False)
    env = TransformedEnv(env)
    env.append_transform(StepCounter())

    return env

In [273]:
env.specs

Composite(
    output_spec: Composite(
        full_observation_spec: Composite(
            observation: BoundedContinuous(
                shape=torch.Size([4]),
                space=ContinuousBox(
                    low=Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, contiguous=True),
                    high=Tensor(shape=torch.Size([4]), device=cpu, dtype=torch.float32, contiguous=True)),
                device=cpu,
                dtype=torch.float32,
                domain=continuous),
            device=None,
            shape=torch.Size([])),
        full_done_spec: Composite(
            done: Categorical(
                shape=torch.Size([1]),
                space=CategoricalBox(n=2),
                device=cpu,
                dtype=torch.bool,
                domain=discrete),
            terminated: Categorical(
                shape=torch.Size([1]),
                space=CategoricalBox(n=2),
                device=cpu,
                dtype=torch.bool,
 

In [274]:
from torchrl.modules import MLP, SafeProbabilisticModule
from torchrl.data import Composite

import torch.nn as nn
from tensordict.nn import TensorDictModule, TensorDictSequential
from torch.distributions import Categorical

In [275]:
def make_REINFORCE_modules(p_env, device):

    obs_d = p_env.observation_spec["observation"].shape
    env_specs = p_env.specs
    act_d = env_specs["input_spec", "full_action_spec", "action"].space.n
    act_spec = env_specs["input_spec", "full_action_spec", "action"]

    mlp = MLP(
        in_features=obs_d[-1],
        activation_class=nn.ReLU,
        out_features=act_d,
        num_cells=[128, 128],
        device = device
    )

    policy_net = TensorDictModule(
        module=mlp,
        in_keys=["observation"],
        out_keys=["logits"]
    )

    actor = SafeProbabilisticModule(
        in_keys=["logits"],
        out_keys=["action"],
        distribution_class=Categorical,
        spec=act_spec.to(device),
        return_log_prob=True,
        log_prob_key="action_log_prob"
    )

    actor = TensorDictSequential(policy_net, actor)

    return actor

def make_REINFORCE_module(env_name, device):
    proof_environment = make_env(env_name, device=device)
    REINFORCE_actor = make_REINFORCE_modules(proof_environment, device=device)
    del proof_environment
    return REINFORCE_actor



In [276]:
policy = make_REINFORCE_module("CartPole-v1", device=device)

In [277]:
policy

TensorDictSequential(
    module=ModuleList(
      (0): TensorDictModule(
          module=MLP(
            (0): Linear(in_features=4, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=128, bias=True)
            (3): ReLU()
            (4): Linear(in_features=128, out_features=2, bias=True)
          ),
          device=cpu,
          in_keys=['observation'],
          out_keys=['logits'])
      (1): SafeProbabilisticModule(
          in_keys=['logits'],
          out_keys=['action', 'action_log_prob'],
          distribution_class=<class 'torch.distributions.categorical.Categorical'>, 
          distribution_kwargs={}),
          default_interaction_type=deterministic),
          num_samples=None))
    ),
    device=cpu,
    in_keys=['observation'],
    out_keys=['logits', 'action', 'action_log_prob'])

In [278]:
def eval_model(actor, test_env, num_episodes=3, max_steps=10_000, gamma=0.99, discounted=True):
    test_rewards = torch.zeros(num_episodes, dtype=torch.float32)

    actor.eval()
    with torch.no_grad():
        for i in range(num_episodes):
            td_test = test_env.rollout(
                policy=actor,
                auto_reset=True,
                auto_cast_to_device=True,
                break_when_any_done=True,
                max_steps=max_steps,
            )

            rewards = td_test["next", "reward"].squeeze(-1)  # [T]
            if discounted:
                G = 0
                ret = 0
                for r in reversed(rewards):
                    G = r + gamma * G
                    ret = G  # final G after loop is discounted return from step 0
                test_rewards[i] = ret
            else:
                test_rewards[i] = rewards.sum()

    return test_rewards.mean().item()



In [279]:
import torch.optim as optim

In [288]:
def train_REINFORCE(
    env_name="CartPole-v1",
    device="cpu",
    num_episodes=1000,
    gamma=0.99,
    lr=5e-4,
    log_interval=50,
):

    # Environment
    train_env = make_env(env_name, device=device)
    test_env = make_env(env_name, device=device)

    # Actor
    actor = make_REINFORCE_module(env_name, device).to(device)

    # Optimizer
    optimizer = optim.Adam(actor.parameters(), lr=lr)

    # Training
    for episode in range(1, num_episodes + 1):
        actor.train()

        # Rollout one episode
        td = train_env.rollout(
            policy=actor,
            auto_reset=True,
            auto_cast_to_device=True,
            break_when_any_done=True,
            max_steps=500
        )

        rewards = torch.tensor(td["next", "reward"].squeeze(-1))  # [T]
        
        log_probs = td.get("action_log_prob")  # from SafeProbabilisticModule
        T = rewards.shape[0]

        # Compute discounted returns
        returns = torch.zeros(T, device=device)
        G = 0
        for t in reversed(range(T)):
            G = rewards[t] + gamma * G
            returns[t] = G

        # Normalize returns (optional, stabilizes training)
        returns = (returns - returns.mean()) / (returns.std() + 1e-8)

        # Compute REINFORCE loss
        loss = -(log_probs.squeeze(-1) * returns).mean()

        # Optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Logging
        if episode % log_interval == 0:
            print(rewards)
            print(returns)
            print(log_probs)
            avg_reward = eval_model(actor, test_env, num_episodes=5)
            print(f"Episode {episode}/{num_episodes}, Loss: {loss.item():.3f}, Avg Reward: {avg_reward:.2f}")

    return actor



if __name__ == "__main__":
    trained_actor = train_REINFORCE(env_name="CartPole-v1", device="cpu", num_episodes=1_000)

  rewards = torch.tensor(td["next", "reward"].squeeze(-1))  # [T]


tensor([1., 1., 1., 1., 1., 1., 1., 1., 1., 1.])
tensor([ 1.4664,  1.1492,  0.8289,  0.5052,  0.1784, -0.1518, -0.4853, -0.8222,
        -1.1625, -1.5062])
tensor([-0.1858, -0.2233, -0.2867, -0.3728, -0.4601, -0.5403, -0.6169, -0.6731,
        -0.6752, -0.6375], grad_fn=<StackBackward0>)
Episode 50/1000, Loss: -0.164, Avg Reward: 8.65
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1.])
tensor([ 1.4435,  1.0910,  0.7350,  0.3755,  0.0122, -0.3547, -0.7253, -1.0996,
        -1.4777])
tensor([-0.0357, -0.0569, -0.1002, -0.1925, -0.3444, -0.5313, -0.6466, -0.6901,
        -0.6409], grad_fn=<StackBackward0>)
Episode 100/1000, Loss: -0.233, Avg Reward: 9.19
tensor([1., 1., 1., 1., 1., 1., 1., 1., 1.])
tensor([ 1.4435,  1.0910,  0.7350,  0.3755,  0.0122, -0.3547, -0.7253, -1.0996,
        -1.4777])
tensor([-0.0119, -0.0231, -0.0515, -0.1292, -0.3254, -0.6162, -0.6636, -0.6824,
        -0.6875], grad_fn=<StackBackward0>)
Episode 150/1000, Loss: -0.259, Avg Reward: 8.83
tensor([1., 1., 1., 1., 1., 1.,

In [290]:
actor.parameters

<bound method Module.parameters of TensorDictSequential(
    module=ModuleList(
      (0): TensorDictModule(
          module=MLP(
            (0): Linear(in_features=4, out_features=128, bias=True)
            (1): ReLU()
            (2): Linear(in_features=128, out_features=128, bias=True)
            (3): ReLU()
            (4): Linear(in_features=128, out_features=2, bias=True)
          ),
          device=cpu,
          in_keys=['observation'],
          out_keys=['logits'])
      (1): SafeProbabilisticModule(
          in_keys=['logits'],
          out_keys=['action', 'action_log_prob'],
          distribution_class=<class 'torch.distributions.categorical.Categorical'>, 
          distribution_kwargs={}),
          default_interaction_type=deterministic),
          num_samples=None))
    ),
    device=cpu,
    in_keys=['observation'],
    out_keys=['logits', 'action', 'action_log_prob'])>