In [1]:
import numpy as np
import tensorflow as tf
import gym
import tensorflow_probability as tfp

In [7]:
class ActorCritic(tf.keras.Model):
    def __init__(self, action_dim):
        super().__init__()
        self.fc1 = tf.keras.layers.Dense(512, activation="relu")
        self.fc2 = tf.keras.layers.Dense(128, activation="relu")
        self.critic = tf.keras.layers.Dense(1, activation=None)
        self.actor = tf.keras.layers.Dense(action_dim, activation=None)
        
    def call(self, input_data):
        x = self.fc1(input_data)
        x1 = self.fc2(x)
        actor = self.actor(x1)
        critic = self.critic(x1)
        return actor, critic
    

In [3]:
class Agent:
    def __init__(self, action_dim=4, gamma=0.99):
        self.gamma = gamma 
        self.opt = tf.keras.optimizers.Adam(learning_rate=1e-4)
        self.actor_critic = ActorCritic(action_dim)
        
    def get_action(self, state):
        _, action_probs = self.actor_critic(np.array([state]))
        action_probs = tf.nn.softmax(action_probs)
        action_probs = action_probs.numpy()
        dist = tfp.distributions.Categorical(
            probs=action_probs, dtype=tf.float32
        )
        action = dist.sample()
        return int(action.numpy()[0])
    
    def actor_loss(self, prob, action, td):
        prob = tf.nn.softmax(prob)
        dist = tfp.distributions.Categorical(probs=prob, dtype=tf.float32)
        log_prob = dist.log_prob(action)
        loss = -log_prob * td
        return loss
    
    def learn(self, state, action, reward, next_state, done):
        state = np.array([state])
        next_state = np.array([next_state])
        
        with tf.GradientTape() as tape:
            value, action_probs = self.actor_critic(state, training=True)
            value_next_st, _ = self.actor_critic(next_state, training=True)
            td = reward + self.gamma * value_next_st * (1 - int(done)) - value
            actor_loss = self.actor_loss(action_probs, action, td)
            critic_loss = td ** 2
            total_loss = actor_loss + critic_loss
        grads = tape.gradient(total_loss, self.actor_critic.trainable_variables)
        self.opt.apply_gradients(zip(grads, self.actor_critic.trainable_variables))
        return total_loss

In [4]:
def train(agent, env, episodes, render=True):
    
    for episode in range(episodes):
        
        done = False
        state = env.reset()
        total_reward = 0
        all_loss = []
        
        while not done:
            action = agent.get_action(state)
            next_state, reward, done, _ = env.step(action)
            loss = agent.learn(state, action, reward, next_state, done)
            all_loss.append(loss)
            state = next_state
            total_reward += reward
            if render: 
                env.render()
            if done:
                print("\n")
            print(f"Episode {episode}, ep_reward: {total_reward}", end="\r")

In [9]:
env = gym.make("CartPole-v1")
agent = Agent(env.action_space.n)
num_episodes = 200
train(agent, env, num_episodes, render=False)

Episode 0, ep_reward: 9.0

Episode 1, ep_reward: 9.00

Episode 2, ep_reward: 8.00

Episode 3, ep_reward: 8.0

Episode 4, ep_reward: 9.0

Episode 5, ep_reward: 9.00

Episode 6, ep_reward: 9.00

Episode 7, ep_reward: 8.00

Episode 8, ep_reward: 7.0

Episode 9, ep_reward: 7.0

Episode 10, ep_reward: 8.0

Episode 11, ep_reward: 8.0

Episode 12, ep_reward: 10.0

Episode 13, ep_reward: 8.00

Episode 14, ep_reward: 8.0

Episode 15, ep_reward: 8.0

Episode 16, ep_reward: 9.0

Episode 17, ep_reward: 9.00

Episode 18, ep_reward: 9.00

Episode 19, ep_reward: 9.00

Episode 20, ep_reward: 8.00

Episode 21, ep_reward: 9.0

Episode 22, ep_reward: 10.0

Episode 23, ep_reward: 8.00

Episode 24, ep_reward: 9.0

Episode 25, ep_reward: 10.0

Episode 26, ep_reward: 9.00

Episode 27, ep_reward: 8.00

Episode 28, ep_reward: 9.0

Episode 29, ep_reward: 7.00

Episode 30, ep_reward: 9.0

Episode 31, ep_reward: 9.00

Episode 32, ep_reward: 9.00

Episode 33, ep_reward: 9.00

Episode 34, ep_reward: 8.00

Episode 3