In [1]:
%matplotlib inline
import numpy as np
import matplotlib
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.autograd import Variable
import torchvision.transforms as T

In [2]:
import gym
from gym.wrappers import Monitor
gym_env = "CartPole-v0"
env = gym.envs.make(gym_env)
state = env.reset()
print(env.action_space.n)
print(state)

2
[-0.02371382 -0.01289756  0.03717611 -0.04757698]


In [3]:
use_gpu = torch.cuda.is_available()
print('Use GPU: {}'.format(use_gpu))

Use GPU: True


In [4]:
import glob
import io
import base64
from IPython.display import HTML
from IPython import display as ipythondisplay
from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()

<Display cmd_param=['Xvfb', '-br', '-nolisten', 'tcp', '-screen', '0', '1400x900x24', ':1013'] cmd=['Xvfb', '-br', '-nolisten', 'tcp', '-screen', '0', '1400x900x24', ':1013'] oserror=None return_code=None stdout="None" stderr="None" timeout_happened=False>

### Policy Class

In [5]:
class Policy(nn.Module):
    def __init__(self, state_size, action_size, hidden_dim):
        super(Policy, self).__init__()
        self.fc1 = nn.Linear(state_size, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, action_size)
        
    def forward(self, x):
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.softmax(x, dim=1)
        return x

### REINFORCE

In [6]:
class REINFORCE:
    def __init__(self):
        self.policy_model = Policy(4, 2, 128)
        self.policy_model = self.policy_model.cuda()
        self.policy_model.train()
        self.optimizer = optim.Adam(self.policy_model.parameters(), lr=1e-3)
        
    def select_action(self, state):
        state_tensor = torch.Tensor([state]) # [1 x 4]
        state_tensor = Variable(state_tensor).cuda() # [1 x 4]
        probs = self.policy_model(state_tensor) # [1 x 2]
        probs_cpu = probs.data.cpu()[0] # Tensor cpu [2]
        # Sampling
        if np.random.rand() < probs_cpu[0]:
            action = 0
        else:
            action = 1
        prob = probs[:,action].view(1, -1) # Tensor gpu [1 x 1]
        log_prob = prob.log() # Tensor
        return action, log_prob

    def update_parameters(self, rewards, log_probs, gamma):
        R = 0
        policy_loss = []
        for i in reversed(range(len(rewards))):
            R = rewards[i] + gamma * R
            policy_loss.append(-log_probs[i] * R)
        policy_loss = torch.cat(policy_loss).sum()
        self.optimizer.zero_grad()
        policy_loss.backward()
        self.optimizer.step()

### Start Training

In [7]:
#env = gym.envs.make(gym_env)
epi_rewards = []
max_episode = 5
max_steps = 1000
gamma = 0.99
env = Monitor(gym.make("CartPole-v0"), './video', force=True)
#, video_callable=lambda episode_id: episode_id == max_episode - 1)

agent = REINFORCE()
for i_episode in range(max_episode):
    state = env.reset()
    log_probs = []
    rewards = []
    for t in range(max_steps):
        action, log_prob = agent.select_action(state)

        next_state, reward, done, _ = env.step(action)

        log_probs.append(log_prob)
        rewards.append(reward)
        state = next_state
        if done:
            break

    agent.update_parameters(rewards, log_probs, gamma)
    # Max tot_reward of CartPole-v0 is 200.
    tot_reward = np.sum(rewards)
    epi_rewards.append(tot_reward)
    print("Episode: {}, reward: {}".format(i_episode, tot_reward), end="\r")

env.close()

ValueError: invalid literal for int() with base 10: ''

In [None]:
plt.plot(epi_rewards)
plt.show()

In [8]:
!pip list

Package              Version 
-------------------- --------
absl-py              0.7.1   
astor                0.7.1   
atari-py             0.1.7   
attrs                19.1.0  
backcall             0.1.0   
bleach               3.1.0   
box2d-py             2.3.8   
certifi              2019.3.9
cffi                 1.12.3  
chardet              3.0.4   
cycler               0.10.0  
decorator            4.4.0   
defusedxml           0.5.0   
EasyProcess          0.2.5   
entrypoints          0.3     
future               0.17.1  
gast                 0.2.2   
greenlet             0.4.15  
grpcio               1.20.1  
gym                  0.12.1  
h5py                 2.9.0   
idna                 2.8     
ipdb                 0.12    
ipykernel            5.1.0   
ipython              7.3.0   
ipython-genutils     0.2.0   
ipywidgets           7.4.2   
jedi                 0.13.3  
Jinja2               2.10    
jsonschema           3.0.1   
jupyter              1.0.0   
jupyter-cl

In [9]:
!pip install gym pyglet

