In [1]:
from tensorflow import keras
from tensorflow.keras import backend as K
from tensorflow.keras.layers import Dense, Activation, Input
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.optimizers import Adam
from tensorflow.python.framework.ops import disable_eager_execution
import numpy as np
import random
import seaborn as sns

disable_eager_execution()

In [3]:
class Agent(object):
    def __init__(self, alpha, beta, gamma=0.99, n_actions=4,
                 layer1_size=1024, layer2_size=512, input_dims=32, memory_size=1000):
        self.gamma = gamma
        self.alpha = alpha
        self.beta = beta
        self.input_dims = input_dims
        self.fc1_dims = layer1_size
        self.fc2_dims = layer2_size
        self.n_actions = n_actions

        self.actor, self.critic, self.policy = self.build_actor_critic_network()
        self.action_space = [i for i in range(n_actions)]
        
        self.memory = []
        self.memory_size = memory_size
        self.fig_count = 0

    def build_actor_critic_network(self):
        inputs = Input(shape=(self.input_dims,))
        delta = Input(shape=[1])
        dense1 = Dense(self.fc1_dims, activation='relu')(inputs)
        dense2 = Dense(self.fc2_dims, activation='relu')(dense1)
        probs = Dense(self.n_actions, activation='softmax')(dense2)
        
        
        inputs_critic = Input(shape=(self.input_dims,))
        dense1_critic = Dense(self.fc1_dims, activation='relu')(inputs_critic)
        dense2_critic = Dense(self.fc2_dims, activation='relu')(dense1_critic)
        q_values = Dense(self.n_actions, activation='linear')(dense2_critic)

        def custom_loss(y_true, y_pred):
            out = K.clip(y_pred, 1e-1, 1-1e-1)
            log_lik = y_true*K.log(out)
            
            return K.sum(-log_lik*delta)

        actor = Model(inputs=[inputs, delta], outputs=[probs])

        actor.compile(optimizer=Adam(learning_rate=self.alpha), loss=custom_loss)

        critic = Model(inputs=[inputs_critic], outputs=[q_values])

        critic.compile(optimizer=Adam(learning_rate=self.beta), loss='mean_squared_error')

        policy = Model(inputs=[inputs], outputs=[probs])

        return actor, critic, policy

    def choose_action(self, observation, ep = 0):
        # state = observation[np.newaxis, :]
        state = np.array([observation])
        probabilities = self.policy.predict(state)[0]
        action = np.random.choice(self.action_space, p=probabilities)
        if random.random() < ep:
            action = np.random.choice(self.action_space)

        return action
    
    def addStateTransition(self, stateTransition):
        if(len(self.memory) >= self.memory_size):
            del self.memory[random.randint(0, len(self.memory) - 1)] #remove one element
        self.memory.append(stateTransition)
    
    def stateTransition(self, gs0, action, reward, gs1, terminal):
        st = [gs0, action, reward, gs1, terminal]
        self.addStateTransition(st)
    
    def learn(self):
        state_transitions = self.memory
        random.shuffle(state_transitions)
        
        # state_transitions are lists of 5 tuples: current_state, action, reward, next_state, done/game_over
        current_states = np.array([state_transition[0] for state_transition in state_transitions])
        actions = np.array([state_transition[1] for state_transition in state_transitions])
        rewards = np.array([state_transition[2] for state_transition in state_transitions])
        next_states = np.array([state_transition[3] for state_transition in state_transitions])
        dones = np.array([0 if state_transition[4] else 1 for state_transition in state_transitions])
        
        critic_value_next = np.max(self.critic.predict(next_states), axis=1)
        critic_q_values = self.critic.predict(current_states)
        
        target = rewards + self.gamma * critic_value_next * dones
        delta = target - np.max(critic_q_values, axis = 1)
        
        # for i in range(critic_q_values.shape[0]):
        #     critic_q_values[i][actions[i]] = target[i]
        critic_q_values[np.arange(critic_q_values.shape[0]), actions] = target
        
        actions_one_hot = np.zeros([actions.shape[0], self.n_actions])
        actions_one_hot[np.arange(actions.shape[0]), actions] = 1
        
        history_actor = self.actor.fit([current_states, delta], actions_one_hot, verbose=0, epochs=1)
        history_actor = history_actor.history['loss']

        history_critic = self.critic.fit(current_states, critic_q_values, verbose=0, epochs=1)
        history_critic = history_critic.history['loss']
        
        sns.lineplot(x = range(len(history_actor)), y = history_actor)
        plt.savefig(f"figures/fig_actor_{self.fig_count}.png")
        plt.clf()
        
        sns.lineplot(x = range(len(history_critic)), y = history_critic)
        plt.savefig(f"figures/fig_critic_{self.fig_count}.png")
        plt.clf()
        self.fig_count += 1
            
            