In [None]:
# Tutorial by www.pylessons.com
# Tutorial written for - Tensorflow 2.3.1

import os
#os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
import random
import gym
import pylab
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Input, Dense, Lambda, Add, Conv2D, Flatten
from tensorflow.keras.optimizers import Adam, RMSprop
from tensorflow.keras import backend as K
import cv2
import threading
from threading import Thread, Lock
import time
import tensorflow_probability as tfp
from typing import Any, List, Sequence, Tuple


os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

tfd = tfp.distributions


class ActorCritic(tf.keras.Model):
  """Combined actor-critic network."""
  def __init__(
      self, 
      num_actions: int, 
      num_hidden_units: int):
    """Initialize."""
    super().__init__()

    self.num_actions = num_actions
    
    self.conv_1 = tf.keras.layers.Conv2D(16, 8, 4, padding="valid", activation="relu", kernel_regularizer='l2')
    self.conv_2 = tf.keras.layers.Conv2D(32, 4, 2, padding="valid", activation="relu", kernel_regularizer='l2')
    self.conv_3 = tf.keras.layers.Conv2D(32, 3, 1, padding="valid", activation="relu", kernel_regularizer='l2')
    
    self.lstm = tf.keras.layers.LSTM(128, return_sequences=True, return_state=True, kernel_regularizer='l2')
    
    self.common = tf.keras.layers.Dense(num_hidden_units, activation="relu", kernel_regularizer='l2')
    self.actor = tf.keras.layers.Dense(num_actions, kernel_regularizer='l2')
    self.critic = tf.keras.layers.Dense(1, kernel_regularizer='l2')

  def get_config(self):
    config = super().get_config().copy()
    config.update({
        'num_actions': self.num_actions,
        'num_hidden_units': self.num_hidden_units
    })
    return config
    
  def call(self, inputs: tf.Tensor, memory_state: tf.Tensor, carry_state: tf.Tensor, training) -> Tuple[tf.Tensor, tf.Tensor, 
                                                                                                        tf.Tensor, tf.Tensor]:
    batch_size = tf.shape(inputs)[0]

    conv_1 = self.conv_1(inputs)
    conv_2 = self.conv_2(conv_1)
    conv_3 = self.conv_3(conv_2)
    conv_3_reshaped = tf.keras.layers.Reshape((4*4,32))(conv_3)
    
    initial_state = (memory_state, carry_state)
    lstm_output, final_memory_state, final_carry_state  = self.lstm(conv_3_reshaped, initial_state=initial_state, 
                                                                    training=training)
    X_input = tf.keras.layers.Flatten()(lstm_output)
    x = self.common(X_input)
    
    return self.actor(x), self.critic(x), memory_state, carry_state


def safe_log(x):
  """Computes a safe logarithm which returns 0 if x is zero."""
  return tf.where(
      tf.math.equal(x, 0),
      tf.zeros_like(x),
      tf.math.log(tf.math.maximum(1e-12, x)))


def take_vector_elements(vectors, indices):
    """
    For a batch of vectors, take a single vector component
    out of each vector.
    Args:
      vectors: a [batch x dims] Tensor.
      indices: an int32 Tensor with `batch` entries.
    Returns:
      A Tensor with `batch` entries, one for each vector.
    """
    return tf.gather_nd(vectors, tf.stack([tf.range(tf.shape(vectors)[0]), indices], axis=1))


huber_loss = tf.keras.losses.Huber(reduction=tf.keras.losses.Reduction.SUM)
sparse_ce = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True, reduction=tf.keras.losses.Reduction.SUM)
mse_loss = tf.keras.losses.MeanSquaredError()


