In [1]:
import random
import gym
import math
import numpy as np
from collections import deque
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import Adam
import logger

def random_states_experiment(model, episode_num, env):
   
    obs_space = env.observation_space
    obs_min = obs_space.low
    obs_max = obs_space.high


    num_samples = 10000
    random_state_samples = np.random.uniform(
        low=obs_min, high=obs_max, size=(num_samples, len(obs_min)))

    predicted_dists = model.predict(random_state_samples)
    predicted_dists = np.array(predicted_dists)
   
    logger.log_experiment_random_states(random_state_samples, predicted_dists, obs_min, obs_max, episode_num, [], apply_softmax= True)


In [2]:
# Inspired by https://keon.io/deep-q-learning/

class DQNCartPoleSolver():
    def __init__(self, n_episodes=1000, n_win_ticks=195, max_env_steps=None, gamma=1.0, epsilon=1.0, epsilon_min=0.01, epsilon_log_decay=0.995, alpha=0.01, alpha_decay=0.01, batch_size=64, monitor=False, quiet=False):
        logger.create_logger("logs_case_study")
        logger.log_action_meanings(["LEFT", "RIGHT"])
        self.memory = deque(maxlen=100000)
        self.env = gym.make('CartPole-v0')
        if monitor: self.env = gym.wrappers.Monitor(self.env, '../data/cartpole-1', force=True)
        self.gamma = gamma
        self.epsilon = epsilon
        self.epsilon_min = epsilon_min
        self.epsilon_decay = epsilon_log_decay
        self.alpha = alpha
        self.alpha_decay = alpha_decay
        self.n_episodes = n_episodes
        self.n_win_ticks = n_win_ticks
        self.batch_size = batch_size
        self.quiet = quiet
        if max_env_steps is not None: self.env._max_episode_steps = max_env_steps
        
    
        # Init model
        self.model = Sequential()
        self.model.add(Dense(24, input_dim=4, activation='tanh'))
        self.model.add(Dense(48, activation='tanh'))
        self.model.add(Dense(2, activation='linear'))
        self.model.compile(loss='mse', optimizer=Adam(lr=self.alpha, decay=self.alpha_decay))

    def remember(self, state, action, reward, next_state, done):
        self.memory.append((state, action, reward, next_state, done))

    def choose_action(self, state, epsilon, timestep, episode):
        
        if (np.random.random() <= epsilon):
            logger.log_custom_timestep_scalar(0, timestep, episode, "random")
            
            ### JUST FOR LOGGING ###
            preds = self.model.predict(state)
            logger.log_action_probs(preds[0], episode, timestep, apply_softmax = True) 
            ########################
            
            return self.env.action_space.sample(), preds
        else:
            logger.log_custom_timestep_scalar(1, timestep, episode, "random")
            preds = self.model.predict(state)
            logger.log_custom_timestep_scalar(np.mean(preds), timestep, episode, "q_val")
            logger.log_action_probs(preds[0], episode, timestep, apply_softmax = True)
            action = np.argmax(preds)
            return action, preds

    def get_epsilon(self, t):
        return max(self.epsilon_min, min(self.epsilon, 1.0 - math.log10((t + 1) * self.epsilon_decay)))

    def preprocess_state(self, state):
        return np.reshape(state, [1, 4])

    def replay(self, batch_size):
        x_batch, y_batch = [], []
        minibatch = random.sample(
            self.memory, min(len(self.memory), batch_size))
        for state, action, reward, next_state, done in minibatch:
            y_target = self.model.predict(state)
            y_target[0][action] = reward if done else reward + self.gamma * np.max(self.model.predict(next_state)[0])
            x_batch.append(state[0])
            y_batch.append(y_target[0])
        
        hist = self.model.fit(np.array(x_batch), np.array(y_batch), batch_size=len(x_batch), verbose=0)
        loss = hist.history["loss"][0]
        if self.epsilon > self.epsilon_min:
            self.epsilon *= self.epsilon_decay
        return loss

    def run(self):
        scores = deque(maxlen=100)

        for e in range(self.n_episodes):
            curr_probs = []
            actions_episode = []
            state = self.preprocess_state(self.env.reset())
            done = False
            returns = []
            i = 0
            while not done:
                
                ### JUST FOR LOGGING ###
                frame_img = self.env.render(mode="rgb_array")
                logger.log_frame(frame=frame_img, episode_count = e, step=i)
                

                ########################
                
                action, preds = self.choose_action(state, self.get_epsilon(e), i, e)
                curr_probs.append(logger.softmax(preds[0]))
                
                ### JUST FOR LOGGING ###
                actions_episode.append(action)
                ########################
                
                next_state, reward, done, _ = self.env.step(action)
                
                logger.log_custom_timestep_scalar(reward, i, e, 'reward')
                
                returns.append(reward)
                next_state = self.preprocess_state(next_state)
                self.remember(state, action, reward, next_state, done)
                state = next_state
                
                ### JUST FOR LOGGING ###
                weights = self.model.weights[-2].numpy()
                logger.log_weights(weight_tensor=weights,
                           step=i,
                           episode_count=e)
                ########################
                
                
                i += 1
                
            ### LOGGING ####    
            logger.log_episode_return(episode_return = i, episode_count = e)
            logger.log_action_distribution(np.array(actions_episode), e)
            logger.log_custom_distribution(np.array(returns), "reward_distribution", e)
            ################
            
            
            
            scores.append(i)
            mean_score = np.mean(scores)
            print('Episode {} - Mean score: {} - Episode return {}'.format(e, mean_score, i))
            if mean_score >= self.n_win_ticks and e >= 100:
                if not self.quiet: print('Ran {} episodes. Solved after {} trials ✔'.format(e, e - 100))
                return e - 100
            
            if e%10==0:
                random_states_experiment(self.model, e, self.env)
            
            if e>=1:
                logger.log_action_divergence(curr_probs, probs_old, e)
            probs_old = curr_probs
            curr_probs= []
            
            if e % 100 == 0 and not self.quiet:
                print('[Episode {}] - Mean survival time over last 100 episodes was {} ticks.'.format(e, mean_score))

            loss = self.replay(self.batch_size)
            logger.log_custom_episode_scalar(loss, e, "loss")
        
        if not self.quiet: print('Did not solve after {} episodes 😞'.format(e))
        return e

