# REINFORCE - Monte Carlo Policy Gradient Method for Cartpole

The REINFORCE policy gradient is defined by $\nabla J(\theta)=E_\pi[G_t*\nabla log(\pi_\theta(stat))]$. Where $G_t$ is the discounted Monte Carlo reward ie. $G_t = \sum_{k=0}^T\gamma^k*r_{t+k}$ 

To do the policy gradient in pytorch, here it calculate the policy gradient esitmator as $loss_{policy}=\sum_{t=0}^T[log(\pi_\theta(state_t))*G_t]$ and just call the backward function on the loss tensor / variable.

The base algorithm of REINFORCE suffer from high variance as it estimate the advantage using $G_t = \sum_{k=0}^T\gamma^k*r_{t+k}$. An 
improved version of REINFORCE add a state-value function estimator as baseline to reduce the variance. The REINFORCE_wBaseline class after the vallina REINFORCE class added a neural network as a state-value function approximator. The new policy gradient esitmator with baseline is: $loss_{policy}=\sum_{t=0}^T[log(\pi_\theta(state_t))*(G_t - V(state_t))]$. 

For the gradient estimator of state-value function, it can simply reuse the same monte carlo reward to caculate the mean square error(MSE): $loss_v=MSE(G_t - V(state_t))$ or use TD error: $loss_v=MSE(r_{t+1} + \gamma * V(state_{t+1}) - V(state_t))$. Here in the implementation it reuse the monte carlo reward $G_t$.

In [1]:
import os
import gym
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.distributions import Categorical, Normal

In [2]:
use_cuda = torch.cuda.is_available()
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor

In [3]:
def save_torch_model(model, filename):
  if not os.path.exists(os.path.dirname(filename)):
    os.makedirs(os.path.dirname(filename))
  torch.save(model.state_dict(), filename)

def load_torch_model(model, filename):
  model.load_state_dict(torch.load(filename))


In [4]:
class PolicyNet_discret(nn.Module):
  def __init__(self, input_size, output_size):
    super(PolicyNet_discret,self).__init__()
    self.l1_linear = nn.Linear(input_size, 128)
    self.l2_linear = nn.Linear(128, 64)
    self.l3_linear = nn.Linear(64, output_size)
    nn.init.kaiming_normal_(self.l1_linear.weight)
    nn.init.kaiming_normal_(self.l2_linear.weight)
    self.l3_linear.weight.data.zero_()
    
  def forward(self,x):
    out = F.relu(self.l1_linear(x))
    out = F.relu(self.l2_linear(out))
    out = F.softmax(self.l3_linear(out),dim=0)
    return out

In [5]:
class ValueNet(nn.Module):
  def __init__(self, input_size):
    super(ValueNet,self).__init__()
    self.l1_linear = nn.Linear(input_size, 128)
    self.l2_linear = nn.Linear(128, 64)
    self.l3_linear = nn.Linear(64, 1)
    nn.init.kaiming_normal_(self.l1_linear.weight)
    nn.init.kaiming_normal_(self.l2_linear.weight)
    self.l3_linear.weight.data.zero_()

  def forward(self, x):
    out = F.relu(self.l1_linear(x))
    out = F.relu(self.l2_linear(out))
    out = self.l3_linear(out)
    return out    

In [28]:
class REINFORCE():
  def __init__(self, env, steps_in_state = 2):
    self.steps_in_state = steps_in_state
    self.policy = PolicyNet_discret(env.observation_space.shape[0] * steps_in_state,env.action_space.n)
    self.env = env
    self._gamma = 0.96

  def init_state(self, env_state):
    self.running_state_seq = [env_state] * self.steps_in_state
    self.running_state = FloatTensor(self.running_state_seq).view(-1)

  def add_state(self, env_state):
    self.running_state_seq = self.running_state_seq[1:]
    self.running_state_seq.append(env_state)
    self.running_state = FloatTensor(self.running_state_seq).view(-1)

  def get_action(self):
    action = self.policy(self.running_state)
    return action.argmax().item()

  def pick_action(self, state):
    probs = self.policy(state)
    action_dist = Categorical(probs)
    action = action_dist.sample()
    action = action.item()
    return (action, action_dist.log_prob(FloatTensor([action])))
  
  def update_policy(self, rollout):
    (states, actions, rewards, log_probs) = zip(*rollout)
    MC_rewards = []
    R = 0
    for r in rewards[::-1]:
      R = r + self._gamma * R
      MC_rewards.insert(0, R)
    loss = []
    for (log_prob, reward) in zip(log_probs, MC_rewards):
      loss.append(-log_prob*reward)
    loss = torch.stack(loss).sum()
    self.optimizer.zero_grad()
    loss.backward()
    self.optimizer.step()

  def train(self, env, episode, lr=1e-3, checkpoint=100):
    self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr, weight_decay=1e-3)
    best_score = 0
    running_score = None
    for i in range(episode):
      s0 = env.reset()
      self.init_state(s0)
      state = self.running_state
      rollout = []
      episode_ended = False
      score = 0
      while not episode_ended:
        (action, log_prob) =  self.pick_action(state)
        (s1, reward, episode_ended, info) = env.step(action)
        rollout.append((state, action, reward, log_prob))
        self.add_state(s1)
        next_state = self.running_state
        if episode_ended:
          ended = 1
        else:
          ended = 0
        s0 = s1
        state = next_state
        score += reward
    
      if running_score == None:
        running_score = score
      else:
        running_score = running_score * 0.9 + 0.1 * score
        
      if (i + 1) % checkpoint == 0:
        if running_score > best_score:
          save_torch_model(self.policy, 'model/reinforce_cartpole_best.pth')
          best_score = running_score
        save_torch_model(self.policy,'model/reinforce_cartpole_iter_%d.pth' %(i+1))
        print(i+1,': running_score:', running_score)

      self.update_policy(rollout)

