# Training DNQ for CartPole

Based on [Deep Q-Network from Tianshou](https://tianshou.org/en/stable/01_tutorials/00_dqn.html).

In [1]:
%load_ext autoreload
%autoreload 2

# Install dependencies

In [2]:
!pip install gymnasium==0.29.1 pygame==2.3.0 pettingzoo==1.24.3 tianshou==0.5.1

Collecting gymnasium==0.29.1
  Downloading gymnasium-0.29.1-py3-none-any.whl (953 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m953.9/953.9 kB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pygame==2.3.0
  Downloading pygame-2.3.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting pettingzoo==1.24.3
  Downloading pettingzoo-1.24.3-py3-none-any.whl (847 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m847.8/847.8 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting tianshou==0.5.1
  Downloading tianshou-0.5.1-py3-none-any.whl (163 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m163.1/163.1 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
Collecting farama-notifications>=0.0.1 (from gymnasium==0.29.1)
  Downloading Farama_Notifications-0.0.4-py3-none-any.w

# Setup environment

In [25]:
import gymnasium as gym
import tianshou as ts

def get_env(render_mode=None):
  return gym.make("CartPole-v1", render_mode=render_mode)

env = get_env()

In [26]:
train_envs = ts.env.DummyVectorEnv([get_env for _ in range(10)])
test_envs = ts.env.DummyVectorEnv([get_env for _ in range(100)])

# Setup PyTorch Network

In [27]:
import torch
import numpy as np
from torch import nn


class Net(nn.Module):
    def __init__(self, state_shape, action_shape):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(np.prod(state_shape), 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, 128),
            nn.ReLU(inplace=True),
            nn.Linear(128, np.prod(action_shape)),
        )

    def forward(self, obs, state=None, info={}):
        if not isinstance(obs, torch.Tensor):
            obs = torch.tensor(obs, dtype=torch.float)
        batch = obs.shape[0]
        logits = self.model(obs.view(batch, -1))
        return logits, state


state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n

net = Net(state_shape, action_shape)
optim = torch.optim.Adam(net.parameters(), lr=1e-3)

# Create Policy

In [28]:
policy = ts.policy.DQNPolicy(
    model=net,
    optim=optim,
    action_space=env.action_space,
    discount_factor=0.9,
    estimation_step=3,
    target_update_freq=320,
)

In [29]:
train_collector = ts.data.Collector(
    policy, train_envs, ts.data.VectorReplayBuffer(20000, 10), exploration_noise=True
)
test_collector = ts.data.Collector(policy, test_envs, exploration_noise=True)

# Train the agent
Using the OffPolicyTrainer.

In [31]:
result = ts.trainer.OffpolicyTrainer(
    policy=policy,
    train_collector=train_collector,
    test_collector=test_collector,
    max_epoch=10,
    step_per_epoch=50000,
    step_per_collect=2000,
    update_per_step=0.1,
    episode_per_test=100,
    batch_size=256,
    train_fn=lambda epoch, env_step: policy.set_eps(0.1),
    test_fn=lambda epoch, env_step: policy.set_eps(0.05),
    stop_fn=lambda mean_rewards: mean_rewards >= 500,
).run()
print(f'Finished training! Use {result["duration"]}')

result

Epoch #1: 50001it [00:27, 1792.95it/s, env_step=50000, len=165, loss=0.016, n/ep=12, n/st=2000, rew=165.92]                           


Epoch #1: test_reward: 129.220000 ± 5.369506, best_reward: 129.220000 ± 5.369506 in #1


Epoch #2: 50001it [00:33, 1512.59it/s, env_step=100000, len=472, loss=0.008, n/ep=8, n/st=2000, rew=472.75]                           


Epoch #2: test_reward: 500.000000 ± 0.000000, best_reward: 500.000000 ± 0.000000 in #2
Finished training! Use 63.38s


{'duration': '63.38s',
 'train_time/model': '45.43s',
 'test_step': 178156,
 'test_episode': 600,
 'test_time': '6.93s',
 'test_speed': '25693.27 step/s',
 'best_reward': 500.0,
 'best_result': '500.00 ± 0.00',
 'train_step': 100000,
 'train_episode': 1158,
 'train_time/collector': '11.02s',
 'train_speed': '1771.57 step/s'}

# Play

Play with the trained agent to the opponent a number of episodes and print the results

In [32]:
policy.eval()

env = get_env(render_mode=None)
env = ts.env.DummyVectorEnv([lambda: env])
collector = ts.data.Collector(policy, env, exploration_noise=True)
result = collector.collect(n_episode=100, render=None)
rews, lens = result["rews"], result["lens"]

display(rews.mean())

490.97

Plot Result

In [33]:
import plotly.figure_factory as ff

fig = ff.create_distplot([result['rews']], ['reward'])
fig.update_layout(title_text='CartPole DQN Result')
fig.show()