In [1]:
import gym
import tianshou as ts

import torch




In [2]:
env = gym.make('Acrobot-v1')

In [3]:
train_envs = ts.env.DummyVectorEnv([lambda: gym.make('Acrobot-v1') for _ in range(10)])
test_envs = ts.env.DummyVectorEnv([lambda: gym.make('Acrobot-v1') for _ in range(100)])

In [4]:
state_shape = env.observation_space.shape or env.observation_space.n
action_shape = env.action_space.shape or env.action_space.n
net = ts.utils.net.common.Net(state_shape, action_shape, activation='ReLu')
optim = torch.optim.Adam(net.parameters(), lr=1e-3)

In [5]:
policy = ts.policy.DQNPolicy(net, optim, discount_factor = 0.9, estimation_step = 3, target_update_freq = 320)

In [6]:
train_collector = ts.data.Collector(policy, train_envs, ts.data.VectorReplayBuffer(20000, 10), exploration_noise = True)
test_collector = ts.data.Collector(policy, test_envs, exploration_noise = True)

In [7]:
result = ts.trainer.offpolicy_trainer(
    policy, train_collector, test_collector,
    max_epoch=10, step_per_epoch=10000, step_per_collect=10,
    update_per_step=0.1, episode_per_test=100, batch_size=64,
    train_fn=lambda epoch, env_step: policy.set_eps(0.1),
    test_fn=lambda epoch, env_step: policy.set_eps(0.05),
    stop_fn=lambda mean_rewards: mean_rewards >= env.spec.reward_threshold)
print(f'Finished training! Use {result["duration"]}')

Epoch #1: 10001it [00:06, 1449.29it/s, env_step=10000, len=500, loss=3.613, n/ep=4, n/st=10, rew=-500.00]                           


Epoch #1: test_reward: -500.000000 ± 0.000000, best_reward: -500.000000 ± 0.000000 in #0


Epoch #2: 10001it [00:06, 1487.00it/s, env_step=20000, len=473, loss=3.087, n/ep=0, n/st=10, rew=-472.00]                           


Epoch #2: test_reward: -258.320000 ± 104.503003, best_reward: -258.320000 ± 104.503003 in #2


Epoch #3: 10001it [00:19, 508.24it/s, env_step=30000, len=189, loss=3.680, n/ep=0, n/st=10, rew=-188.00]                            


Epoch #3: test_reward: -167.810000 ± 39.565817, best_reward: -167.810000 ± 39.565817 in #3


Epoch #4: 10001it [00:06, 1478.81it/s, env_step=40000, len=148, loss=4.268, n/ep=0, n/st=10, rew=-147.00]                           


Epoch #4: test_reward: -216.690000 ± 62.779885, best_reward: -167.810000 ± 39.565817 in #3


Epoch #5: 10001it [00:07, 1322.79it/s, env_step=50000, len=188, loss=3.156, n/ep=0, n/st=10, rew=-187.00]                           


Epoch #5: test_reward: -416.500000 ± 93.917251, best_reward: -167.810000 ± 39.565817 in #3


Epoch #6: 10001it [00:29, 336.80it/s, env_step=60000, len=178, loss=2.369, n/ep=0, n/st=10, rew=-177.00]                            


Epoch #6: test_reward: -244.810000 ± 84.784868, best_reward: -167.810000 ± 39.565817 in #3


Epoch #7: 10001it [00:06, 1430.06it/s, env_step=70000, len=172, loss=1.952, n/ep=0, n/st=10, rew=-171.00]                           


Epoch #7: test_reward: -235.510000 ± 70.430035, best_reward: -167.810000 ± 39.565817 in #3


Epoch #8: 10001it [00:06, 1507.21it/s, env_step=80000, len=210, loss=1.461, n/ep=0, n/st=10, rew=-209.00]                           


Epoch #8: test_reward: -156.470000 ± 31.354890, best_reward: -156.470000 ± 31.354890 in #8


Epoch #9: 10001it [00:06, 1567.58it/s, env_step=90000, len=197, loss=1.339, n/ep=0, n/st=10, rew=-196.00]                           


Epoch #9: test_reward: -193.200000 ± 69.508704, best_reward: -156.470000 ± 31.354890 in #8


Epoch #10: 10001it [00:06, 1553.69it/s, env_step=100000, len=224, loss=1.075, n/ep=0, n/st=10, rew=-223.00]                           


Epoch #10: test_reward: -228.890000 ± 58.131901, best_reward: -156.470000 ± 31.354890 in #8
Finished training! Use 165.72s


In [8]:
from torch.utils.tensorboard import SummaryWriter
from tianshou.utils import BasicLogger

writer = SummaryWriter('log/dqn_Acrobot-v1')
logger = BasicLogger(writer)



In [9]:
torch.save(policy.state_dict(), 'dqn.pth')
policy.load_state_dict(torch.load('dqn.pth'))

<All keys matched successfully>

In [10]:
policy.eval()
policy.set_eps(0.07)
collector = ts.data.Collector(policy, env, exploration_noise=True)
collector.collect(n_episode=1, render=1 / 35)



{'n/ep': 1,
 'n/st': 238,
 'rews': array([-237.]),
 'lens': array([238]),
 'idxs': array([0]),
 'rew': -237.0,
 'len': 238.0,
 'rew_std': 0.0,
 'len_std': 0.0}