In [102]:
# 1 samples must be a = [state, action, rewards, next_state, is_done]
# is_done is for determining a terminal or non-terminal state

import random
import tensorflow as tf
import numpy as np

class ReplayMemory:
    main_memory = []
    max_reply = 0
    num_batch = 0
    def __init__(self, max_replay: int, mini_batch_num: int):
        self.max_reply = max_replay
        self.num_batch = mini_batch_num

class DeepQAgent:
    replay:ReplayMemory = None
    num_actions: int = None
    eval_model = None
    target_model = None
    gamma:float = None
    epsilon:float = None
    epsilon_min: float = None
    epsilon_decay: float = None
    
    # counter for updating model weight
    learn_counter: int = 0
    update_weight_on: int = 0
    
    def __init__(self, num_actions: int, max_replay: int, mini_batch_num: int, 
                 weight_update: int, epsilon: float, epsilon_min: float, 
                 epsilon_decay:float, gamma:float):
        self.replay = ReplayMemory(max_replay, mini_batch_num)
        self.eval_model, self.target_model = self.create_model()
        self.num_actions = int(num_actions)
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_decay
        self.update_weight_on = weight_update
        
    def create_model(self):
        # Create your own model and return the sequential model.
        # Need to watchout your input is need to be a state shape
        # And your output need to be your action shape
        model = tf.keras.models.Sequential([
            tf.keras.layers.Input(shape=(4,)),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dense(32, activation='relu'),
            tf.keras.layers.Dense(2, activation='linear'),
        ])
        
        model.compile(optimizer = 'adam',
                      loss = 'mean_squared_error',
                      metrics= ['mse']
                      )
        
        return model, model
        
    def store_memory(self, state, action, rewards, next_state, is_done):
        if len(self.replay.main_memory) == self.replay.max_reply:
            self.replay.main_memory.pop(0)
        self.replay.main_memory.append([state, action, rewards, 
                                        next_state, is_done])
        
    def pick_action(self, state, epsilon = None):
        if epsilon == None:
            epsilon = self.epsilon
        action = None
        if random.random() > epsilon:
            if type(state) != list:
                state = state.tolist()
            prediction = self.eval_model.predict([state])[0]
            action = np.argmax(prediction)
        else:
            action = random.randint(0, self.num_actions - 1)
            
        return action
    
    def learn(self):
        if len(self.replay.main_memory) < self.replay.num_batch:
            return
        samples = self.__sample_mini_batch__()
        X_current = [x[0] for x in samples]
        X_current = np.array(X_current)
        X_next = [x[3] for x in samples]
        X_next = np.array(X_next)
        prediction = self.eval_model.predict(X_current)
        target_prediction = self.target_model.predict(X_next)
        for i in range(len(samples)):
            if samples[i][4]: # if is_done
                # For terminal next state
                prediction[i][samples[i][1]] = samples[i][2]
            else:
                # For non-terminal next state
                target = self.gamma * target_prediction[i][samples[i][1]]
                prediction[i][samples[i][1]] = samples[i][2] + target
                
        X_train = [i[0] for i in samples]
        X_train = np.array(X_train)
        self.eval_model.fit(X_train, prediction, verbose=1, epochs=10)
        print(self.epsilon)
        if self.learn_counter % self.update_weight_on == 0:
            self.__update_target_models__()
        
        # Post Learn
        self.learn_counter += 1
        epsilon_after_decay = self.epsilon * self.epsilon_decay
        if  epsilon_after_decay < self.epsilon_min:
            self.epsilon = self.epsilon_min
        else:
            self.epsilon = epsilon_after_decay
            
    def load_model(self,path:str):
        self.target_model = tf.keras.models.load_model(path)
        self.eval_model = tf.keras.models.load_model(path)
        print("Model Loaded")
        
    def save_model(self,path:str):
        self.eval_model.save(path)
        print("Model saved")
            
    def __sample_mini_batch__(self):
        return random.sample(self.replay.main_memory, self.replay.num_batch)

    def __update_target_models__(self):
        self.target_model.set_weights(self.eval_model.get_weights())

In [103]:
import gym
class Environment:
    
    def __init__(self):
        self.game = gym.make("CartPole-v1")
        action_space = self.game.action_space.n
        self.agent = DeepQAgent(action_space, 10000, 2000, 
                                10, 1, 0.05, 0.995, 0.95)
        
    def train(self, num_ep: int):
        for i in range(1 , num_ep):
            print(f"Episodes {i}")
            state = self.game.reset()
            while True:
                self.game.render()
                action = self.agent.pick_action(state)
                state_next, reward, terminal, info = self.game.step(action)
                if terminal:
                    reward = -1.0
                self.agent.store_memory(state, action, reward, state_next, terminal)
                state = state_next
                if terminal:
                    break
                
            self.agent.learn()
                
    def play(self):
        while True:
            state = self.game.reset()
            while True:
                self.game.render()
                action = self.agent.pick_action(state, 0)
                state, _, terminal, _ = self.game.step(action)
                if terminal:
                    break
            

In [104]:
envir = Environment()

In [105]:
envir.train(30000)

Episodes 1
Episodes 2
Episodes 3
Episodes 4
Episodes 5
Episodes 6
Episodes 7
Episodes 8
Episodes 9
Episodes 10
Episodes 11
Episodes 12
Episodes 13
Episodes 14
Episodes 15
Episodes 16
Episodes 17
Episodes 18
Episodes 19
Episodes 20
Episodes 21
Episodes 22
Episodes 23
Episodes 24
Episodes 25
Episodes 26
Episodes 27
Episodes 28
Episodes 29
Episodes 30
Episodes 31
Episodes 32
Episodes 33
Episodes 34
Episodes 35
Episodes 36
Episodes 37
Episodes 38
Episodes 39
Episodes 40
Episodes 41
Episodes 42
Episodes 43
Episodes 44
Episodes 45
Episodes 46
Episodes 47
Episodes 48
Episodes 49
Episodes 50
Episodes 51
Episodes 52
Episodes 53
Episodes 54
Episodes 55
Episodes 56
Episodes 57
Episodes 58
Episodes 59
Episodes 60
Episodes 61
Episodes 62
Episodes 63
Episodes 64
Episodes 65
Episodes 66
Episodes 67
Episodes 68
Episodes 69
Episodes 70
Episodes 71
Episodes 72
Episodes 73
Episodes 74
Episodes 75
Episodes 76
Episodes 77
Episodes 78
Episodes 79
Episodes 80
Episodes 81
Episodes 82
Episodes 83
Episodes 84
E

KeyboardInterrupt: 

In [106]:
envir.agent.save_model("./model/cart-pole.h5")

Model saved


In [107]:
envir.play()

KeyboardInterrupt: 