In [1]:
import numpy as np
import gym
import torch
import torch.nn as nn
import torch.nn.functional as F
from collections import namedtuple

In [2]:
CRITIC_LEARNING_RATE = 0.001
ACTOR_LEARNING_RATE = 0.0001
GAMMA = 0.9

In [3]:
class A2CAgent:
    def __init__(self, num_states, num_actions):
        self.num_states = num_states
        self.num_actions = num_actions
        self.action_list = np.arange(num_actions)
        
        self.critic_model = self._get_critic_model()
        self.critic_optimizer = torch.optim.Adam(self.critic_model.parameters(), lr=CRITIC_LEARNING_RATE)
        
        self.actor_model = self._get_actor_model()
        self.actor_optimizer = torch.optim.Adam(self.actor_model.parameters(), lr=ACTOR_LEARNING_RATE)

    def get_action(self, state, train=True, episode = None):
        state = torch.FloatTensor(state).view(1, -1)
        policy = self.actor_model(state)
        return np.random.choice(self.action_list, p=policy.detach().numpy().ravel())
    
    def train(self, state, action, next_state, reward, done, episode=None):
        state = torch.FloatTensor(state).view(1, -1)
        action = torch.LongTensor([action]).view(1, -1)
        next_state = torch.FloatTensor(next_state).view(1, -1)
        reward = torch.FloatTensor([reward]).view(1, -1)
        done = torch.ByteTensor([done]).view(1, -1)
        
        state_value = self.critic_model(state)
        next_state_value = self.critic_model(next_state)
        critic_target = reward + GAMMA * (1 - done) * next_state_value
        
        critic_loss = F.smooth_l1_loss(state_value, critic_target)
        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()
        
        advantage_value = (critic_target - state_value).detach()
        
        policy = self.actor_model(state)
        selected_action_prob = policy.gather(1, action)
        log_selected_action_prob = torch.log(selected_action_prob)
        
        actor_loss = -(log_selected_action_prob * advantage_value)
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()
        
        return actor_loss.item()
    
    def _get_critic_model(self):
        model = nn.Sequential()
        model.add_module('fc1', nn.Linear(self.num_states, 32))
        model.add_module('relu1', nn.ReLU())
        model.add_module('fc2', nn.Linear(32, 32))
        model.add_module('relu2', nn.ReLU())
        model.add_module('fc3', nn.Linear(32, 1))
        return model     
    
    def _get_actor_model(self):
        model = nn.Sequential()
        model.add_module('fc1', nn.Linear(self.num_states, 32))
        model.add_module('relu1', nn.ReLU())
        model.add_module('fc2', nn.Linear(32, 32))
        model.add_module('relu2', nn.ReLU())
        model.add_module('fc3', nn.Linear(32, self.num_actions))
        model.add_module('softmax1', nn.Softmax(dim=1))
        return model

In [4]:
env = gym.make('CartPole-v0')

In [5]:
agent = A2CAgent(env.observation_space.shape[0], env.action_space.n)

In [7]:
N_EPOCH = 500

In [None]:
for episode in range(N_EPOCH):
    done = False
    state = env.reset()
    step = 0
    while not done:
        action = agent.get_action(state)
        next_state, reward, done, _ = env.step(action)
        if done:
            if step < 195:
                reward = -1.0
            else:
                reward = 1.0
        else:
            reward = 0.0
            
        agent.train(state, action, next_state, reward, done)
        
        state = next_state
        step += 1
        
    print("Episode {} Step {}".format(episode, step))

Episode 0 Step 14
Episode 1 Step 64
Episode 2 Step 22
Episode 3 Step 26
Episode 4 Step 24
Episode 5 Step 18
Episode 6 Step 22
Episode 7 Step 32
Episode 8 Step 31
Episode 9 Step 16
Episode 10 Step 29
Episode 11 Step 46
Episode 12 Step 40
Episode 13 Step 21
Episode 14 Step 59
Episode 15 Step 20
Episode 16 Step 21
Episode 17 Step 12
Episode 18 Step 19
Episode 19 Step 11
Episode 20 Step 24
Episode 21 Step 22
Episode 22 Step 63
Episode 23 Step 28
Episode 24 Step 13
Episode 25 Step 13
Episode 26 Step 44
Episode 27 Step 10
Episode 28 Step 23
Episode 29 Step 36
Episode 30 Step 24
Episode 31 Step 19
Episode 32 Step 17
Episode 33 Step 25
Episode 34 Step 19
Episode 35 Step 80
Episode 36 Step 36
Episode 37 Step 10
Episode 38 Step 32
Episode 39 Step 14
Episode 40 Step 13
Episode 41 Step 16
Episode 42 Step 19
Episode 43 Step 29
Episode 44 Step 37
Episode 45 Step 17
Episode 46 Step 37
Episode 47 Step 20
Episode 48 Step 49
Episode 49 Step 38
Episode 50 Step 37
Episode 51 Step 17
Episode 52 Step 40
Epi