if __name__ == '__main__':
    agent = DQNCartPoleSolver()
    agent.run()

Episode 0 - Mean score: 15.0 - Episode return 15
[Episode 0] - Mean survival time over last 100 episodes was 15.0 ticks.
Episode 1 - Mean score: 20.0 - Episode return 25
Episode 2 - Mean score: 34.0 - Episode return 62
Episode 3 - Mean score: 34.75 - Episode return 37
Episode 4 - Mean score: 30.6 - Episode return 14
Episode 5 - Mean score: 28.5 - Episode return 18
Episode 6 - Mean score: 25.857142857142858 - Episode return 10
Episode 7 - Mean score: 23.75 - Episode return 9
Episode 8 - Mean score: 22.22222222222222 - Episode return 10
Episode 9 - Mean score: 20.8 - Episode return 8
Episode 10 - Mean score: 19.727272727272727 - Episode return 9
Episode 11 - Mean score: 19.0 - Episode return 11
Episode 12 - Mean score: 18.153846153846153 - Episode return 8
Episode 13 - Mean score: 17.5 - Episode return 9
Episode 14 - Mean score: 16.933333333333334 - Episode return 9
Episode 15 - Mean score: 16.4375 - Episode return 9
Episode 16 - Mean score: 16.0 - Episode return 9
Episode 17 - Mean scor

Episode 139 - Mean score: 32.37 - Episode return 27
Episode 140 - Mean score: 32.39 - Episode return 23
Episode 141 - Mean score: 32.44 - Episode return 24
Episode 142 - Mean score: 32.5 - Episode return 32
Episode 143 - Mean score: 32.59 - Episode return 31
Episode 144 - Mean score: 32.74 - Episode return 37
Episode 145 - Mean score: 32.85 - Episode return 28
Episode 146 - Mean score: 32.95 - Episode return 24
Episode 147 - Mean score: 33.12 - Episode return 35
Episode 148 - Mean score: 33.25 - Episode return 30
Episode 149 - Mean score: 33.27 - Episode return 19
Episode 150 - Mean score: 33.36 - Episode return 25
Episode 151 - Mean score: 33.31 - Episode return 20
Episode 152 - Mean score: 33.3 - Episode return 17
Episode 153 - Mean score: 33.2 - Episode return 19
Episode 154 - Mean score: 33.14 - Episode return 14
Episode 155 - Mean score: 33.0 - Episode return 14
Episode 156 - Mean score: 33.0 - Episode return 15
Episode 157 - Mean score: 32.95 - Episode return 14
Episode 158 - Mea

Episode 296 - Mean score: 24.44 - Episode return 78
Episode 297 - Mean score: 25.01 - Episode return 72
Episode 298 - Mean score: 25.49 - Episode return 66
Episode 299 - Mean score: 25.73 - Episode return 44
Episode 300 - Mean score: 25.94 - Episode return 42
[Episode 300] - Mean survival time over last 100 episodes was 25.94 ticks.
Episode 301 - Mean score: 26.24 - Episode return 53
Episode 302 - Mean score: 26.62 - Episode return 56
Episode 303 - Mean score: 26.96 - Episode return 55
Episode 304 - Mean score: 27.25 - Episode return 55
Episode 305 - Mean score: 27.4 - Episode return 52
Episode 306 - Mean score: 27.5 - Episode return 48
Episode 307 - Mean score: 27.22 - Episode return 46
Episode 308 - Mean score: 27.2 - Episode return 57
Episode 309 - Mean score: 27.09 - Episode return 49
Episode 310 - Mean score: 27.53 - Episode return 75
Episode 311 - Mean score: 27.82 - Episode return 48
Episode 312 - Mean score: 28.93 - Episode return 129
Episode 313 - Mean score: 29.4 - Episode re

