<a href="https://colab.research.google.com/github/henry-bokyum-kim/NNStudy/blob/master/%5BRL%5D%5BMW%5DCartPole-V1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install gym pyvirtualdisplay > /dev/null 2>&1
!apt-get install -y xvfb python-opengl ffmpeg > /dev/null 2>&1
!apt-get install x11-utils

import gym
from gym import logger as gymlogger
from gym.wrappers import Monitor
gymlogger.set_level(40)

import math
import glob
import io
import base64
from IPython.display import HTML

from pyvirtualdisplay import Display
from IPython import display as ipythondisplay

display = Display(visible=0, size=(1400,900),)
display.start()

Reading package lists... Done
Building dependency tree       
Reading state information... Done
x11-utils is already the newest version (7.7+3build1).
The following package was automatically installed and is no longer required:
  libnvidia-common-430
Use 'apt autoremove' to remove it.
0 upgraded, 0 newly installed, 0 to remove and 7 not upgraded.


<Display cmd_param=['Xvfb', '-br', '-nolisten', 'tcp', '-screen', '0', '1400x900x24', ':1001'] cmd=['Xvfb', '-br', '-nolisten', 'tcp', '-screen', '0', '1400x900x24', ':1001'] oserror=None return_code=None stdout="None" stderr="None" timeout_happened=False>

In [0]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim

HIDDEN = 128
BATCH_SIZE = 32
PERCENTILE = 50

class Net(nn.Module):
  def __init__(self, obs_size, hidden, action_size):
    super(Net, self).__init__()
    self.net=nn.Sequential(
        nn.Linear(obs_size, hidden),
        nn.ReLU(),
        nn.Linear(hidden, action_size)
    )

  def forward(self, x):
    return self.net(x)

In [0]:
import copy

def get_batch(batch_size, env, net):
  batch = []
  reward = 0.0
  obs = env.reset()
  steps = []
  sm = nn.Softmax(dim=1)
  
  while True:
    obs_v = torch.FloatTensor([obs])
    action_v = sm(net(obs_v)).data.numpy()[0]
    action = np.random.choice(env.action_space.n, p=action_v)

    next_obs, rew, done, _ = env.step(action)
    reward += rew
    steps.append((obs,action))

    if done:
      batch.append((reward, steps))
      reward = 0.0
      next_obs = env.reset()
      steps = []
      if len(batch) == batch_size:
        yield batch
        batch=[]
    obs = next_obs

In [0]:
def filter_batch(batch):
  rews = list(map(lambda episode: episode[0],batch))
  rews_threshold = np.percentile(rews, PERCENTILE)
  rews_mean = np.mean(rews)

  ret_action = []
  ret_obs = []

  for scene in batch:
    if scene[0] < rews_threshold:
      continue
    ret_obs.extend(map(lambda step: step[0], scene[1]))
    ret_action.extend(map(lambda step: step[1], scene[1]))

  return torch.FloatTensor(ret_obs), torch.LongTensor(ret_action), rews_threshold, rews_mean

In [0]:
def show_video():
  mp4list = glob.glob('video/*.mp4')
  if len(mp4list) > 0:
    for mp4 in mp4list:
      video = io.open(mp4, 'r+b').read()
      print(mp4)
      encoded = base64.b64encode(video)
      ipythondisplay.display(HTML(data='''<video alt="test" autoplay loop controls style="height: 300px;">
                                  <source src="data:video/mp4;base64,{0}" type="video/mp4"/>
                                  </video>'''.format(encoded.decode('ascii'))))
  else:
    print("Could not find video")


In [5]:
env = gym.wrappers.Monitor(gym.make("CartPole-v1"), "video", force=True, video_callable=lambda c:c%100 ==0)
net = Net(env.observation_space.shape[0], HIDDEN, env.action_space.n)
CELoss = nn.CrossEntropyLoss()
opt = optim.Adam(net.parameters(), 0.01)

for i,batch in enumerate(get_batch(BATCH_SIZE, env, net)):
  obs, action, threshold, mean = filter_batch(batch)
  opt.zero_grad()

  action_out = net(obs)
  loss = CELoss(action_out, action)
  loss.backward()
  opt.step()
  print("%d, loss = %.4f, th = %.2f, mean = %.2f"%(i, loss, threshold, mean))
  if mean>499:
    print("Solved")
    break

0, loss = 0.6933, th = 17.50, mean = 22.12
1, loss = 0.6799, th = 20.50, mean = 24.91
2, loss = 0.6672, th = 22.50, mean = 27.97
3, loss = 0.6532, th = 35.50, mean = 41.41
4, loss = 0.6378, th = 36.00, mean = 40.91
5, loss = 0.6363, th = 40.50, mean = 48.75
6, loss = 0.6114, th = 42.50, mean = 48.22
7, loss = 0.6234, th = 51.00, mean = 56.41
8, loss = 0.6004, th = 44.00, mean = 52.75
9, loss = 0.5969, th = 71.00, mean = 68.62
10, loss = 0.5782, th = 56.00, mean = 67.16
11, loss = 0.5743, th = 65.00, mean = 66.62
12, loss = 0.5717, th = 67.00, mean = 77.44
13, loss = 0.5648, th = 66.50, mean = 71.22
14, loss = 0.5490, th = 67.00, mean = 76.06
15, loss = 0.5534, th = 76.00, mean = 80.62
16, loss = 0.5400, th = 78.00, mean = 90.47
17, loss = 0.5270, th = 71.50, mean = 74.31
18, loss = 0.5289, th = 93.00, mean = 99.62
19, loss = 0.5327, th = 94.00, mean = 105.97
20, loss = 0.5263, th = 122.50, mean = 119.84
21, loss = 0.5279, th = 119.50, mean = 137.19
22, loss = 0.5238, th = 135.00, mean 

In [6]:
batch[0]

(500.0,
 [(array([-0.04730696,  0.02757794, -0.03917624,  0.03374646]), 1),
  (array([-0.0467554 ,  0.22323915, -0.03850131, -0.27103504]), 0),
  (array([-0.04229062,  0.02868715, -0.04392201,  0.0092601 ]), 0),
  (array([-0.04171687, -0.16577827, -0.04373681,  0.287768  ]), 1),
  (array([-0.04503244,  0.02993922, -0.03798145, -0.01838201]), 0),
  (array([-0.04443365, -0.16461803, -0.03834909,  0.26207953]), 1),
  (array([-0.04772601,  0.03102975, -0.0331075 , -0.04244824]), 0),
  (array([-0.04710542, -0.16360219, -0.03395646,  0.23960786]), 1),
  (array([-0.05037746,  0.03198797, -0.02916431, -0.06358955]), 0),
  (array([-0.0497377 , -0.16270396, -0.0304361 ,  0.21975107]), 0),
  (array([-0.05299178, -0.35737792, -0.02604108,  0.50267994]), 1),
  (array([-0.06013934, -0.16189879, -0.01598748,  0.20190545]), 1),
  (array([-0.06337732,  0.03344812, -0.01194937, -0.09577758]), 0),
  (array([-0.06270835, -0.16150055, -0.01386492,  0.19311151]), 0),
  (array([-0.06593836, -0.35642144, -0.0