### Utils

In [None]:
!apt-get install -y xvfb python-opengl x11-utils
!pip install pyvirtualdisplay==0.2.5

Reading package lists... Done
Building dependency tree       
Reading state information... Done
x11-utils is already the newest version (7.7+3build1).
python-opengl is already the newest version (3.1.0+dfsg-1).
xvfb is already the newest version (2:1.19.6-1ubuntu4.9).
0 upgraded, 0 newly installed, 0 to remove and 37 not upgraded.


In [None]:
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.distributions import Categorical, Normal

import random, datetime, gym, os, time, psutil, cv2
import numpy as np
from torchsummary import summary as summary_
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import display, HTML
from collections import deque, namedtuple

%matplotlib inline

In [None]:
from pyvirtualdisplay import Display
v_display = Display(visible=0, size=(400, 300))
v_display.start()

import warnings
warnings.filterwarnings("ignore", category=np.VisibleDeprecationWarning)

def display_animation(anim):
    plt.close(anim._fig)
    return HTML(anim.to_jshtml())

def display_frames_as_gif(frames):
    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])
    
    anim = animation.FuncAnimation(
        plt.gcf(), animate, frames=len(frames), interval=10)
    display(display_animation(anim))

# Actor Critic Model
- mountain car
- discrete action space

In [None]:
class ActorCritic(nn.Module):
    def __init__(self, n_state, n_action):
        super(ActorCritic, self).__init__()
        self.shared_layer = nn.Linear(n_state,256)
        self.policy_layer = nn.Linear(256,n_action)
        self.value_layer = nn.Linear(256,1)
        self.optimizer = optim.Adam(self.parameters(), lr=learning_rate)
        self.data = []
        
    def policy(self, state, softmax_dim = 0):
        x = F.relu(self.shared_layer(state))
        x = self.policy_layer(x)
        prob = F.softmax(x, dim=softmax_dim)
        return Categorical(prob)
    
    def get_action(self, state):
        dist = self.policy(torch.from_numpy(state).float())
        action = dist.sample().item()
        return action
    
    def value(self, state):
        x = F.relu(self.shared_layer(state))
        value = self.value_layer(x)
        return value
    
    def save(self, transition):
        self.data.append(transition)
        
    def get_batch(self):
        state_list, action_list, reward_list, next_state_list, done_list = [], [], [], [], []
        for experience in self.data:
            state, action, reward, next_state, done = experience
            state_list.append(state)
            action_list.append([action])
            reward_list.append([reward/100.0])
            next_state_list.append(next_state)
            done_list.append([0.0 if done else 1.0])
        
        state_batch = torch.tensor(state_list, dtype=torch.float)
        action_batch = torch.tensor(action_list)
        reward_batch = torch.tensor(reward_list, dtype=torch.float)
        next_state_batch = torch.tensor(next_state_list, dtype=torch.float)
        done_batch = torch.tensor(done_list, dtype=torch.float)
        self.data = []

        return state_batch, action_batch, reward_batch, next_state_batch, done_batch
  
    def update(self):
        state, action, reward, next_state, done = self.get_batch()
        td_target = reward + gamma * self.value(next_state) * done
        td_error = td_target - self.value(state)
        
        dist = self.policy(state, softmax_dim=1)
        prob_action = dist.probs.gather(1, action)
        loss_actor = - torch.log(prob_action) * td_error.detach()
        # loss_critic = F.smooth_l1_loss(self.value(state), td_target.detach())
        loss_critic = (self.value(state) - td_target.detach())**2
        loss = loss_actor + loss_critic

        self.optimizer.zero_grad()
        loss.mean().backward()
        self.optimizer.step()


# Train

In [None]:
#Hyperparameters
gamma = 0.98
update_every = 10
learning_rate=0.0002

In [None]:
def train(agent, env, eval_env, print_every=200, num_episodes=10000, learning_rate=0.0002, display=True):
    episodic_reward = np.zeros([num_episodes])
    for n_epi in range(num_episodes):
        done = False
        state = env.reset()
        while not done:
            for t in range(update_every):
                action = agent.get_action(state)
                next_state, reward, done, info = env.step(action)
                agent.save((state, action, reward, next_state,done))
                state = next_state
                episodic_reward[n_epi] += reward
                if done:
                    break                     
            agent.update()
        if (n_epi==0) or ((n_epi+1) % print_every == 0):
            print("# of episode :{}, return : {:.1f}".format(n_epi, episodic_reward[n_epi]))
            if display:
                print(f"[Eval. start] step:[{n_epi + 1}/{num_episodes}]")

                state, done, ep_ret, ep_len = eval_env.reset(), False, 0, 0
                frames = []
                while not done:
                    action = agent.get_action(state)
                    state, reward, done, _ = eval_env.step(action)
                    frame = eval_env.render(mode='rgb_array')
                    texted_frame = cv2.putText(
                        img=np.copy(frame),
                        text='epoch:{}'.format(n_epi+1, ep_len),
                        org=(300,100), fontFace=3, fontScale=1, color=(0,0,255), thickness=3)
                    frames.append(texted_frame)
                    ep_ret += reward # return
                    ep_len += 1 # length 
                print(f"[Eval. done]")
                
                #Display GIF
                display_frames_as_gif(frames)


In [None]:
cart_pole = gym.make('CartPole-v0')
cart_pole_eval = gym.make('CartPole-v0')

In [None]:
cart_pole.observation_space

Box(-3.4028234663852886e+38, 3.4028234663852886e+38, (4,), float32)

In [None]:
cart_pole.action_space

Discrete(2)

In [None]:
actor_critic = ActorCritic(n_state=4, n_action=2)

In [None]:
train(actor_critic, cart_pole, cart_pole_eval, 200, 1000, 0.0002)