In [14]:
import gymnasium as gym
import tianshou as ts

env = gym.make("CartPole-v1")

## Environments

In [15]:
train_envs = gym.make("CartPole-v1")
test_envs = gym.make("CartPole-v1")

In [16]:
train_envs = ts.env.DummyVectorEnv([lambda: gym.make("CartPole-v1") for _ in range(10)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make("CartPole-v1") for _ in range(100)])

## Creat the PyTorch Network

In [17]:
import torch
import numpy as np
from torch import nn


class Net(nn.Module):
    def __init__(self, state_shape, action_shape):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(np.prod(state_shape), 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, np.prod(action_shape)),
        )

    def forward(self, obs, state=None, info={}):
        if not isinstance(obs, torch.Tensor):
            obs = torch.tensor(obs, dtype=torch.float)
        batch = obs.shape[0]
        logits = self.model(obs.view(batch, -1))
        return logits, state


state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
net = Net(state_shape, action_shape)
optim = torch.optim.Adam(net.parameters(), lr=1e-3)

## Create the Policy

In [18]:
policy = ts.policy.DQNPolicy(
    model=net,
    optim=optim,
    action_space=env.action_space,
    discount_factor=0.9,
    estimation_step=3,
    target_update_freq=320,
)

In [19]:
train_collector = ts.data.Collector(
    policy, train_envs, ts.data.VectorReplayBuffer(20000, 10), exploration_noise=True
)
test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True)

## Train

In [24]:
result = ts.trainer.OffpolicyTrainer(
    policy=policy,
    train_collector=train_collector,
    test_collector=test_collector,
    max_epoch=1,
    step_per_epoch=10000,
    step_per_collect=10,
    update_per_step=0.1,
    episode_per_test=100,
    batch_size=64,
    train_fn=lambda epoch, env_step: policy.set_eps(0.1),
    test_fn=lambda epoch, env_step: policy.set_eps(0.05),
    stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold,
).run()
print(f'Finished training! Use {result["duration"]}')

result

Epoch #1: 10001it [01:45, 94.47it/s, env_step=10000, len=218, loss=0.304, n/ep=0, n/st=10, rew=218.00]                           


Epoch #1: test_reward: 217.220000 ± 56.067920, best_reward: 246.240000 ± 79.688534 in #0
Finished training! Use 107.56s


{'duration': '107.56s',
 'train_time/model': '71.23s',
 'test_step': 46346,
 'test_episode': 200,
 'test_time': '1.67s',
 'test_speed': '27679.65 step/s',
 'best_reward': 246.24,
 'best_result': '246.24 ± 79.69',
 'train_step': 10000,
 'train_episode': 50,
 'train_time/collector': '34.66s',
 'train_speed': '94.44 step/s'}

## Test Policy

In [25]:
policy.eval()
policy.set_eps(0.05)
collector = ts.data.Collector(policy, env, exploration_noise=True)
collector.collect(n_episode=1, render=1 / 35)

  gym.logger.warn(


{'n/ep': 1,
 'n/st': 197,
 'rews': array([197.]),
 'lens': array([197]),
 'idxs': array([0]),
 'rew': 197.0,
 'len': 197.0,
 'rew_std': 0.0,
 'len_std': 0.0}

In [26]:
torch.save(policy.state_dict(), 'cartpole_dqn.pth')