In [1]:
%load_ext autoreload
%autoreload 2

In [7]:
import sys
from lstmppo.env import make_env, RecurrentVecEnvWrapper
from lstmppo.buffer import RecurrentRolloutBuffer
from lstmppo.policy import LSTMPPOPolicy
from gymnasium.vector import SyncVectorEnv
from lstmppo.init import initialize

In [8]:
sys.argv = [""]
cfg = initialize(seconds_since_epoch=1766683383)


In [1]:
import torch.optim as optim
import torch

In [None]:

class PPOTrainer(object):

    def __init__(self,
                 cfg):
        
        self.vec_state_env = None
        self.global_step = 0
        self.rollout_steps = cfg.rollout_steps

        self.policy = LSTMPPOPolicy(cfg).to(cfg.device)

        venv = SyncVectorEnv([make_env(cfg.env_id)
                              for _ in range(cfg.num_envs)])
        
        self.env = RecurrentVecEnvWrapper(cfg,
                                          venv)

        self.buffer = RecurrentRolloutBuffer(cfg)

        self.policy = LSTMPPOPolicy(cfg)

        self.optimizer = optim.Adam(self.policy.parameters(),
                                    lr=cfg.learning_rate,
                                    eps=1e-5)

    def reset(self):

        self.vec_state_env = self.env.reset()
        self.buffer.reset()
        self.global_step = 0

    def collect_rollout(self):

        for _ in range(cfg.rollout_steps):

            with torch.no_grad():

                policy =\
                    self.policy.act(self.vec_state_env.to_policy_input())

            next_obs, rewards, terminated, truncated, info, hxs_env, cxs_env = env.step(actions)

            # Store step (use hxs_env/cxs_env = hidden state at start of step)
            buffer.add(
                obs=obs,
                actions=actions,
                rewards=rewards,
                values=values,
                logprobs=logprobs,
                terminated=terminated,
                truncated=truncated,
                hxs=hxs_env,
                cxs=cxs_env,
            )

            # Update hidden states inside env wrapper to new policy states
            env.update_hidden_states(new_hxs, new_cxs)

            obs = next_obs
            hxs = new_hxs
            cxs = new_cxs

            global_step += cfg.num_envs

        # Bootstrap value at last obs
        with torch.no_grad():
            last_logits, last_values, _, _ = policy.forward(obs, hxs, cxs)
            last_value = last_values  # (N,)

        buffer.compute_gae(last_value=last_value, gamma=cfg.gamma, lam=cfg.gae_lambda)

trainer = PPOTrainer(cfg)

In [None]:
def train_lstm_ppo_position_only_cartpole(cfg: PPOConfig):
    device = torch.device(cfg.device)
    env_id = "POPGym-PositionOnlyCartPole-v0"  # check id if needed

    # Vectorized env
    venv = SyncVectorEnv([make_env(env_id) for _ in range(cfg.num_envs)])
    dummy_env = gym.make(env_id)
    obs_shape = dummy_env.observation_space.shape
    action_dim = dummy_env.action_space.n
    dummy_env.close()

    # Policy and optimizer
    policy = LSTMPPOPolicy(
        obs_dim=obs_shape[0],
        action_dim=action_dim,
        hidden_size=cfg.hidden_size
    ).to(device)
    





    global_step = 0


        # ==== PPO UPDATE ====
        for epoch in range(cfg.update_epochs):
            for mb in buffer.get_recurrent_minibatches(cfg.batch_envs):
                mb_obs = mb["obs"]          # (T,B,obs)
                mb_actions = mb["actions"]  # (T,B,1)
                mb_returns = mb["returns"]  # (T,B)
                mb_advantages = mb["advantages"]  # (T,B)
                mb_logprobs_old = mb["logprobs"]  # (T,B)
                mb_hxs = mb["hxs"]          # (B,H)
                mb_cxs = mb["cxs"]          # (B,H)

                # Normalize advantages per minibatch
                adv = mb_advantages
                adv = (adv - adv.mean()) / (adv.std() + 1e-8)

                # Evaluate actions
                new_logprobs, entropy, values = policy.evaluate_actions(
                    mb_obs, mb_hxs, mb_cxs, mb_actions
                )

                # PPO ratio
                ratio = (new_logprobs - mb_logprobs_old).exp()

                # Policy loss
                unclipped = ratio * adv
                clipped = torch.clamp(ratio, 1.0 - cfg.clip_coef, 1.0 + cfg.clip_coef) * adv
                policy_loss = -torch.min(unclipped, clipped).mean()

                # Value loss
                value_loss = F.mse_loss(values, mb_returns)

                # Entropy bonus
                entropy_loss = entropy.mean()

                loss = policy_loss + cfg.vf_coef * value_loss - cfg.ent_coef * entropy_loss

                optimizer.zero_grad()
                loss.backward()
                nn.utils.clip_grad_norm_(policy.parameters(), cfg.max_grad_norm)
                optimizer.step()

        # Logging (replace with your logger)
        avg_return = mb_returns.mean().item()
        print(f"Step: {global_step}, ApproxReturn(last mb): {avg_return:.2f}")

    env.venv.close()

In [9]:
import os
import math
import random
from dataclasses import dataclass

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import gymnasium as gym
import popgym
from gymnasium.vector import SyncVectorEnv

