In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchrl.envs.utils import set_exploration_type, ExplorationType
from torchrl.envs import GymEnv
from torchrl.modules import SafeModule, ProbabilisticActor
from torchrl.modules.distributions import OneHotCategorical

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
env = GymEnv("CartPole-v1", device=device)

actor = ProbabilisticActor(
    module=SafeModule(
        module=nn.Sequential(
            nn.Linear(4, 16),
            nn.ReLU(),
            nn.Linear(16, 16),
            nn.ReLU(),
            nn.Linear(16, 2),
        ),
        in_keys=["observation"],
        out_keys=["logits"],
    ),
    spec=env.action_spec,
    in_keys=["logits"],
    distribution_class=OneHotCategorical,
    return_log_prob=True,
).to(device)

In [4]:
actor

ProbabilisticActor(
    module=ModuleList(
      (0): SafeModule(
          module=Sequential(
            (0): Linear(in_features=4, out_features=16, bias=True)
            (1): ReLU()
            (2): Linear(in_features=16, out_features=16, bias=True)
            (3): ReLU()
            (4): Linear(in_features=16, out_features=2, bias=True)
          ),
          device=cpu,
          in_keys=['observation'],
          out_keys=['logits'])
      (1): SafeProbabilisticModule(
          in_keys=['logits'],
          out_keys=['action', 'action_log_prob'],
          distribution_class=<class 'torchrl.modules.distributions.discrete.OneHotCategorical'>, 
          distribution_kwargs={}),
          default_interaction_type=deterministic),
          num_samples=None))
    ),
    device=cpu,
    in_keys=['observation'],
    out_keys=['logits', 'action', 'action_log_prob'])

In [5]:
optimizer = optim.Adam(actor.parameters(), lr=1e-3)

num_episodes = 1000
max_steps_per_episode = 500
reward_log = []
gamma = 0.99

for episode in range(num_episodes):
    episode_data = []
    td = env.reset()

    with set_exploration_type(ExplorationType.RANDOM):
        for step in range(max_steps_per_episode):
            td = actor(td)
            td_next = env.step(td)
            episode_data.append(td_next.clone())
            
            done = td_next.get(("next", "done"))
            if done.any():
                break
            
            td = td_next.get("next")

    
    T = len(episode_data)
    rewards = torch.tensor(
        [td_step.get(("next", "reward")) for td_step in episode_data],
        device=device,
    )

    returns = torch.zeros(T, device=device)
    G = 0
    for t in reversed(range(T)):
        G = rewards[t] + gamma * G
        returns[t] = G
    
    logits = torch.stack([td_step.get("action_log_prob") for td_step in episode_data])
    loss = -(torch.mean(logits * returns))
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    reward_log.append(rewards.sum().item())
    
    if episode % 50 == 0 or episode == 0:
        avg_reward = sum(reward_log[-50:]) / len(reward_log[-50:])
        print(f"Episode {episode}: Average Reward (last 50): {avg_reward:.2f}")
    
    if len(reward_log) > 10 and sum(reward_log[-10:]) / 10 > 475:
        print(f"\nSolved at episode {episode}!")
        break

env.close()

Episode 0: Average Reward (last 50): 13.00
Episode 50: Average Reward (last 50): 21.64
Episode 100: Average Reward (last 50): 20.66
Episode 150: Average Reward (last 50): 19.70
Episode 200: Average Reward (last 50): 19.78
Episode 250: Average Reward (last 50): 23.92
Episode 300: Average Reward (last 50): 25.68
Episode 350: Average Reward (last 50): 20.84
Episode 400: Average Reward (last 50): 22.52
Episode 450: Average Reward (last 50): 30.16
Episode 500: Average Reward (last 50): 26.92
Episode 550: Average Reward (last 50): 33.28
Episode 600: Average Reward (last 50): 31.24
Episode 650: Average Reward (last 50): 36.82
Episode 700: Average Reward (last 50): 42.40
Episode 750: Average Reward (last 50): 58.20
Episode 800: Average Reward (last 50): 55.12
Episode 850: Average Reward (last 50): 42.54
Episode 900: Average Reward (last 50): 75.74
Episode 950: Average Reward (last 50): 80.92