In [29]:
env = gym.make('CartPole-v0')
agent = REINFORCE(env)

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m


In [30]:
agent.train(env, episode=2000, lr=7e-5, checkpoint=200)

200 : running_score: 21.269940462098422
400 : running_score: 29.375980104308145
600 : running_score: 50.05871229111651
800 : running_score: 69.30404227531287
1000 : running_score: 97.64375629710904
1200 : running_score: 141.81128305453186
1400 : running_score: 164.22592458012318
1600 : running_score: 172.4267281043214
1800 : running_score: 191.29587196835206
2000 : running_score: 161.36453303726924


In [42]:
class REINFORCE_wBaseline():
  def __init__(self, env, steps_in_state = 2):
    self.steps_in_state = steps_in_state
    self.policy = PolicyNet_discret(env.observation_space.shape[0] * steps_in_state,env.action_space.n)
    self.value = ValueNet(env.observation_space.shape[0] * steps_in_state)
    self.env = env
    self._gamma = 0.96
    
  def predict_value(self, state):
    return self.value(state)
  
  def predict_action(self, state):
    return self.policy(state)
  
  def pick_action(self, state):
    probs = self.predict_action(state)
    action_dist = Categorical(probs)
    action = action_dist.sample()
    action = action.item()
    return (action, action_dist.log_prob(FloatTensor([action])))
  
  def update_policy_and_value(self, rollout):
    (states, actions, rewards, log_probs) = zip(*rollout)
    
    MC_rewards = []
    R = 0
    for r in rewards[::-1]:
      R = r + self._gamma * R
      MC_rewards.insert(0, R)
      
    value_prediction = self.value(torch.stack(states))
    value_loss = F.mse_loss(value_prediction, FloatTensor(MC_rewards).view(-1,1))
    
    policy_loss = []
    for (log_prob, reward, baseline) in zip(log_probs, MC_rewards, value_prediction.view(-1).tolist()):
      policy_loss.append(-log_prob*(reward - baseline))
    policy_loss = torch.stack(policy_loss).sum()
    
    self.value_optimizer.zero_grad()
    value_loss.backward()
    self.value_optimizer.step()
    
    self.policy_optimizer.zero_grad()
    policy_loss.backward()
    self.policy_optimizer.step()

  def train(self, env, episode=1000, lr=1e-3, lr_policy=None, lr_value=None, checkpoint=100):
    lr_policy = lr if lr_policy == None else lr_policy
    lr_value = lr if lr_value == None else lr_value
    self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr_policy, weight_decay=1e-3)
    self.value_optimizer = torch.optim.Adam(self.value.parameters(), lr=lr_value, weight_decay=1e-3)
    best_score = -99999
    running_score = None
    for i in range(episode):
      s0 = env.reset()
      seq = [s0] * self.steps_in_state
      rollout = []
      state = FloatTensor(seq).view(-1)
      episode_ended = False
      score = 0
      while not episode_ended:
        (action, log_prob) =  self.pick_action(state)
        (s1, reward, episode_ended, info) = env.step(action)
        rollout.append((state, action, reward, log_prob))
        seq = seq[1:]
        seq.append(s1)
        next_state = FloatTensor(seq).view(-1)
        if episode_ended:
          ended = 1
        else:
          ended = 0
        s0 = s1
        state = next_state
        score += reward
        
      if running_score == None:
        running_score = score
      else:
        running_score = running_score * 0.9 + score * 0.1
        
      if (i + 1) % checkpoint == 0:
        if running_score > best_score:
          save_torch_model(self.policy, 'model/reinforce_cartpole_best.pth')
          best_score = running_score
        save_torch_model(self.policy,'model/reinforce_cartpole_iter_%d.pth' %(i+1))
        print(i+1,': running_score:', running_score)
        
      self.update_policy_and_value(rollout)

In [40]:
env = gym.make('CartPole-v0')
agent = REINFORCE_wBaseline(env)

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m


In [41]:
agent.train(env, episode=2000, lr=1e-4, checkpoint=200)

200 : running_score: 22.324116660275386
400 : running_score: 43.74702665869132
600 : running_score: 95.82207029866576
800 : running_score: 163.71738047252506
1000 : running_score: 184.3465488749348
1200 : running_score: 171.7118348831992
1400 : running_score: 192.37269649593756
1600 : running_score: 179.388897217873
1800 : running_score: 194.5123934497959
2000 : running_score: 193.0464476556347
