In [None]:
import numpy as np
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import namedtuple

In [None]:
GAMMA = 0.99
LEARNING_RATE = 0.001

In [None]:
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward', 'done'))

In [None]:
class REINFORCEAgent:
    def __init__(self, num_states, num_actions):
        self.num_states = num_states
        self.num_actions = num_actions
        self.action_list = np.arange(num_actions)
        self.transitions = []
        self.model = self._get_model()
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=LEARNING_RATE)
        
    def get_action(self, state, episode=None):
        state = torch.FloatTensor(state).view(1, -1)
        prob_list = self.model(state)
        return np.random.choice(self.action_list, p=prob_list.detach().numpy().ravel())
    
    def train(self, state, action, next_state, reward, done, episode=None):
        self.transitions.append(Transition(state, action, next_state, reward, done))
        
        if not done:
            return None
        
        batch = Transition(*zip(*self.transitions))

        return_list = self._get_returns(batch.reward, GAMMA)
        
        state_batch = torch.FloatTensor(batch.state).view(-1, self.num_states)
        action_batch = torch.LongTensor(batch.action).view(-1, 1)
        return_batch = torch.FloatTensor(return_list).view(-1, 1)
#         next_state_batch = torch.FloatTensor(batch.next_state).view(-1, self.num_states)
#         done_batch = torch.ByteTensor(batch.done).view(-1, 1)

        policy_batch = self.model(state_batch)
        selected_action_prob_batch = policy_batch.gather(1, action_batch)
        log_selected_action_prob_batch = torch.log(selected_action_prob_batch)
        
        loss = -(return_batch * log_selected_action_prob_batch).mean()
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        self.transitions = []
        
        return loss.item()
      
    
    def _get_model(self):
        model = nn.Sequential()
        model.add_module('fc1', nn.Linear(self.num_states, 32))
        model.add_module('relu1', nn.ReLU())
        model.add_module('fc2', nn.Linear(32, 32))
        model.add_module('relu2', nn.ReLU())
        model.add_module('fc3', nn.Linear(32, self.num_actions))
        model.add_module('softmax1', nn.Softmax(dim=1))
        return model
    
    def _get_returns(self, rewards, gamma=0.99):
        g_list = []
        g = 0.0
        for r in reversed(rewards):
            g = r + gamma * g
            g_list.append(g)
        g_list = np.array(g_list[::-1])
        return g_list - g_list.mean()

In [None]:
env = gym.make('CartPole-v0')

In [None]:
agent = REINFORCEAgent(env.observation_space.shape[0], env.action_space.n)

In [None]:
N_EPOCH = 500

In [None]:
continues_sucess = 0

for episode in range(N_EPOCH):
    done = False
    state = env.reset()
    step = 0
    while not done:
        action = agent.get_action(state)
        next_state, reward, done, _ = env.step(action)
        if done:
            if step < 195:
                reward = -1.0
                continues_sucess = 0
            else:
                reward = 1.0
                continues_sucess += 1
        else:
            reward = 0.0
            
        agent.train(state, action, next_state, reward, done, episode)
        
        state = next_state
        step += 1
        
    if continues_sucess >= 10:
        break
        
    print("Episode {} Step {}".format(episode, step))