In [29]:
import gym
import chainer
import collections
import numpy as np
import itertools
import math
import copy
import random

env = gym.make("CartPole-v0")#.unwrapped

n_input = env.observation_space.shape[0]
n_output = env.action_space.n

def _conf_of(**kwargs):
    return collections.namedtuple("_Conf", kwargs.keys())(**kwargs)

args = _conf_of(
    n_middle = 50,
    lr = 1e-2,
    gamma = 0.95,
    n_batch = 32,
    n_episodes=200,
    n_start_train=500,
    n_target_update_interval=100,
    n_steps=200,
    epsilon=0.3,
)

class Model(chainer.Chain):

    def __init__(self, n_input, n_middle, n_ouput):
        super().__init__()
        with self.init_scope():
            self.l1 = chainer.links.Linear(n_input, n_middle)
            self.l2 = chainer.links.Linear(n_middle, n_middle)
            self.l3 = chainer.links.Linear(n_middle, n_output)
            
    def __call__(self, input):
        h = chainer.functions.tanh(self.l1(input))
        h = chainer.functions.tanh(self.l2(h))
        return self.l3(h)

        
model = Model(n_input, 50, n_output)
target_model = copy.deepcopy(model)


def copy_param(target_link, source_link):
    """Copy parameters of a link to another link."""
    target_params = dict(target_link.namedparams())
    for param_name, param in source_link.namedparams():
        target_params[param_name].data[:] = param.data

    # Copy Batch Normalization's statistics
    target_links = dict(target_link.namedlinks())
    for link_name, link in source_link.namedlinks():
        if isinstance(link, torch.links.BatchNormalization):
            target_bn = target_links[link_name]
            target_bn.avg_mean[:] = link.avg_mean
            target_bn.avg_var[:] = link.avg_var


#opt = chainer.optimizers.Adam(eps=1e-2)
opt = chainer.optimizers.SGD(lr=1e-2)
opt.setup(model)

buffer = []
episode_result_list = []
i_total_step = -1
for i_episode in range(1, args.n_episodes + 1):
    si = env.reset()
    step_result_list = []
    for i_step in range(1, args.n_steps + 1):
        i_total_step += 1
        
        if random.random() < args.epsilon:
            ai1 = env.action_space.sample()
        else:
            with chainer.no_backprop_mode():
                ai1 = int(model(chainer.Variable(np.array([si], dtype=np.float32))).data.argmax(axis=1)[0])
        si1, ri1, done, debug_info = env.step(ai1)
        buffer.append(dict(si=si, ai1=ai1, ri1=ri1, si1=si1, done=done))
        metric = None
        if i_total_step > args.n_start_train:
            batch = random.sample(buffer, args.n_batch)
            batch = dict(
                si=np.array([t["si"] for t in batch], dtype=np.float32),
                ai1=np.array([t["ai1"] for t in batch], dtype=int),
                ri1=np.array([t["ri1"] for t in batch], dtype=np.float32),
                si1=np.array([t["si1"] for t in batch], dtype=np.float32),
                mask=np.array([not t["done"] for t in batch], dtype=np.float32),
            )
            q_pred = chainer.functions.reshape(chainer.functions.select_item(model(chainer.Variable(batch["si"])), chainer.Variable(batch["ai1"])), (-1, 1))
            with chainer.no_backprop_mode():
                q_target = (
                    chainer.Variable(batch["ri1"])
                    + args.gamma*chainer.Variable(batch["mask"])*chainer.functions.select_item(target_model(chainer.Variable(batch["si1"])), chainer.Variable(model(chainer.Variable(batch["si1"])).data.argmax(axis=1)))
                ).data.reshape(-1, 1)
            loss = chainer.functions.mean(chainer.functions.huber_loss(q_pred, chainer.Variable(q_target), delta=1))
            model.cleargrads()
            loss.backward()
            opt.update()
        if metric is not None:
            q_list.append(np.mean(metric["q_pred"]))
            step_result_list.append(metric)
        if i_total_step%(args.n_target_update_interval) == 0:
            target_model = copy.deepcopy(model)
        if done:
            break
        si = si1
    episode_result_list.append(step_result_list)
    print(i_step, end="\t")
    if i_episode%10 == 0:
        print()

69	124	79	53	50	73	50	63	67	75	
57	45	75	44	41	33	42	93	67	80	
54	136	76	49	41	64	83	49	64	49	
62	50	50	42	46	32	47	42	66	58	
105	40	70	56	57	70	20	35	151	73	
86	65	79	41	50	39	111	49	65	67	
117	89	59	108	79	65	72	96	62	101	
112	128	78	73	80	50	65	53	94	154	
51	76	48	55	57	78	98	57	113	154	
50	58	85	63	55	70	54	56	66	55	
52	53	50	75	34	81	44	61	58	80	
48	58	113	57	59	37	52	56	65	105	
69	166	64	65	36	62	65	64	59	93	
77	81	116	91	71	46	81	61	63	43	
158	93	106	44	36	72	60	76	90	76	
63	130	77	49	106	100	40	83	95	119	
74	130	163	97	78	103	92	154	125	139	
126	99	145	84	200	149	105	200	157	174	
81	127	71	152	61	145	58	62	200	96	
150	83	152	75	170	175	82	71	170	173	


In [22]:
mm = Model(4, 50, 2)

In [25]:
mm.l1.W.data.var()

0.23128141

In [26]:
mm.l2.W.data.var()

0.020486917

In [27]:
mm.l3.W.data.var()

0.017133638