# Actor Critic Cartpole

In [1]:
%matplotlib inline
from IPython import display
from IPython.display import HTML
import matplotlib
import matplotlib.pyplot as plt
import matplotlib.animation as animation

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.distributions import Normal, Categorical

import numpy as np
import random
import os
import gym

In [2]:
use_cuda = torch.cuda.is_available()
FloatTensor = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
LongTensor = torch.cuda.LongTensor if use_cuda else torch.LongTensor
ByteTensor = torch.cuda.ByteTensor if use_cuda else torch.ByteTensor

In [3]:
def save_torch_model(model, filename):
  if not os.path.exists(os.path.dirname(filename)):
    os.makedirs(os.path.dirname(filename))
  torch.save(model.state_dict(), filename)

def load_torch_model(model, filename):
  model.load_state_dict(torch.load(filename))


In [4]:
def copy_grad(source, target):
  grads = []
  for param in source.parameters():
    grads.append(param.grad.clone())
  grads.reverse()
  for param in target.parameters():
    param.grad = grads.pop()

def zero_grad(model):
  for param in model.parameters():
    if type(param.grad) != type(None):
      param.grad.data.zero_()
      
def update_target(target_net, eval_net, tau):
  fast = eval_net.state_dict()
  slow = target_net.state_dict()
  for t in slow:
    slow[t] = slow[t] * (1. - tau) + fast[t] * tau

  target_net.load_state_dict(slow)

In [5]:
class PolicyNet_discret(nn.Module):
  def __init__(self, input_size, output_size):
    super(PolicyNet_discret,self).__init__()
    self.l1_linear = nn.Linear(input_size, 512)
    self.l2_linear = nn.Linear(512, 256)
    self.l3_linear = nn.Linear(256, output_size)
    nn.init.kaiming_normal_(self.l1_linear.weight)
    nn.init.kaiming_normal_(self.l2_linear.weight)
    self.l3_linear.weight.data.zero_()
    
  def forward(self,x):
    out = F.relu(self.l1_linear(x))
    out = F.relu(self.l2_linear(out))
    out = F.softmax(self.l3_linear(out),dim=0)
    return out

In [6]:
class ValueNet(nn.Module):
  def __init__(self, input_size):
    super(ValueNet,self).__init__()
    self.l1_linear = nn.Linear(input_size, 512)
    self.l2_linear = nn.Linear(512,256)
    self.l3_linear = nn.Linear(256, 1)
    nn.init.kaiming_normal_(self.l1_linear.weight)
    nn.init.kaiming_normal_(self.l2_linear.weight)
    self.l3_linear.weight.data.zero_()

  def forward(self, x):
    out = F.relu(self.l1_linear(x))
    out = F.relu(self.l2_linear(out))
    out = self.l3_linear(out)
    return out    

In [7]:
class ActorCritic():
  def __init__(self, env, steps_in_state = 2):
    self.steps_in_state = steps_in_state
    self.policy = PolicyNet_discret(env.observation_space.shape[0] * steps_in_state,env.action_space.n)
    self.value = ValueNet(env.observation_space.shape[0] * steps_in_state)
    if use_cuda:
      self.policy.cuda()
      self.value.cuda()
    self.env = env
    self._gamma = 0.96
    
  def pick_action(self, state):
    probs = self.policy(state)
    action_dist = Categorical(probs)
    action = action_dist.sample()
    action = action.item()
    return (action, action_dist.log_prob(FloatTensor([action])))
  
  def update_actor_critic(self, episode):
    (states, actions, rewards, next_states, log_probs, ended) = zip(*episode)
    
    rewards = FloatTensor(rewards)
    ended = FloatTensor(ended)
    state_value = self.value(torch.stack(states))
    next_state_value = self.value(torch.stack(next_states))
    target_value = rewards + (1 - ended) * self._gamma * next_state_value
    
    delta = target_value - state_value

    value_loss = F.mse_loss(state_value, target_value)
    
    policy_loss = []
    for log_prob, d in zip(log_probs, delta):
      policy_loss.append(-log_prob * d)
    policy_loss = torch.stack(policy_loss).sum()
    
    self.value_optimizer.zero_grad()
    value_loss.backward(retain_graph=True)
    self.value_optimizer.step()
    
    self.policy_optimizer.zero_grad()
    policy_loss.backward()
    self.policy_optimizer.step()

  def train(self, env, episode, lr=1e-3, target_copylr=1e-3, lr_policy=None, lr_value=None, checkpoint=100):
    lr_policy = lr if lr_policy == None else lr_policy
    lr_value = lr if lr_value == None else lr_value
    self.policy_optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr_policy, weight_decay=1e-3)
    self.value_optimizer = torch.optim.Adam(self.value.parameters(), lr=lr_value, weight_decay=1e-3)
    best_score = -99999
    running_score = None
    for i in range(episode):
      s0 = env.reset()
      seq = [s0] * self.steps_in_state
      state = FloatTensor(seq).view(-1)
      episode = []
      episode_ended = False
      score = 0
      while not episode_ended:
        (action, log_prob) =  self.pick_action(state)
        (s1, reward, episode_ended, info) = env.step(action)
        seq = seq[1:]
        seq.append(s1)
        next_state = FloatTensor(seq).view(-1)
        if episode_ended:
          ended = 1
        else:
          ended = 0
        episode.append((state, action, reward, next_state, log_prob, ended))

        s0 = s1
        state = next_state
        score += reward
        
      if running_score == None:
        running_score = score
      else:
        running_score = running_score * 0.9 + score * 0.1
        
      if (i + 1) % checkpoint == 0:
        is_best = False
        if running_score > best_score:
          is_best = True
          save_torch_model(self.policy, 'model/actor_critic_cartpole_policy_best.pth')
          best_score = running_score
        save_torch_model(self.policy,'model/actor_critic_cartpole_policy_iter_%d.pth' %(i+1))
        print('%d: running_score:%.2f, is_best:%s' %(i+1, running_score, is_best))

      self.update_actor_critic(episode)

In [8]:
env = gym.make('CartPole-v0')
agent = ActorCritic(env)

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m


In [9]:
agent.train(env, 1000, lr=1e-4, checkpoint=100)

100: running_score:24.29, is_best:True
200: running_score:28.67, is_best:True
300: running_score:50.39, is_best:True
400: running_score:131.17, is_best:True
500: running_score:147.21, is_best:True
600: running_score:198.62, is_best:True
700: running_score:155.24, is_best:False
800: running_score:180.74, is_best:False
900: running_score:194.96, is_best:False
1000: running_score:192.11, is_best:False
