In [1]:
%load_ext autoreload
%autoreload 1
%aimport lib_prm

In [10]:
import collections
import random
import time
import math

import torch
import numpy as np
import gym
import matplotlib.pyplot as plt
%matplotlib notebook

In [18]:
env = gym.make("CartPole-v0").unwrapped

n_middle_feature = 100
n_middle_advantage = 50
n_middle_value = 50
n_input = env.observation_space.shape[0]
n_output = env.action_space.n

model_type = "dueling"
#model_type = "single"
if model_type == "dueling":
    namer = lib_prm.make_namer()
    feater_output = torch.nn.Linear(n_middle_feature, n_middle_feature)
    feature = torch.nn.Sequential(collections.OrderedDict([
        (namer("fc"), torch.nn.Linear(n_input, n_middle_feature)),
        (namer("ac"), torch.nn.Tanh()),
        (namer("fc"), torch.nn.Linear(n_middle_feature, n_middle_feature)),
        (namer("ac"), torch.nn.Tanh()),
        (namer("fc"), feater_output),
        (namer("ac"), torch.nn.Tanh()),
    ]))
    namer = lib_prm.make_namer()
    advantage = torch.nn.Sequential(collections.OrderedDict([
        (namer("fc"), torch.nn.Linear(n_middle_feature, n_middle_advantage)),
        (namer("ac"), torch.nn.Tanh()),
        (namer("fc"), torch.nn.Linear(n_middle_advantage, n_middle_advantage)),
        (namer("ac"), torch.nn.Tanh()),
        (namer("fc"), torch.nn.Linear(n_middle_advantage, n_output)),
        (namer("mean0"), lib_prm.Mean0()),
    ]))
    namer = lib_prm.make_namer()
    value = torch.nn.Sequential(collections.OrderedDict([
        (namer("fc"), torch.nn.Linear(n_middle_feature, n_middle_value)),
        (namer("ac"), torch.nn.Tanh()),
        (namer("fc"), torch.nn.Linear(n_middle_value, n_middle_value)),
        (namer("ac"), torch.nn.Tanh()),
        (namer("fc"), torch.nn.Linear(n_middle_value, 1)),
    ]))
    model = lib_prm.Model(feature, value, advantage)
elif model_type == "single":
    namer = lib_prm.make_namer()
    model = torch.nn.Sequential(collections.OrderedDict([
        (namer("fc"), torch.nn.Linear(n_input, n_middle_feature)),
        (namer("ac"), torch.nn.Tanh()),
        (namer("fc"), torch.nn.Linear(n_middle_feature, n_middle_feature)),
        (namer("ac"), torch.nn.Tanh()),
        (namer("fc"), torch.nn.Linear(n_middle_feature, n_middle_feature)),
        (namer("ac"), torch.nn.Tanh()),
        (namer("fc"), torch.nn.Linear(n_middle_feature, n_middle_feature)),
        (namer("ac"), torch.nn.Tanh()),
        (namer("fc"), torch.nn.Linear(n_middle_feature, n_middle_feature)),
        (namer("ac"), torch.nn.Tanh()),
        (namer("fc"), torch.nn.Linear(n_middle_feature, n_output)),
    ]))
else:
    raise ValueError(f"Unsupported model_type: {model_type}")
model.apply(lib_prm.init_model)

opt = torch.optim.SGD(model.parameters(), lr=3e-4, momentum=0.9)
#opt = torch.optim.Adam(model.parameters(), lr=1e-3, eps=1e-3)

n_replay_memory = 1_000_000
n_batch = 31

rm = lib_prm.prm.PrioritizedReplayMemory(capacity=n_replay_memory, n_batch=n_batch, alpha=0.5, random_state=42)
loss = torch.nn.SmoothL1Loss()
#loss = torch.nn.MSELoss()
agent = lib_prm.Agent(model=model, opt=opt, gamma=0.99, replay_memory=rm, n_batch=n_batch, cuda=False, alpha=0.5, loss=loss, dqn_mode="doubledqn", td_mode="mnih2015")

n_episodes = 1000
n_steps = 1000
n_steps_start = 300
n_step_update = 50
n_step_train = 1
epsilon = 0.05

i_total_step = 1
for i_episode in range(1, n_episodes + 1):
    si = env.reset()
    r_episode = 0
    for i_step in range(1, n_steps + 1):
        if i_total_step <= n_steps_start:
            ai1 = env.action_space.sample()
        elif random.random() < epsilon:
            ai1 = env.action_space.sample()
        else:
            ai1 = agent.act(si)
        si1, ri1, done, _ = env.step(ai1)
        r_episode += ri1
        env.render()
        import time
        time.sleep(0.001)
        agent.push(state=si, action=ai1, reward_next=ri1, state_next=si1, done=done)
        if (i_total_step%n_step_train == 0) and (i_total_step > n_steps_start):
            loss = agent.train()
            #print(loss["loss"].data.numpy()[0], np.std(loss["td"].data.numpy()))
        if i_total_step%n_step_update == 0:
            agent.update_target_model()
        if i_total_step%50 == 0:
            agent.replay_memory.sort()
        if done:
            break
        si = si1
        i_total_step += 1
    if i_episode%4 == 0:
        print(i_episode, i_total_step, r_episode)
        

# env.reset()
# for i in range(1000):
#     env.render()
#     a = env.action_space.sample()
#     r = env.step(a)
#     print(a, r[1:])
#     time.sleep(0.25)

4 75 17.0
8 170 28.0
12 250 15.0
16 312 10.0
20 344 8.0
24 376 9.0
28 411 10.0
32 445 8.0
36 481 10.0
40 529 10.0
44 579 13.0
48 804 79.0
52 1280 113.0
56 2057 320.0
60 3183 197.0
64 4098 342.0
68 5311 179.0
72 6358 318.0
76 7966 260.0
80 9415 142.0
84 11261 1000.0
88 13208 106.0
92 14228 303.0
96 16330 1000.0
100 18748 451.0
104 20058 153.0
108 20839 160.0
112 22100 201.0
116 24279 348.0
120 25324 136.0
124 26969 469.0
128 29047 178.0
132 29849 180.0
136 30545 237.0
140 32065 291.0
144 32497 130.0
148 33352 141.0


KeyboardInterrupt: 