class A3CAgent:
    # Actor-Critic Main Optimization Algorithm
    def __init__(self, env_name):
        # Initialization
        # Environment and PPO parameters
        self.env_name = env_name       
        self.env = gym.make(env_name)
        self.action_size = self.env.action_space.n
        self.EPISODES, self.episode, self.max_average = 20000, 0, -21.0 # specific for pong
        
        memory_size = 10000
        self.memory = []
        self.lock = Lock()
        self.lr = 0.0001

        num_hidden_units = 512
        
        self.batch_size = 64

        # Instantiate plot memory
        self.scores, self.episodes, self.average = [], [], []

        self.Save_Path = 'Models'
        
        if not os.path.exists(self.Save_Path): os.makedirs(self.Save_Path)
        self.path = '{}_A3C_{}'.format(self.env_name, self.lr)
        self.model_name = os.path.join(self.Save_Path, self.path)

        # Create Actor-Critic network model
        self.model = ActorCritic(self.action_size, num_hidden_units)
        
        self.learning_rate = 0.0001
        self.optimizer = tf.keras.optimizers.Adam(self.lr)

    def remember(self, state, action, policy, reward, done, memory_state, carry_state):
        experience = state, action, policy, reward, done, memory_state, carry_state
        self.memory.append((experience))
        
    def act(self, state, memory_state, carry_state):
        prediction = self.model(state, memory_state, carry_state, training=False)
        
        dist = tfd.Categorical(logits=prediction[0])
        action = int(dist.sample()[0])
        policy = prediction[0]
        
        memory_state = prediction[2]
        carry_state = prediction[3]
        
        return action, policy, memory_state, carry_state

    def update(self, states, actions, agent_policies, rewards, dones, memory_states, carry_states):
        states = tf.convert_to_tensor(states, dtype=tf.float32)
        actions = tf.convert_to_tensor(actions, dtype=tf.int32)
        agent_policies = tf.convert_to_tensor(agent_policies, dtype=tf.float32)
        rewards = tf.convert_to_tensor(rewards, dtype=tf.float32)
        dones = tf.convert_to_tensor(dones, dtype=tf.bool)
        memory_states = tf.convert_to_tensor(memory_states, dtype=tf.float32)
        carry_states = tf.convert_to_tensor(carry_states, dtype=tf.float32)

        batch_size = states.shape[0]

        online_variables = self.model.trainable_variables
        with tf.GradientTape() as tape:
            tape.watch(online_variables)

            learner_policies = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
            learner_values = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)

            memory_state = tf.expand_dims(memory_states[0], 0)
            carry_state = tf.expand_dims(carry_states[0], 0)
            for i in tf.range(0, batch_size):
                learner_output = self.model(tf.expand_dims(states[i,:,:,:], 0), memory_state, carry_state,
                                                           training=True)
                learner_policy = learner_output[0]
                learner_policy = tf.squeeze(learner_policy)
                learner_policies = learner_policies.write(i, learner_policy)

                learner_value = learner_output[1]
                learner_value = tf.squeeze(learner_value)
                learner_values = learner_values.write(i, learner_value)

                memory_state = learner_output[2]
                carry_state = learner_output[3]

            learner_policies = learner_policies.stack()
            learner_values = learner_values.stack()

            learner_logits = tf.nn.softmax(learner_policies[:-1])
            agent_logits = tf.nn.softmax(agent_policies[:-1])

            actions = actions[:-1]
            rewards = rewards[1:]
            dones = dones[1:]

            bootstrap_value = learner_values[-1]
            learner_values = learner_values[:-1]

            discounting = 0.99
            discounts = tf.cast(~dones, tf.float32) * discounting

            target_action_probs = take_vector_elements(learner_logits, actions)
            target_action_log_probs = tf.math.log(target_action_probs)

            behaviour_action_probs = take_vector_elements(agent_logits, actions)
            behaviour_action_log_probs = tf.math.log(behaviour_action_probs)

            lambda_ = 1.0

            log_rhos = target_action_log_probs - behaviour_action_log_probs

            log_rhos = tf.convert_to_tensor(log_rhos, dtype=tf.float32)
            discounts = tf.convert_to_tensor(discounts, dtype=tf.float32)
            rewards = tf.convert_to_tensor(rewards, dtype=tf.float32)
            values = tf.convert_to_tensor(learner_values, dtype=tf.float32)
            bootstrap_value = tf.convert_to_tensor(bootstrap_value, dtype=tf.float32)

            clip_rho_threshold = tf.convert_to_tensor(1.0, dtype=tf.float32)
            clip_pg_rho_threshold = tf.convert_to_tensor(1.0, dtype=tf.float32)

            rhos = tf.math.exp(log_rhos)

            clipped_rhos = tf.minimum(clip_rho_threshold, rhos, name='clipped_rhos')

            cs = tf.minimum(1.0, rhos, name='cs')
            cs *= tf.convert_to_tensor(lambda_, dtype=tf.float32)

            values_t_plus_1 = tf.concat([values[1:], tf.expand_dims(bootstrap_value, 0)], axis=0)
            deltas = clipped_rhos * (rewards + discounts * values_t_plus_1 - values)

            acc = tf.zeros_like(bootstrap_value)
            vs_minus_v_xs = []
            for i in range(int(discounts.shape[0]) - 1, -1, -1):
                discount, c, delta = discounts[i], cs[i], deltas[i]
                acc = delta + discount * c * acc
                vs_minus_v_xs.append(acc)  

            vs_minus_v_xs = vs_minus_v_xs[::-1]

            vs = tf.add(vs_minus_v_xs, values, name='vs')
            vs_t_plus_1 = tf.concat([vs[1:], tf.expand_dims(bootstrap_value, 0)], axis=0)
            clipped_pg_rhos = tf.minimum(clip_pg_rho_threshold, rhos, name='clipped_pg_rhos')

            pg_advantages = (clipped_pg_rhos * (rewards + discounts * vs_t_plus_1 - values))

            vs = tf.stop_gradient(vs)
            pg_advantages = tf.stop_gradient(pg_advantages)

            actor_loss = -tf.reduce_mean(target_action_log_probs * pg_advantages)

            baseline_cost = 0.5
            v_error = values - vs
            critic_loss = baseline_cost * 0.5 * tf.reduce_mean(tf.square(v_error))

            total_loss = actor_loss + critic_loss

        grads = tape.gradient(total_loss, self.model.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.model.trainable_variables))
    
    def replay(self):
        memory_len = len(self.memory)
        if len(self.memory) > self.batch_size:
            start_index = random.randint(0, memory_len - self.batch_size)
            minibatch = self.memory[start_index:start_index+self.batch_size]
        else:
            return

        states = np.zeros((self.batch_size, 64, 64, 3), dtype=np.float32)
        actions = np.zeros(self.batch_size, dtype=np.int32)
        policies = np.zeros((self.batch_size, self.action_size), dtype=np.float32)
        rewards = np.zeros(self.batch_size, dtype=np.float32)
        dones = np.zeros(self.batch_size, dtype=np.bool)
        memory_states = np.zeros((self.batch_size, 128), dtype=np.float32)
        carry_states = np.zeros((self.batch_size, 128), dtype=np.float32)

        for i in range(len(minibatch)):
            states[i] = minibatch[i][0]
            actions[i] = minibatch[i][1]
            policies[i] = minibatch[i][2]
            rewards[i] = minibatch[i][3]
            dones[i] = minibatch[i][4]
            memory_states[i] = minibatch[i][5]
            carry_states[i] = minibatch[i][6]

        self.update(states, actions, policies, rewards, dones, memory_states, carry_states)
        
    def load(self, model_name):
        self.ActorCritic = load_model(model_name, compile=False)
        #self.Critic = load_model(Critic_name, compile=False)

    def save(self):
        self.ActorCritic.save(self.model_name)

    pylab.figure(figsize=(18, 9))
    def PlotModel(self, score, episode):
        self.scores.append(score)
        self.episodes.append(episode)
        self.average.append(sum(self.scores[-50:]) / len(self.scores[-50:]))
        if str(episode)[-2:] == "00":# much faster than episode % 100
            pylab.plot(self.episodes, self.scores, 'b')
            pylab.plot(self.episodes, self.average, 'r')
            pylab.ylabel('Score', fontsize=18)
            pylab.xlabel('Steps', fontsize=18)
            try:
                pylab.savefig(self.path + ".png")
            except OSError:
                pass

        return self.average[-1]
    
    def imshow(self, image, rem_step=0):
        cv2.imshow(self.model_name + str(rem_step), image[rem_step,...])
        if cv2.waitKey(25) & 0xFF == ord("q"):
            cv2.destroyAllWindows()
            return

    def reset(self, env):
        state = env.reset()
            
        return state
    
    def step(self, action, env):
        next_state, reward, done, info = env.step(action)
        
        return next_state, reward, done, info
    
    def train(self, n_threads):
        self.env.close()
        # Instantiate one environment per thread
        envs = [gym.make(self.env_name) for i in range(n_threads)]

        # Create threads
        threads = [threading.Thread(
                target=self.train_threading,
                daemon=True,
                args=(self,
                    envs[i],
                    i)) for i in range(n_threads)]

        for t in threads:
            time.sleep(2)
            t.start()
            
        for t in threads:
            time.sleep(10)
            t.join()
    
    def render(self, obs):
        #obs = cv2.cvtColor(obs, cv2.COLOR_RGB2BGR)
        cv2.imshow('obs', obs)
        cv2.waitKey(1)
    
    def train_threading(self, agent, env, thread):
        max_average = 15.0
        total_step = 0
        for e in range(self.EPISODES):
            state = self.reset(env)
            #print("state.shape: ", state.shape)
            state = state[35:195:2, ::2,:]
            state = cv2.resize(state, (64, 64), interpolation=cv2.INTER_CUBIC)
            state = np.expand_dims(state, 0) / 255.0
            
            done = False
            score = 0
            SAVING = ''
            
            memory_state = tf.zeros([1,128], dtype=tf.dtypes.float32)
            carry_state = tf.zeros([1,128], dtype=tf.dtypes.float32)
            while not done:
                #self.env.render()
                #self.render(state[0])
                
                action, policy, memory_state, carry_state = self.act(state, memory_state, carry_state)
                
                next_state, reward, done, _ = self.step(action, env)
                next_state = next_state[35:195:2, ::2,:]
                next_state = cv2.resize(next_state, (64, 64), interpolation=cv2.INTER_CUBIC)
                next_state = np.expand_dims(next_state, 0) / 255.0
                
                self.remember(state, action, policy, reward / 20.0, done, memory_state, carry_state)
                state = next_state
                score += reward

                if done:           
                    break
                
                self.lock.acquire()
                if total_step % 100 == 0:
                    # train model
                    self.replay()
                self.lock.release()
                    
                total_step += 1
                
            # Update episode count
            with self.lock:
                average = self.PlotModel(score, self.episode)
                # saving best models
                if average >= self.max_average:
                    self.max_average = average
                    #self.save()
                    SAVING = "SAVING"
                else:
                    SAVING = ""

                print("episode: {}/{}, thread: {}, score: {}, average: {:.2f} {}".format(self.episode, self.EPISODES, thread, score, average, SAVING))
                if(self.episode < self.EPISODES):
                    self.episode += 1
                 
    def test(self, Actor_name, Critic_name):
        self.load(Actor_name, Critic_name)
        for e in range(100):
            state = self.reset(self.env)
            done = False
            score = 0
            while not done:
                self.env.render()
                action = np.argmax(self.Actor.predict(state))
                state, reward, done, _ = self.step(action, self.env, state)
                score += reward
                if done:
                    print("episode: {}/{}, score: {}".format(e, self.EPISODES, score))
                    break

        self.env.close()


if __name__ == "__main__":
    env_name = 'PongDeterministic-v4'
    #env_name = 'Pong-v0'
    agent = A3CAgent(env_name)
    
    #agent.run() # use as A2C
    agent.train(n_threads=1) # use as A3C
    #agent.test('Models/Pong-v0_A3C_2.5e-05_Actor.h5', '')

episode: 0/20000, thread: 0, score: -21.0, average: -21.00 SAVING
episode: 1/20000, thread: 0, score: -21.0, average: -21.00 SAVING
episode: 2/20000, thread: 0, score: -21.0, average: -21.00 SAVING
episode: 3/20000, thread: 0, score: -21.0, average: -21.00 SAVING
episode: 4/20000, thread: 0, score: -21.0, average: -21.00 SAVING
episode: 5/20000, thread: 0, score: -20.0, average: -20.83 SAVING
episode: 6/20000, thread: 0, score: -21.0, average: -20.86 
episode: 7/20000, thread: 0, score: -20.0, average: -20.75 SAVING
episode: 8/20000, thread: 0, score: -20.0, average: -20.67 SAVING
episode: 9/20000, thread: 0, score: -20.0, average: -20.60 SAVING
episode: 10/20000, thread: 0, score: -21.0, average: -20.64 
