use of some DRL algorithm (PPO) to solve the classif CartPole-V1 problem in Gym.

buildings blocks may include
* batch
* replay buffer
* vectorized environment wrapper
* data collector
* trainer
* logger

In [1]:
!uv pip install tianshou gym

[2mResolved [1m51 packages[0m [2min 886ms[0m[0m
[2mPrepared [1m6 packages[0m [2min 4.61s[0m[0m
[2mUninstalled [1m1 package[0m [2min 5ms[0m[0m
[2mInstalled [1m42 packages[0m [2min 6.02s[0m[0m
 [32m+[39m [1mabsl-py[0m[2m==2.2.2[0m
 [32m+[39m [1mcloudpickle[0m[2m==3.1.1[0m
 [32m+[39m [1mcontourpy[0m[2m==1.3.2[0m
 [32m+[39m [1mcycler[0m[2m==0.12.1[0m
 [32m+[39m [1mdeepdiff[0m[2m==7.0.1[0m
 [32m+[39m [1mdistlib[0m[2m==0.3.9[0m
 [32m+[39m [1mfarama-notifications[0m[2m==0.0.4[0m
 [32m+[39m [1mfilelock[0m[2m==3.18.0[0m
 [32m+[39m [1mfonttools[0m[2m==4.57.0[0m
 [32m+[39m [1mfsspec[0m[2m==2025.3.2[0m
 [32m+[39m [1mgrpcio[0m[2m==1.71.0[0m
 [32m+[39m [1mgym[0m[2m==0.26.2[0m
 [32m+[39m [1mgym-notices[0m[2m==0.0.8[0m
 [32m+[39m [1mgymnasium[0m[2m==0.28.1[0m
 [32m+[39m [1mh5py[0m[2m==3.13.0[0m
 [32m+[39m [1mjax-jumpy[0m[2m==1.0.0[0m
 [32m+[39m [1mkiwisolver[0m[2m==1.4.8[0m
 [

In [5]:
%%capture
import gymnasium as gym
import torch

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import DummyVectorEnv
from tianshou.policy import PPOPolicy
from tianshou.trainer import OnpolicyTrainer
from tianshou.utils.net.common import ActorCritic, Net
from tianshou.utils.net.discrete import Actor, Critic

device = "cuda" if torch.cuda.is_available() else "cpu"

In [9]:
#environements
env = gym.make("CartPole-v1")
train_envs = DummyVectorEnv([lambda: gym.make("CartPole-v1") for _ in range(20)])
test_envs = DummyVectorEnv([lambda: gym.make ('CartPole-v1') for _ in range(10)])

# model & optimizer
assert env.observation_space.shape is not None #for mypy
net = Net(state_shape=env.observation_space.shape, hidden_sizes=[64,64], device=device)

assert isinstance(env.action_space, gym.spaces.Discrete) #for mypy

actor = Actor(preprocess_net = net, action_shape = env.action_space.n, device=device).to(device)

critic = Critic(preprocess_net = net, device = device).to(device)

actor_critic = ActorCritic(actor,critic)

optim = torch.optim.Adam(actor_critic.parameters(), lr=0.0003)


# PPO Policy
dist = torch.distributions.Categorical
policy: PPOPolicy = PPOPolicy(
    actor=actor,
    critic=critic,
    optim=optim,
    dist_fn=dist,
    action_space=env.action_space,
    action_scaling=False,
)

#collectior
train_collector = Collector(policy, train_envs, VectorReplayBuffer(20000, len(train_envs)))
test_collector = Collector(policy, test_envs)

#trainer
train_result = OnpolicyTrainer(
    policy=policy,
    batch_size=256,
    train_collector=train_collector,
    test_collector=test_collector,
    max_epoch=10,
    step_per_epoch=50000,
    repeat_per_collect=10,
    episode_per_test=10,
    step_per_collect=2000,
    stop_fn=lambda mean_reward: mean_reward >=195,
).run()


Epoch #1: 50001it [00:17, 2910.45it/s, env_step=50000, gradient_step=200, len=118, n/ep=23, n/st=2000, rew=118.13]


Epoch #1: test_reward: 171.200000 ± 107.831164, best_reward: 171.200000 ± 107.831164 in #1


Epoch #2: 50001it [00:16, 3012.57it/s, env_step=100000, gradient_step=400, len=96, n/ep=8, n/st=2000, rew=96.00]


Epoch #2: test_reward: 212.200000 ± 141.184843, best_reward: 212.200000 ± 141.184843 in #2


In [10]:
train_result.pprint_asdict()

InfoStats
----------------------------------------
{   'best_reward': 212.2,
    'best_reward_std': 141.1848433791673,
    'gradient_step': 400,
    'test_episode': 30,
    'test_step': 4029,
    'timing': {   'test_time': 1.4898450374603271,
                  'total_time': 35.28361105918884,
                  'train_time': 33.793766021728516,
                  'train_time_collect': 18.464709043502808,
                  'train_time_update': 15.129051923751831,
                  'update_speed': 2959.1256545867836},
    'train_episode': 1425,
    'train_step': 100000}


In [11]:
#performance
policy.eval()
eval_result = test_collector.collect(n_episode=3, render=False)
print(f"Final reward: {eval_result.returns.mean()}, length: {eval_result.lens.mean()}")




Final reward: 332.3333333333333, length: 332.3333333333333


# batch

* batch, the most basic data structure in tianshou
* like a numpy version of python dictionary , similar to pytorch tensordict but with different type structure
* in DRL you need to handle a lof of dictionary-format data. most algorithm would require to store state, action and reward data for every step when interacting with the environment.
* all of them can be organized as dictionary and batch class helps in unifying terface of a diverse set of algorithms
* batch support advanced indexing concatenation and splitting formatting print just like any other numpy array, which proved to be helpful for developers

In [17]:
# %%capture

import pickle

import numpy as np
import torch

from tianshou.data import Batch

# converted from a python library
print("========================================")
batch1 = Batch({"a": [4, 4], "b": (5, 5)})
print(batch1)

# initialization of batch2 is equivalent to batch1
print("========================================")
batch2 = Batch(a=[4, 4], b=(5, 5))
print(batch2)

# the dictionary can be nested, and it will be turned into a nested Batch
print("========================================")
data = {
    "action": np.array([1.0, 2.0, 3.0]),
    "reward": 3.66,
    "obs": {
        "rgb_obs": np.zeros((3, 3)),
        "flatten_obs": np.ones(5),
    },
}

batch3 = Batch(data, extra="extra_string")
print(batch3)
# batch3.obs is also a Batch
print(type(batch3.obs))
print(batch3.obs.rgb_obs)

# a list of dictionary/Batch will automatically be concatenated/stacked, providing convenience if you
# want to use parallelized environments to collect data.
print("========================================")
batch4 = Batch([data] * 3)
print(batch4)
print(batch4.obs.rgb_obs.shape)

Batch(
    a: array([4, 4]),
    b: array([5, 5]),
)
Batch(
    a: array([4, 4]),
    b: array([5, 5]),
)
Batch(
    action: array([1., 2., 3.]),
    reward: array(3.66),
    obs: Batch(
             rgb_obs: array([[0., 0., 0.],
                             [0., 0., 0.],
                             [0., 0., 0.]]),
             flatten_obs: array([1., 1., 1., 1., 1.]),
         ),
    extra: 'extra_string',
)
<class 'tianshou.data.batch.Batch'>
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]
Batch(
    obs: Batch(
             flatten_obs: array([[1., 1., 1., 1., 1.],
                                 [1., 1., 1., 1., 1.],
                                 [1., 1., 1., 1., 1.]]),
             rgb_obs: array([[[0., 0., 0.],
                              [0., 0., 0.],
                              [0., 0., 0.]],
                      
                             [[0., 0., 0.],
                              [0., 0., 0.],
                              [0., 0., 0.]],
                      
           

In [18]:
# getting access to data
# search or change key-value pair in a batch just as if dictionary

batch1 = Batch({"a": [4, 4], "b": (5, 5)})
print(batch1)

# add or delete key-value pair in batch1
print("========================================")
batch1.c = Batch(c1=np.arange(3), c2=False)
del batch1.a
print(batch1)

# access value by key
print("========================================")
assert batch1["c"] is batch1.c
print("c" in batch1)

# traverse the Batch
print("========================================")
for key, value in batch1.items():
    print(str(key) + ": " + str(value))

Batch(
    a: array([4, 4]),
    b: array([5, 5]),
)
Batch(
    b: array([5, 5]),
    c: Batch(
           c1: array([0, 1, 2]),
           c2: array(False),
       ),
)
True
b: [5 5]
c: Batch(
    c1: array([0, 1, 2]),
    c2: array(False),
)


In [19]:
#indexing and slicing
# if batch share same shape in certain dimesnsion , can support arary-like indexing and sclicing

# Let us suppose we have collected the data from stepping from 4 environments
step_outputs = [
    {
        "act": np.random.randint(10),
        "rew": 0.0,
        "obs": np.ones((3, 3)),
        "info": {"done": np.random.choice(2), "failed": False},
        "terminated": False,
        "truncated": False,
    }
    for _ in range(4)
]
batch = Batch(step_outputs)
print(batch)
print(batch.shape)

# advanced indexing is supported, if we only want to select data in a given set of environments
print("========================================")
print(batch[0])
print(batch[[0, 3]])

# slicing is also supported
print("========================================")
print(batch[-2:])

Batch(
    obs: array([[[1., 1., 1.],
                 [1., 1., 1.],
                 [1., 1., 1.]],
         
                [[1., 1., 1.],
                 [1., 1., 1.],
                 [1., 1., 1.]],
         
                [[1., 1., 1.],
                 [1., 1., 1.],
                 [1., 1., 1.]],
         
                [[1., 1., 1.],
                 [1., 1., 1.],
                 [1., 1., 1.]]]),
    info: Batch(
              done: array([1, 1, 0, 1]),
              failed: array([False, False, False, False]),
          ),
    act: array([2, 5, 0, 4]),
    rew: array([0., 0., 0., 0.]),
    truncated: array([False, False, False, False]),
    terminated: array([False, False, False, False]),
)
[4]
Batch(
    obs: array([[1., 1., 1.],
                [1., 1., 1.],
                [1., 1., 1.]]),
    info: Batch(
              done: 1,
              failed: False,
          ),
    act: 2,
    rew: 0.0,
    truncated: False,
    terminated: False,
)
Batch(
    obs: array([[[1

In [20]:
# aggregation and splitting
# concat batches with compatible keys
# try incompatible keys yourself if you feel curious
print("========================================")
b1 = Batch(a=[{"b": np.float64(1.0), "d": Batch(e=np.array(3.0))}])
b2 = Batch(a=[{"b": np.float64(4.0), "d": {"e": np.array(6.0)}}])
b12_cat_out = Batch.cat([b1, b2])
print(b1)
print(b2)
print(b12_cat_out)

# stack batches with compatible keys
# try incompatible keys yourself if you feel curious
print("========================================")
b3 = Batch(a=np.zeros((3, 2)), b=np.ones((2, 3)), c=Batch(d=[[1], [2]]))
b4 = Batch(a=np.ones((3, 2)), b=np.ones((2, 3)), c=Batch(d=[[0], [3]]))
b34_stack = Batch.stack((b3, b4), axis=1)
print(b3)
print(b4)
print(b34_stack)

# split the batch into small batches of size 1, breaking the order of the data
print("========================================")
print(type(b34_stack.split(1)))
print(list(b34_stack.split(1, shuffle=True)))

Batch(
    a: Batch(
           d: Batch(
                  e: array([3.]),
              ),
           b: array([1.]),
       ),
)
Batch(
    a: Batch(
           d: Batch(
                  e: array([6.]),
              ),
           b: array([4.]),
       ),
)
Batch(
    a: Batch(
           d: Batch(
                  e: array([3., 6.]),
              ),
           b: array([1., 4.]),
       ),
)
Batch(
    a: array([[0., 0.],
              [0., 0.],
              [0., 0.]]),
    b: array([[1., 1., 1.],
              [1., 1., 1.]]),
    c: Batch(
           d: array([[1],
                     [2]]),
       ),
)
Batch(
    a: array([[1., 1.],
              [1., 1.],
              [1., 1.]]),
    b: array([[1., 1., 1.],
              [1., 1., 1.]]),
    c: Batch(
           d: array([[0],
                     [3]]),
       ),
)
Batch(
    c: Batch(
           d: array([[[1],
                      [0]],
              
                     [[2],
                      [3]]]),
       ),


In [23]:
# batch actually supports torch tensor. usages exactly the same
batch1 = Batch(a=np.arange(2), b=torch.zeros((2, 2)))
batch2 = Batch(a=np.arange(2), b=torch.ones((2, 2)))
batch_cat = Batch.cat([batch1, batch2, batch1])
print(batch_cat)

# can convert the data type easily if you no longer want to use hybrid data type anymore
batch_cat.to_numpy_()
print(batch_cat)
batch_cat.to_torch_()
print(batch_cat)
# batch is serializable , if need to save disk or restore it
batch = Batch(obs=Batch(a=0.0,c=torch.tensor([1.0,2.0])), np=np.zeros([3,4]))
batch_pk = pickle.loads(pickle.dumps(batch))
print(batch_pk)

Batch(
    b: tensor([[0., 0.],
               [0., 0.],
               [1., 1.],
               [1., 1.],
               [0., 0.],
               [0., 0.]]),
    a: array([0, 1, 0, 1, 0, 1]),
)
Batch(
    b: array([[0., 0.],
              [0., 0.],
              [1., 1.],
              [1., 1.],
              [0., 0.],
              [0., 0.]], dtype=float32),
    a: array([0, 1, 0, 1, 0, 1]),
)
Batch(
    b: tensor([[0., 0.],
               [0., 0.],
               [1., 1.],
               [1., 1.],
               [0., 0.],
               [0., 0.]]),
    a: tensor([0, 1, 0, 1, 0, 1], dtype=torch.int32),
)
Batch(
    obs: Batch(
             a: array(0.),
             c: tensor([1., 2.]),
         ),
    np: array([[0., 0., 0., 0.],
               [0., 0., 0., 0.],
               [0., 0., 0., 0.]]),
)
