In [1]:
import gym

In [80]:
import numpy as np
import functools

import random

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
GAME = 'CartPole-v1'

In [5]:
env = gym.make(GAME)
env.reset()
act_n = env.action_space.n
obs_n = env.observation_space.shape[0]

In [146]:
class BaseQNet(nn.Module):
    def __init__(self, in_, out_, hidden = 128):
        super(BaseQNet, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(in_, hidden),
            nn.ReLU(),
            nn.Linear(hidden, out_),
            nn.Sigmoid(),
        )
    
    def forward(self, x):
        return self.net(x)

In [152]:
LR = 0.01
EPOCH = 1000
SAMPLE = 50
bestN = 15
net = BaseQNet(obs_n, 1).to(device)

opt = optim.Adam(net.parameters(), lr=LR)
crt = nn.BCELoss()

In [None]:
for epoch in range(EPOCH):
    #sampling - simulate 50 scenario and take 10 bests.
    step_list = []
    for sample in range(SAMPLE):
        obs = env.reset()
        step = []
        while True:
            if random.random() < np.exp(-epoch/100):
                act = env.action_space.sample()
            else:
                with torch.no_grad():
                    act_v = net(torch.FloatTensor(obs).to(device))
                    act = int(torch.round(act_v).cpu().numpy()[0])
            next_obs, rew, done, _ = env.step(act)
            step.append((obs, act))
            obs = next_obs
            if done:
                break
        step_list.append(step)
    sorted_step = sorted(step_list, key=lambda x:len(x), reverse=True)
    best_step = sorted_step[:bestN]
    
    # training
    loss_sum = 0
    for step in best_step:
        step = list(zip(*step))
        obs = torch.FloatTensor(step[0]).to(device)
        act = torch.FloatTensor(step[1]).unsqueeze(1).to(device)
        
        pred_act = net(obs)
        loss = crt(pred_act, act)
        opt.zero_grad()
        loss.backward()
        opt.step()
        loss_sum += loss.data.cpu().numpy()
    
    print(epoch, functools.reduce(lambda prev, cur: (prev + len(cur)), sorted_step, 0)/SAMPLE, loss_sum)

0 20.4 10.444130778312683
1 22.62 10.448449313640594
2 21.9 10.369653284549713
3 23.84 10.430789232254028
4 25.54 10.414540410041809
5 23.24 10.316456079483032
6 24.78 10.223991513252258
7 33.08 10.170708298683167
8 26.62 10.19842004776001
9 27.82 10.126807510852814
10 26.56 10.31234472990036
11 27.2 10.133429825305939
12 30.34 10.063583850860596
13 27.04 10.194552421569824
14 24.0 10.260365009307861
15 33.46 10.029829800128937
16 35.26 10.011495232582092
17 33.1 9.920941054821014
18 37.4 9.948222756385803
19 34.2 9.91957277059555
20 30.86 10.303567826747894
21 34.32 9.983331084251404
22 36.22 9.95633590221405
23 35.22 9.905102849006653
24 38.9 9.750160813331604
25 30.84 10.04047167301178
26 37.36 9.84970897436142
27 40.14 9.848584711551666
28 44.1 9.636337757110596
29 38.28 9.763657331466675
30 51.46 9.692278027534485
31 41.88 9.671737849712372
32 47.94 9.57748955488205
33 53.9 9.513760805130005
34 47.38 9.626357913017273
35 53.42 9.575250267982483
36 56.1 9.464346885681152
37 54.46 9