In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from lstmppo.init import initialize
params = initialize(seconds_since_epoch=1766683383)


╭─ Unrecognized options ───────────────────────────────────────────────────────╮
│ Unrecognized options: --f=/run/user/1000/jupyter/runtime/kernel-v3c99d292aac │
│ 18ff8d8f51217662b34f1f2b0f57f2.json                                          │
│ ──────────────────────────────────────────────────────────────────────────── │
│ For full helptext, run /home/mspcvsp/miniconda/envs/cage2_env/lib/python3.10 │
│ /site-packages/ipykernel_launcher.py --help                                  │
╰──────────────────────────────────────────────────────────────────────────────╯


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [None]:
import torch

class RecurrentRolloutBuffer:
    def __init__(self, T, N, obs_shape, action_dim, hidden_size, device):
        self.T = T
        self.N = N
        self.device = device

        # Core storage
        self.obs = torch.zeros(T, N, *obs_shape, device=device)
        self.actions = torch.zeros(T, N, action_dim, device=device)
        self.rewards = torch.zeros(T, N, device=device)
        self.values = torch.zeros(T, N, device=device)
        self.logprobs = torch.zeros(T, N, device=device)

        # Episode termination logic
        self.terminated = torch.zeros(T, N, device=device, dtype=torch.bool)
        self.truncated = torch.zeros(T, N, device=device, dtype=torch.bool)

        # Hidden states at the *start* of each timestep
        # Shape: (T, N, num_layers, hidden_size)
        self.hxs = torch.zeros(T, N, hidden_size, device=device)
        self.cxs = torch.zeros(T, N, hidden_size, device=device)

        # Filled index
        self.step = 0

    def add(self, obs, actions, rewards, values, logprobs,
            terminated, truncated, hxs, cxs):
        """
        obs: (N, obs_dim)
        hxs, cxs: (N, hidden_size)
        """
        t = self.step
        self.obs[t].copy_(obs)
        self.actions[t].copy_(actions)
        self.rewards[t].copy_(rewards)
        self.values[t].copy_(values)
        self.logprobs[t].copy_(logprobs)
        self.terminated[t].copy_(terminated)
        self.truncated[t].copy_(truncated)
        self.hxs[t].copy_(hxs)
        self.cxs[t].copy_(cxs)

        self.step += 1

    def compute_gae(self, last_value, gamma=0.99, lam=0.95):
        """
        last_value: (N,)
        """
        T, N = self.T, self.N

        advantages = torch.zeros(T, N, device=self.device)
        last_gae = torch.zeros(N, device=self.device)

        for t in reversed(range(T)):
            # True terminal: no bootstrap
            true_terminal = self.terminated[t]

            # Time-limit truncation: DO bootstrap
            bootstrap = ~true_terminal

            next_value = last_value if t == T - 1 else self.values[t + 1]

            delta = (
                self.rewards[t]
                + gamma * next_value * bootstrap
                - self.values[t]
            )

            last_gae = delta + gamma * lam * last_gae * bootstrap
            advantages[t] = last_gae

        returns = advantages + self.values
        self.advantages = advantages
        self.returns = returns

    def get_recurrent_minibatches(self, batch_size):
        """
        Returns sequences of shape:
        (seq_len=T, batch_size, ...)
        """
        N = self.N
        env_indices = torch.randperm(N)

        for start in range(0, N, batch_size):
            idx = env_indices[start:start + batch_size]

            yield {
                "obs": self.obs[:, idx],
                "actions": self.actions[:, idx],
                "values": self.values[:, idx],
                "logprobs": self.logprobs[:, idx],
                "returns": self.returns[:, idx],
                "advantages": self.advantages[:, idx],
                "hxs": self.hxs[0, idx],  # initial hidden state
                "cxs": self.cxs[0, idx],
                "terminated": self.terminated[:, idx],
                "truncated": self.truncated[:, idx],
            }