Episode 451 - Mean score: 57.13 - Episode return 62
Episode 452 - Mean score: 57.25 - Episode return 42
Episode 453 - Mean score: 57.45 - Episode return 48
Episode 454 - Mean score: 57.77 - Episode return 58
Episode 455 - Mean score: 58.11 - Episode return 62
Episode 456 - Mean score: 57.82 - Episode return 44
Episode 457 - Mean score: 57.63 - Episode return 42
Episode 458 - Mean score: 57.19 - Episode return 31
Episode 459 - Mean score: 56.78 - Episode return 45
Episode 460 - Mean score: 56.77 - Episode return 42
Episode 461 - Mean score: 56.62 - Episode return 40
Episode 462 - Mean score: 56.72 - Episode return 54
Episode 463 - Mean score: 56.28 - Episode return 42
Episode 464 - Mean score: 55.83 - Episode return 40
Episode 465 - Mean score: 55.25 - Episode return 39
Episode 466 - Mean score: 54.9 - Episode return 36
Episode 467 - Mean score: 54.75 - Episode return 34
Episode 468 - Mean score: 54.76 - Episode return 62
Episode 469 - Mean score: 54.66 - Episode return 43
Episode 470 -

Episode 605 - Mean score: 115.1 - Episode return 200
Episode 606 - Mean score: 115.99 - Episode return 154
Episode 607 - Mean score: 117.38 - Episode return 200
Episode 608 - Mean score: 118.44 - Episode return 149
Episode 609 - Mean score: 119.88 - Episode return 200
Episode 610 - Mean score: 120.5 - Episode return 117
Episode 611 - Mean score: 120.77 - Episode return 114
Episode 612 - Mean score: 121.36 - Episode return 126
Episode 613 - Mean score: 122.5 - Episode return 159
Episode 614 - Mean score: 123.34 - Episode return 145
Episode 615 - Mean score: 123.99 - Episode return 134
Episode 616 - Mean score: 123.05 - Episode return 106
Episode 617 - Mean score: 123.65 - Episode return 111
Episode 618 - Mean score: 124.9 - Episode return 196
Episode 619 - Mean score: 125.84 - Episode return 181
Episode 620 - Mean score: 126.3 - Episode return 122
Episode 621 - Mean score: 126.95 - Episode return 154
Episode 622 - Mean score: 127.47 - Episode return 152
Episode 623 - Mean score: 128.62 

Episode 756 - Mean score: 135.68 - Episode return 127
Episode 757 - Mean score: 136.21 - Episode return 152
Episode 758 - Mean score: 136.36 - Episode return 133
Episode 759 - Mean score: 136.19 - Episode return 115
Episode 760 - Mean score: 136.32 - Episode return 113
Episode 761 - Mean score: 136.53 - Episode return 125
Episode 762 - Mean score: 136.79 - Episode return 120
Episode 763 - Mean score: 136.77 - Episode return 102
Episode 764 - Mean score: 136.88 - Episode return 115
Episode 765 - Mean score: 137.15 - Episode return 145
Episode 766 - Mean score: 136.92 - Episode return 105
Episode 767 - Mean score: 136.98 - Episode return 120
Episode 768 - Mean score: 137.33 - Episode return 134
Episode 769 - Mean score: 137.37 - Episode return 112
Episode 770 - Mean score: 137.77 - Episode return 137
Episode 771 - Mean score: 138.05 - Episode return 123
Episode 772 - Mean score: 138.13 - Episode return 118
Episode 773 - Mean score: 138.25 - Episode return 124
Episode 774 - Mean score: 13

Episode 906 - Mean score: 159.5 - Episode return 200
Episode 907 - Mean score: 159.99 - Episode return 200
Episode 908 - Mean score: 160.33 - Episode return 165
Episode 909 - Mean score: 160.3 - Episode return 182
Episode 910 - Mean score: 160.88 - Episode return 200
Episode 911 - Mean score: 161.35 - Episode return 200
Episode 912 - Mean score: 161.4 - Episode return 200
Episode 913 - Mean score: 161.81 - Episode return 170
Episode 914 - Mean score: 162.41 - Episode return 200
Episode 915 - Mean score: 163.05 - Episode return 200
Episode 916 - Mean score: 163.05 - Episode return 200
Episode 917 - Mean score: 162.83 - Episode return 178
Episode 918 - Mean score: 162.83 - Episode return 200
Episode 919 - Mean score: 163.02 - Episode return 200
Episode 920 - Mean score: 163.55 - Episode return 200
Episode 921 - Mean score: 163.55 - Episode return 200
Episode 922 - Mean score: 163.32 - Episode return 160
Episode 923 - Mean score: 163.49 - Episode return 163
Episode 924 - Mean score: 163.4