In [8]:
import torch
import numpy as np

class RecurrentVecEnvWrapper:
    """
    Wraps a vectorized Gymnasium environment and manages
    per-environment LSTM hidden states.

    Works with SyncVectorEnv or AsyncVectorEnv.
    """

    def __init__(self, venv, hidden_size, device):
        self.venv = venv
        self.num_envs = venv.num_envs
        self.hidden_size = hidden_size
        self.device = device

        # Hidden states per environment
        self.hxs = torch.zeros(self.num_envs, hidden_size, device=device)
        self.cxs = torch.zeros(self.num_envs, hidden_size, device=device)

    def reset(self):
        obs, info = self.venv.reset()

        # Reset all hidden states
        self.hxs.zero_()
        self.cxs.zero_()

        return torch.tensor(obs, device=self.device), info, self.hxs.clone(), self.cxs.clone()

    def step(self, actions):
        """
        actions: (N, action_dim)
        Returns:
            obs: (N, obs_dim)
            rewards: (N,)
            terminated: (N,)
            truncated: (N,)
            info: list of dicts
            hxs: (N, hidden_size)
            cxs: (N, hidden_size)
        """
        obs, rewards, terminated, truncated, info = self.venv.step(actions.cpu().numpy())

        terminated = torch.tensor(terminated, device=self.device, dtype=torch.bool)
        truncated = torch.tensor(truncated, device=self.device, dtype=torch.bool)

        # Reset hidden states only for true terminals
        done_mask = terminated  # NOT truncated
        if done_mask.any():
            self.hxs[done_mask] = 0
            self.cxs[done_mask] = 0

        return (
            torch.tensor(obs, device=self.device),
            torch.tensor(rewards, device=self.device),
            terminated,
            truncated,
            info,
            self.hxs.clone(),
            self.cxs.clone(),
        )

    def update_hidden_states(self, new_hxs, new_cxs):
        """
        Called after the policy forward pass.
        new_hxs, new_cxs: (N, hidden_size)
        """
        self.hxs.copy_(new_hxs)
        self.cxs.copy_(new_cxs)

In [None]:
import torch

class RecurrentRolloutBuffer:
    def __init__(self, T, N, obs_shape, action_dim, hidden_size, device):
        self.T = T
        self.N = N
        self.device = device

        # Core storage
        self.obs = torch.zeros(T, N, *obs_shape, device=device)
        self.actions = torch.zeros(T, N, action_dim, device=device)
        self.rewards = torch.zeros(T, N, device=device)
        self.values = torch.zeros(T, N, device=device)
        self.logprobs = torch.zeros(T, N, device=device)

        # Episode termination logic
        self.terminated = torch.zeros(T, N, device=device, dtype=torch.bool)
        self.truncated = torch.zeros(T, N, device=device, dtype=torch.bool)

        # Hidden states at the *start* of each timestep
        # Shape: (T, N, num_layers, hidden_size)
        self.hxs = torch.zeros(T, N, hidden_size, device=device)
        self.cxs = torch.zeros(T, N, hidden_size, device=device)

        # Filled index
        self.step = 0

    def add(self, obs, actions, rewards, values, logprobs,
            terminated, truncated, hxs, cxs):
        """
        obs: (N, obs_dim)
        hxs, cxs: (N, hidden_size)
        """
        t = self.step
        self.obs[t].copy_(obs)
        self.actions[t].copy_(actions)
        self.rewards[t].copy_(rewards)
        self.values[t].copy_(values)
        self.logprobs[t].copy_(logprobs)
        self.terminated[t].copy_(terminated)
        self.truncated[t].copy_(truncated)
        self.hxs[t].copy_(hxs)
        self.cxs[t].copy_(cxs)

        self.step += 1

    def compute_gae(self, last_value, gamma=0.99, lam=0.95):
        """
        last_value: (N,)
        """
        T, N = self.T, self.N

        advantages = torch.zeros(T, N, device=self.device)
        last_gae = torch.zeros(N, device=self.device)

        for t in reversed(range(T)):
            # True terminal: no bootstrap
            true_terminal = self.terminated[t]

            # Time-limit truncation: DO bootstrap
            bootstrap = ~true_terminal

            next_value = last_value if t == T - 1 else self.values[t + 1]

            delta = (
                self.rewards[t]
                + gamma * next_value * bootstrap
                - self.values[t]
            )

            last_gae = delta + gamma * lam * last_gae * bootstrap
            advantages[t] = last_gae

        returns = advantages + self.values
        self.advantages = advantages
        self.returns = returns

    def get_recurrent_minibatches(self, batch_size):
        """
        Returns sequences of shape:
        (seq_len=T, batch_size, ...)
        """
        N = self.N
        env_indices = torch.randperm(N)

        for start in range(0, N, batch_size):
            idx = env_indices[start:start + batch_size]

            yield {
                "obs": self.obs[:, idx],
                "actions": self.actions[:, idx],
                "values": self.values[:, idx],
                "logprobs": self.logprobs[:, idx],
                "returns": self.returns[:, idx],
                "advantages": self.advantages[:, idx],
                "hxs": self.hxs[0, idx],  # initial hidden state
                "cxs": self.cxs[0, idx],
                "terminated": self.terminated[:, idx],
                "truncated": self.truncated[:, idx],
            }