In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import gymnasium as gym
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.distributions.categorical import Categorical

In [3]:
import sys
sys.argv = [""]

In [16]:
from cart_pole.utils import initialize
from cart_pole.agent import Agent
from numpy.random import Generator, MT19937, SeedSequence
from torch.utils.tensorboard import SummaryWriter
from cart_pole.environment import make_sync_vector_env
import pandas as pd
import pdb

In [12]:
args, writer = initialize(seconds_since_epoch=1764177877)
envs = make_sync_vector_env(args)

In [13]:
agent = Agent(envs).to(args.device)

optimizer = optim.Adam(agent.parameters(),
                       lr=args.learning_rate,
                       eps=1e-5)

In [19]:
class Storage(object):

    def __init__(self,
                 args,
                 envs):

        self.num_steps = args.num_steps
        self.num_envs = args.num_envs
        self.device = args.device

        self.reset(envs)

    def store(self,
              paramid,
              value):

        assert self.step < self.num_steps, "Storage is full"
        self.__dict__[paramid][self.step] = value

    def update_episode_info(self,
                            info):

        if all(key in info for key in ["episode", "_episode"]):

            select_row = info["_episode"]

            self.episode_info.loc[select_row, "cumulative_reward"] =\
                info["episode"]["r"][select_row]

            self.episode_info.loc[select_row, "episode_length"] =\
                info["episode"]["l"][select_row]

        self.step += 1

    def reset(self,
              envs):

        self.step = 0

        shape2d = (self.num_steps, self.num_envs) 
        shape3d = shape2d + envs.single_observation_space.shape
        
        self.obs = torch.zeros(shape3d).to(self.device)
        self.actions = torch.zeros(shape3d).to(self.device)

        self.logprobs = torch.zeros(shape2d).to(self.device)
        self.rewards = torch.zeros(shape2d).to(self.device)
        self.terminated = torch.zeros(shape2d).to(self.device)
        self.truncated = torch.zeros(shape2d).to(self.device)
        self.values = torch.zeros(shape2d).to(self.device)

        self.episode_info =\
            pd.DataFrame([{"cumulative_reward": 0, "episode_length": 0}] *\
                         self.num_envs)

storage = Storage(args, envs)

In [None]:
global_step = 0
start_time = time.time()

In [None]:
seed_seq = SeedSequence(args.seed)
rng = Generator(MT19937(seed_seq))

next_obs = torch.Tensor(envs.reset(seed=seeds)[0]).to(args.device)
next_terminated = torch.zeros(args.num_envs).to(args.device)
next_truncated = torch.zeros(args.num_envs).to(args.device)
num_updates = args.total_timesteps // args.batch_size
storage.reset(envs)

for _ in range(args.num_steps):

    storage.store("obs", next_obs)
    storage.store("terminated", next_terminated)
    storage.store("truncated", next_truncated)

    with torch.no_grad():
        action, logprob, _, value = agent.get_action_and_value(next_obs)
        storage.store("values", value.flatten())
        storage.store("actions", action)
        storage.store("logprobs", logprob)

    next_obs, reward, terminated, truncated, info =\
        envs.step(action.cpu().numpy())
    
    storage.store("rewards",
                  torch.tensor(reward).to(args.device).view(-1))

    next_obs = torch.Tensor(next_obs).to(args.device)
    next_terminated = torch.Tensor(terminated).to(args.device)
    next_truncated = torch.Tensor(truncated).to(args.device)

    storage.update_episode_info(info)
    if any(terminated):
        print("episode(s) terminated")
    pdb.set_trace()

