In [10]:
%load_ext autoreload
%autoreload 1
%aimport lib

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [11]:
import collections
import random
import time
import math

import torch
import numpy as np
import gym
import matplotlib.pyplot as plt
%matplotlib notebook

In [16]:
env = gym.make("CartPole-v0").unwrapped

n_middle_feature = 100
n_middle_advantage = 50
n_middle_value = 50
n_input = env.observation_space.shape[0]
n_output = env.action_space.n

model_type = "dueling"
#model_type = "single"
if model_type == "dueling":
    namer = lib.make_namer()
    feater_output = torch.nn.Linear(n_middle_feature, n_middle_feature)
    feature = torch.nn.Sequential(collections.OrderedDict([
        (namer("fc"), torch.nn.Linear(n_input, n_middle_feature)),
        (namer("ac"), torch.nn.Tanh()),
        (namer("fc"), torch.nn.Linear(n_middle_feature, n_middle_feature)),
        (namer("ac"), torch.nn.Tanh()),
        (namer("fc"), feater_output),
        (namer("ac"), torch.nn.Tanh()),
    ]))
    namer = lib.make_namer()
    advantage = torch.nn.Sequential(collections.OrderedDict([
        (namer("fc"), torch.nn.Linear(n_middle_feature, n_middle_advantage)),
        (namer("ac"), torch.nn.Tanh()),
        (namer("fc"), torch.nn.Linear(n_middle_advantage, n_middle_advantage)),
        (namer("ac"), torch.nn.Tanh()),
        (namer("fc"), torch.nn.Linear(n_middle_advantage, n_output)),
        (namer("mean0"), lib.Mean0()),
    ]))
    namer = lib.make_namer()
    value = torch.nn.Sequential(collections.OrderedDict([
        (namer("fc"), torch.nn.Linear(n_middle_feature, n_middle_value)),
        (namer("ac"), torch.nn.Tanh()),
        (namer("fc"), torch.nn.Linear(n_middle_value, n_middle_value)),
        (namer("ac"), torch.nn.Tanh()),
        (namer("fc"), torch.nn.Linear(n_middle_value, 1)),
    ]))
    model = lib.Model(feature, value, advantage)
elif model_type == "single":
    namer = lib.make_namer()
    model = torch.nn.Sequential(collections.OrderedDict([
        (namer("fc"), torch.nn.Linear(n_input, n_middle_feature)),
        (namer("ac"), torch.nn.Tanh()),
        (namer("fc"), torch.nn.Linear(n_middle_feature, n_middle_feature)),
        (namer("ac"), torch.nn.Tanh()),
        (namer("fc"), torch.nn.Linear(n_middle_feature, n_middle_feature)),
        (namer("ac"), torch.nn.Tanh()),
        (namer("fc"), torch.nn.Linear(n_middle_feature, n_middle_feature)),
        (namer("ac"), torch.nn.Tanh()),
        (namer("fc"), torch.nn.Linear(n_middle_feature, n_middle_feature)),
        (namer("ac"), torch.nn.Tanh()),
        (namer("fc"), torch.nn.Linear(n_middle_feature, n_output)),
    ]))
else:
    raise ValueError(f"Unsupported model_type: {model_type}")
model.apply(lib.init_model)

opt = torch.optim.SGD(model.parameters(), lr=3e-4, momentum=0.9)
#opt = torch.optim.Adam(model.parameters(), lr=1e-3, eps=1e-3)

n_replay_memory = 1_000_000
n_batch = 31

rm = lib.ReplayMemory(capacity=n_replay_memory, random_state=42)
loss = torch.nn.SmoothL1Loss()
#loss = torch.nn.MSELoss()
agent = lib.Agent(model=model, opt=opt, gamma=0.99, replay_memory=rm, n_batch=n_batch, cuda=False, alpha=0.5, loss=loss, dqn_mode="doubledqn", td_mode="mnih2015")

n_episodes = 1000
n_steps = 300
n_steps_start = 300
n_step_update = 50
n_step_train = 1
epsilon = 0.05

i_total_step = 1
for i_episode in range(1, n_episodes + 1):
    si = env.reset()
    r_episode = 0
    for i_step in range(1, n_steps + 1):
        if i_total_step <= n_steps_start:
            ai1 = env.action_space.sample()
        elif random.random() < epsilon:
            ai1 = env.action_space.sample()
        else:
            ai1 = agent.act(si)
        si1, ri1, done, _ = env.step(ai1)
        r_episode += ri1
        agent.push(state=si, action=ai1, reward_next=ri1, state_next=si1, done=done)
        if (i_total_step%n_step_train == 0) and (i_total_step > n_steps_start):
            loss = agent.train()
            #print(loss["loss"].data.numpy()[0], np.std(loss["td"].data.numpy()))
        if i_total_step%n_step_update == 0:
            agent.update_target_model()
        if done:
            break
        si = si1
        i_total_step += 1
    if i_episode%4 == 0:
        print(i_episode, i_total_step, r_episode)
        

# env.reset()
# for i in range(1000):
#     env.render()
#     a = env.action_space.sample()
#     r = env.step(a)
#     print(a, r[1:])
#     time.sleep(0.25)

4 53 13.0
8 130 15.0
12 201 20.0
16 306 25.0
20 390 10.0
24 426 10.0
28 463 12.0
32 498 10.0
36 533 10.0
40 566 8.0
44 614 14.0
48 762 65.0
52 1462 149.0
56 1827 57.0
60 1971 38.0
64 2119 39.0
68 2313 62.0
72 2623 104.0
76 2992 108.0
80 3369 117.0
84 3621 62.0
88 3971 86.0
92 4256 31.0
96 4342 21.0
100 4412 20.0
104 4501 26.0
108 4597 22.0
112 4678 20.0
116 4765 22.0
120 4843 17.0
124 4963 35.0
128 5053 14.0
132 5182 16.0
136 5315 76.0
140 5549 89.0
144 5873 46.0
148 6270 126.0
152 6806 137.0
156 7597 300.0
160 8353 209.0
164 8869 106.0
168 9258 109.0
172 9549 66.0
176 9927 95.0
180 10350 153.0
184 11180 300.0
188 12380 300.0
192 13580 300.0
196 14505 300.0
200 15113 136.0
204 16313 300.0
208 16593 104.0
212 16759 54.0
216 16955 47.0
220 17100 49.0
224 17345 54.0
228 17823 83.0
232 18276 67.0
236 18739 58.0
240 19140 163.0
244 19677 86.0
248 20216 83.0
252 21017 113.0
256 21549 135.0
260 21995 86.0
264 22595 169.0
268 23330 300.0
272 23898 147.0
276 24448 59.0
280 24654 41.0
284 24884 