> [0;32m/tmp/ipykernel_2079/256065787.py[0m(10)[0;36m<module>[0;34m()[0m
[0;32m      8 [0;31m[0mstorage[0m[0;34m.[0m[0mreset[0m[0;34m([0m[0menvs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      9 [0;31m[0;34m[0m[0m
[0m[0;32m---> 10 [0;31m[0;32mfor[0m [0m_[0m [0;32min[0m [0mrange[0m[0;34m([0m[0margs[0m[0;34m.[0m[0mnum_steps[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     11 [0;31m[0;34m[0m[0m
[0m[0;32m     12 [0;31m    [0mstorage[0m[0;34m.[0m[0mstore[0m[0;34m([0m[0;34m"obs"[0m[0;34m,[0m [0mnext_obs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/tmp/ipykernel_2079/256065787.py[0m(10)[0;36m<module>[0;34m()[0m
[0;32m      8 [0;31m[0mstorage[0m[0;34m.[0m[0mreset[0m[0;34m([0m[0menvs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      9 [0;31m[0;34m[0m[0m
[0m[0;32m---> 10 [0;31m[0;32mfor[0m [0m_[0m [0;32min[0m [0mrange[0m[0;34m([0m[0margs[0m[0;34m.[0m[0mnum_steps[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     11 [0;31m[0;34m[0m[0m
[0m[0;32m     12 [0;31m    [0mstorage[0m[0;34m.[0m[0mstore[0m[0;34m([0m[0;34m"obs"[0m[0;34m,[0m [0mnext_obs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/tmp/ipykernel_2079/256065787.py[0m(10)[0;36m<module>[0;34m()[0m
[0;32m      8 [0;31m[0mstorage[0m[0;34m.[0m[0mreset[0m[0;34m([0m[0menvs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      9 [0;31m[0;34m[0m[0m
[0m[0;32m---> 10 [0;31m[0;32mfor[0m [0m_[0m [0;32min[0m [0mrange[0m[0;34m([0m[0margs[0m[0;34m.[0m[0mnum_steps[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     11 [0;31m[0;34m[0m[0m
[0m[0;32m     12 [0;31m    [0mstorage[0m[0;34m.[0m[0mstore[0m[0;34m([0m[0;34m"obs"[0m[0;34m,[0m [0mnext_obs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/tmp/ipykernel_2079/256065787.py[0m(10)[0;36m<module>[0;34m()[0m
[0;32m      8 [0;31m[0mstorage[0m[0;34m.[0m[0mreset[0m[0;34m([0m[0menvs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      9 [0;31m[0;34m[0m[0m
[0m[0;32m---> 10 [0;31m[0;32mfor[0m [0m_[0m [0;32min[0m [0mrange[0m[0;34m([0m[0margs[0m[0;34m.[0m[0mnum_steps[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     11 [0;31m[0;34m[0m[0m
[0m[0;32m     12 [0;31m    [0mstorage[0m[0;34m.[0m[0mstore[0m[0;34m([0m[0;34m"obs"[0m[0;34m,[0m [0mnext_obs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/tmp/ipykernel_2079/256065787.py[0m(10)[0;36m<module>[0;34m()[0m
[0;32m      8 [0;31m[0mstorage[0m[0;34m.[0m[0mreset[0m[0;34m([0m[0menvs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      9 [0;31m[0;34m[0m[0m
[0m[0;32m---> 10 [0;31m[0;32mfor[0m [0m_[0m [0;32min[0m [0mrange[0m[0;34m([0m[0margs[0m[0;34m.[0m[0mnum_steps[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     11 [0;31m[0;34m[0m[0m
[0m[0;32m     12 [0;31m    [0mstorage[0m[0;34m.[0m[0mstore[0m[0;34m([0m[0;34m"obs"[0m[0;34m,[0m [0mnext_obs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/tmp/ipykernel_2079/256065787.py[0m(10)[0;36m<module>[0;34m()[0m
[0;32m      8 [0;31m[0mstorage[0m[0;34m.[0m[0mreset[0m[0;34m([0m[0menvs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      9 [0;31m[0;34m[0m[0m
[0m[0;32m---> 10 [0;31m[0;32mfor[0m [0m_[0m [0;32min[0m [0mrange[0m[0;34m([0m[0margs[0m[0;34m.[0m[0mnum_steps[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     11 [0;31m[0;34m[0m[0m
[0m[0;32m     12 [0;31m    [0mstorage[0m[0;34m.[0m[0mstore[0m[0;34m([0m[0;34m"obs"[0m[0;34m,[0m [0mnext_obs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/tmp/ipykernel_2079/256065787.py[0m(10)[0;36m<module>[0;34m()[0m
[0;32m      8 [0;31m[0mstorage[0m[0;34m.[0m[0mreset[0m[0;34m([0m[0menvs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      9 [0;31m[0;34m[0m[0m
[0m[0;32m---> 10 [0;31m[0;32mfor[0m [0m_[0m [0;32min[0m [0mrange[0m[0;34m([0m[0margs[0m[0;34m.[0m[0mnum_steps[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     11 [0;31m[0;34m[0m[0m
[0m[0;32m     12 [0;31m    [0mstorage[0m[0;34m.[0m[0mstore[0m[0;34m([0m[0;34m"obs"[0m[0;34m,[0m [0mnext_obs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/tmp/ipykernel_2079/256065787.py[0m(10)[0;36m<module>[0;34m()[0m
[0;32m      8 [0;31m[0mstorage[0m[0;34m.[0m[0mreset[0m[0;34m([0m[0menvs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      9 [0;31m[0;34m[0m[0m
[0m[0;32m---> 10 [0;31m[0;32mfor[0m [0m_[0m [0;32min[0m [0mrange[0m[0;34m([0m[0margs[0m[0;34m.[0m[0mnum_steps[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     11 [0;31m[0;34m[0m[0m
[0m[0;32m     12 [0;31m    [0mstorage[0m[0;34m.[0m[0mstore[0m[0;34m([0m[0;34m"obs"[0m[0;34m,[0m [0mnext_obs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/tmp/ipykernel_2079/256065787.py[0m(10)[0;36m<module>[0;34m()[0m
[0;32m      8 [0;31m[0mstorage[0m[0;34m.[0m[0mreset[0m[0;34m([0m[0menvs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      9 [0;31m[0;34m[0m[0m
[0m[0;32m---> 10 [0;31m[0;32mfor[0m [0m_[0m [0;32min[0m [0mrange[0m[0;34m([0m[0margs[0m[0;34m.[0m[0mnum_steps[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     11 [0;31m[0;34m[0m[0m
[0m[0;32m     12 [0;31m    [0mstorage[0m[0;34m.[0m[0mstore[0m[0;34m([0m[0;34m"obs"[0m[0;34m,[0m [0mnext_obs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/tmp/ipykernel_2079/256065787.py[0m(10)[0;36m<module>[0;34m()[0m
[0;32m      8 [0;31m[0mstorage[0m[0;34m.[0m[0mreset[0m[0;34m([0m[0menvs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      9 [0;31m[0;34m[0m[0m
[0m[0;32m---> 10 [0;31m[0;32mfor[0m [0m_[0m [0;32min[0m [0mrange[0m[0;34m([0m[0margs[0m[0;34m.[0m[0mnum_steps[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     11 [0;31m[0;34m[0m[0m
[0m[0;32m     12 [0;31m    [0mstorage[0m[0;34m.[0m[0mstore[0m[0;34m([0m[0;34m"obs"[0m[0;34m,[0m [0mnext_obs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/tmp/ipykernel_2079/256065787.py[0m(10)[0;36m<module>[0;34m()[0m
[0;32m      8 [0;31m[0mstorage[0m[0;34m.[0m[0mreset[0m[0;34m([0m[0menvs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      9 [0;31m[0;34m[0m[0m
[0m[0;32m---> 10 [0;31m[0;32mfor[0m [0m_[0m [0;32min[0m [0mrange[0m[0;34m([0m[0margs[0m[0;34m.[0m[0mnum_steps[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     11 [0;31m[0;34m[0m[0m
[0m[0;32m     12 [0;31m    [0mstorage[0m[0;34m.[0m[0mstore[0m[0;34m([0m[0;34m"obs"[0m[0;34m,[0m [0mnext_obs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/tmp/ipykernel_2079/256065787.py[0m(10)[0;36m<module>[0;34m()[0m
[0;32m      8 [0;31m[0mstorage[0m[0;34m.[0m[0mreset[0m[0;34m([0m[0menvs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      9 [0;31m[0;34m[0m[0m
[0m[0;32m---> 10 [0;31m[0;32mfor[0m [0m_[0m [0;32min[0m [0mrange[0m[0;34m([0m[0margs[0m[0;34m.[0m[0mnum_steps[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     11 [0;31m[0;34m[0m[0m
[0m[0;32m     12 [0;31m    [0mstorage[0m[0;34m.[0m[0mstore[0m[0;34m([0m[0;34m"obs"[0m[0;34m,[0m [0mnext_obs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/tmp/ipykernel_2079/256065787.py[0m(10)[0;36m<module>[0;34m()[0m
[0;32m      8 [0;31m[0mstorage[0m[0;34m.[0m[0mreset[0m[0;34m([0m[0menvs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      9 [0;31m[0;34m[0m[0m
[0m[0;32m---> 10 [0;31m[0;32mfor[0m [0m_[0m [0;32min[0m [0mrange[0m[0;34m([0m[0margs[0m[0;34m.[0m[0mnum_steps[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     11 [0;31m[0;34m[0m[0m
[0m[0;32m     12 [0;31m    [0mstorage[0m[0;34m.[0m[0mstore[0m[0;34m([0m[0;34m"obs"[0m[0;34m,[0m [0mnext_obs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  storage.step


13


ipdb>  storage.terminated[:15,:]


tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]], device='cuda:0')


ipdb>  c


> [0;32m/tmp/ipykernel_2079/256065787.py[0m(10)[0;36m<module>[0;34m()[0m
[0;32m      8 [0;31m[0mstorage[0m[0;34m.[0m[0mreset[0m[0;34m([0m[0menvs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      9 [0;31m[0;34m[0m[0m
[0m[0;32m---> 10 [0;31m[0;32mfor[0m [0m_[0m [0;32min[0m [0mrange[0m[0;34m([0m[0margs[0m[0;34m.[0m[0mnum_steps[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     11 [0;31m[0;34m[0m[0m
[0m[0;32m     12 [0;31m    [0mstorage[0m[0;34m.[0m[0mstore[0m[0;34m([0m[0;34m"obs"[0m[0;34m,[0m [0mnext_obs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/tmp/ipykernel_2079/256065787.py[0m(10)[0;36m<module>[0;34m()[0m
[0;32m      8 [0;31m[0mstorage[0m[0;34m.[0m[0mreset[0m[0;34m([0m[0menvs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      9 [0;31m[0;34m[0m[0m
[0m[0;32m---> 10 [0;31m[0;32mfor[0m [0m_[0m [0;32min[0m [0mrange[0m[0;34m([0m[0margs[0m[0;34m.[0m[0mnum_steps[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     11 [0;31m[0;34m[0m[0m
[0m[0;32m     12 [0;31m    [0mstorage[0m[0;34m.[0m[0mstore[0m[0;34m([0m[0;34m"obs"[0m[0;34m,[0m [0mnext_obs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/tmp/ipykernel_2079/256065787.py[0m(10)[0;36m<module>[0;34m()[0m
[0;32m      8 [0;31m[0mstorage[0m[0;34m.[0m[0mreset[0m[0;34m([0m[0menvs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      9 [0;31m[0;34m[0m[0m
[0m[0;32m---> 10 [0;31m[0;32mfor[0m [0m_[0m [0;32min[0m [0mrange[0m[0;34m([0m[0margs[0m[0;34m.[0m[0mnum_steps[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     11 [0;31m[0;34m[0m[0m
[0m[0;32m     12 [0;31m    [0mstorage[0m[0;34m.[0m[0mstore[0m[0;34m([0m[0;34m"obs"[0m[0;34m,[0m [0mnext_obs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/tmp/ipykernel_2079/256065787.py[0m(10)[0;36m<module>[0;34m()[0m
[0;32m      8 [0;31m[0mstorage[0m[0;34m.[0m[0mreset[0m[0;34m([0m[0menvs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      9 [0;31m[0;34m[0m[0m
[0m[0;32m---> 10 [0;31m[0;32mfor[0m [0m_[0m [0;32min[0m [0mrange[0m[0;34m([0m[0margs[0m[0;34m.[0m[0mnum_steps[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     11 [0;31m[0;34m[0m[0m
[0m[0;32m     12 [0;31m    [0mstorage[0m[0;34m.[0m[0mstore[0m[0;34m([0m[0;34m"obs"[0m[0;34m,[0m [0mnext_obs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


episode(s) terminated
> [0;32m/tmp/ipykernel_2079/256065787.py[0m(10)[0;36m<module>[0;34m()[0m
[0;32m      8 [0;31m[0mstorage[0m[0;34m.[0m[0mreset[0m[0;34m([0m[0menvs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      9 [0;31m[0;34m[0m[0m
[0m[0;32m---> 10 [0;31m[0;32mfor[0m [0m_[0m [0;32min[0m [0mrange[0m[0;34m([0m[0margs[0m[0;34m.[0m[0mnum_steps[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     11 [0;31m[0;34m[0m[0m
[0m[0;32m     12 [0;31m    [0mstorage[0m[0;34m.[0m[0mstore[0m[0;34m([0m[0;34m"obs"[0m[0;34m,[0m [0mnext_obs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  info


{'episode': {'r': array([ 0., 18.,  0.,  0.]), '_r': array([False,  True, False, False]), 'l': array([ 0, 18,  0,  0]), '_l': array([False,  True, False, False]), 't': array([ 0.      , 56.411231,  0.      ,  0.      ]), '_t': array([False,  True, False, False])}, '_episode': array([False,  True, False, False])}


ipdb>  storage.episode_info


   cumulative_reward  episode_length
0                  0               0
1                 18              18
2                  0               0
3                  0               0


ipdb>  next_terminated


tensor([0., 1., 0., 0.], device='cuda:0')


ipdb>  c


> [0;32m/tmp/ipykernel_2079/256065787.py[0m(10)[0;36m<module>[0;34m()[0m
[0;32m      8 [0;31m[0mstorage[0m[0;34m.[0m[0mreset[0m[0;34m([0m[0menvs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m      9 [0;31m[0;34m[0m[0m
[0m[0;32m---> 10 [0;31m[0;32mfor[0m [0m_[0m [0;32min[0m [0mrange[0m[0;34m([0m[0margs[0m[0;34m.[0m[0mnum_steps[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     11 [0;31m[0;34m[0m[0m
[0m[0;32m     12 [0;31m    [0mstorage[0m[0;34m.[0m[0mstore[0m[0;34m([0m[0;34m"obs"[0m[0;34m,[0m [0mnext_obs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  storage.terminated[:20,:]


tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 0., 0.]], device='cuda:0')


In [21]:
storage.episode_info

Unnamed: 0,cumulative_reward,episode_length
0,10,10
1,15,15
2,14,14
3,13,13


In [24]:
storage.terminated[:15,0].argwhere()

tensor([[12]], device='cuda:0')