In [None]:
import os
import time
import math
import random
from datetime import datetime

from tqdm import tqdm

import numpy as np
import matplotlib.pyplot as plt

import gym

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Conv2D, Dense, Flatten
from tensorflow.keras.optimizers import RMSprop

In [None]:
class TransitionTable:
    def __init__(
        self,
        stateDim=(105, 80),
        histLen=1,
        maxSize=1_000_000,
        bufferSize=1024,
    ):
        self.stateDim = stateDim
        self.histLen = histLen
        self.maxSize = maxSize
        self.bufferSize = bufferSize
        self.buf_ind = None

        self.recentMemSize = self.histLen

        self.numEntries = 0
        self.insertIndex = 0

        # The original implementation has multiple `histType`, we are going to use 'linear' only. Because of that, there is no `histIndices`

        # DONE pre-allocate (maxSize, dims) Tensors
        self.s = np.zeros(shape=(self.maxSize, *self.stateDim), dtype=np.uint8)
        self.a = np.zeros(self.maxSize, dtype=np.uint8)
        self.r = np.zeros(self.maxSize, dtype=np.float32)
        self.t = np.zeros(self.maxSize, dtype=np.uint8)

        # Tables for storing the last `histLen` states. They are used for constructing the most recent agent state more easily
        self.recent_s = []
        self.recent_a = []
        self.recent_t = []

        # DONE pre-allocate Tensors
        s_size = (self.histLen, *self.stateDim)
        # use 'channels_first' because it is easier to construct array without having to reshape
        self.buf_a = np.zeros(self.bufferSize, dtype=np.uint8)
        self.buf_r = np.zeros(self.bufferSize, dtype=np.float32)
        self.buf_term = np.zeros(self.bufferSize, dtype=np.uint8)
        # shape = (bufferSize, histLen, height, width)
        # default = (1024, 4, 105, 80)
        self.buf_s = np.zeros(shape=(self.bufferSize, *s_size), dtype=np.uint8)
        self.buf_s2 = np.zeros(shape=(self.bufferSize, *s_size), dtype=np.uint8)

    def reset(self):  # DONE
        self.numEntries = 0
        self.insertIndex = 0

    def size(self):  # DONE
        return self.numEntries

    def empty(self):  # DONE
        return self.numEntries == 0

    def fill_buffer(self):  # DONE 3
        assert self.numEntries >= self.bufferSize
        # clear CPU buffers
        self.buf_ind = 1

        for buf_ind in range(self.bufferSize):
            s, a, r, s2, term = self.sample_one()
            # s.shape = (4, 105, 80)
            # s2.shape = (4, 105, 80)
            self.buf_s[buf_ind] = s
            self.buf_a[buf_ind] = a
            self.buf_r[buf_ind] = r
            self.buf_s2[buf_ind] = s2
            self.buf_term[buf_ind] = term

    def sample_one(self):  # TODO 3
        assert self.numEntries > 1

        valid = False
        while not valid:
            # start at the second index because of previous action
            index = random.randrange(1, self.numEntries - self.recentMemSize)

            # TODO 3 why do we need to check `index + self.recentMemSize - 1` instead of `index`
            if self.t[index + self.recentMemSize - 1] == 0:
                valid = True

        return self.get(index)

    def sample(self, batch_size=1):  # DONE 4
        assert batch_size < self.bufferSize

        if (self.buf_ind is None) or (self.buf_ind + batch_size) > self.bufferSize:
            self.fill_buffer()

        index = self.buf_ind
        self.buf_ind = self.buf_ind + batch_size
        start = index
        end = index + batch_size

        # DONE 3 only return a copy
        s = np.copy(self.buf_s[start:end])
        a = np.copy(self.buf_a[start:end])
        r = np.copy(self.buf_r[start:end])
        term = np.copy(self.buf_term[start:end])
        s2 = np.copy(self.buf_s2[start:end])

        return s, a, r, s2, term

    def concatFrames(self, index, use_recent=False):  # DONE 4
        """
        The `index` must not be the terminal state.
        """
        if use_recent:
            s, t = self.recent_s, self.recent_t
        else:
            s, t = self.s, self.t

        # DONE copy frames and zeros pad missing frames
        fullstate = np.zeros(shape=(self.histLen, *self.stateDim), dtype=np.uint8)

        end_index = min(len(s) - 1, index + self.histLen)

        for fs_idx, i in enumerate(range(index, end_index)):
            fullstate[fs_idx] = np.copy(s[i])

        # DONE 5 copy frames and zero-out un-related frames
        # Because all the episode frames is stack together, 
        # the below code is use to find the terminal state index (episode-seperator) 
        # and zero out all the frames after that index.
        zero_out = False

        # start at the second frame
        for i in range(1, self.histLen):
            if not zero_out:
                idx = index + i
                # check terminal state
                if t[idx] == 1:
                    zero_out = True

            # after terminal state is comfirmed, 
            # zero out frames starting at the terminal index
            if zero_out:
                fullstate[i] = np.zeros_like(fullstate[i])

        return fullstate

    def concatActions(self, index, use_recent=False):  # TODO 9
        pass

    def get_recent(self):  # DONE
        # Assumes that the most recent state has been added, but the action has not
        return self.concatFrames(0, True)

    def get(self, index):  # DONE
        s = self.concatFrames(index)
        s2 = self.concatFrames(index + 1)
        # TODO 3 what is `ar_index`
        # why `ar_indxt = index + self.recentMemSize - 1`
        ar_index = index + self.recentMemSize - 1

        return s, self.a[ar_index], self.r[ar_index], s2, self.t[ar_index + 1]

    def add(self, s, a, r, term):  # DONE
        # Increment until at full capacity
        if self.numEntries < self.maxSize:
            self.numEntries += 1

        # Always insert at next index, then wrap around
        self.insertIndex += 1
        # Overwrite oldest experience once at capacity
        if self.insertIndex >= self.maxSize:
            self.insertIndex = 0

        # Overwrite (s, a, r, t) at `insertIndex`
        self.s[self.insertIndex] = s
        self.a[self.insertIndex] = a
        self.r[self.insertIndex] = r
        if term:
            self.t[self.insertIndex] = 1
        else:
            self.t[self.insertIndex] = 0

    def add_recent_state(self, s, term):  # DONE
        if len(self.recent_s) == 0:
            for i in range(self.recentMemSize):
                self.recent_s.append(np.zeros_like(s))
                self.recent_t.append(0)

        self.recent_s.append(s)
        if term:
            self.recent_t.append(1)
        else:
            self.recent_t.append(0)

        # keep recentMemSize states
        if len(self.recent_t) > self.recentMemSize:
            self.recent_s.pop(0)
            self.recent_t.pop(0)

    def add_recent_action(self, a):  # DONE
        if len(self.recent_a) == 0:
            for i in range(self.recentMemSize):
                self.recent_a.append(0)

        self.recent_a.append(a)

        # keep recentMemSize steps
        if len(self.recent_a) > self.recentMemSize:
            self.recent_a.pop(0)

In [None]:
class DQNAgent:
    def __init__(
        self,
        n_actions=4,
        ep_start=1.0,
        ep_end=0.1,
        ep_endt=1_000_000,
        lr=0.00025,
        minibatch_size=1,
        valid_size=512,
        discount=0.99,
        update_freq=1,
        n_replay=1,
        learn_start=0,
        replay_memory=1_000_000,
        hist_len=1,
        max_reward=None,
        min_reward=None,
        network=None,
    ):
        """
        Parameters
        ----------
        n_actions : int
            The number of actions that the agent can take.

        ep_start : float
            The inital epsilon value in epsilon-greedy.

        ep_end : float
            The final epsilon value in epsilon-greedy.

        ep_endt : int
            The number of timesteps over which the inital value of epislon is linearly annealed to its final value.

        lr : float
            The learning rate used by RMSProp.
        """
        # self.state_dim = state_dim
        self.n_actions = n_actions

        # epsilon annealing
        self.ep_start = ep_start  # inital epsilon value
        self.ep = self.ep_start  # exploration probability
        self.ep_end = ep_end  # final epsilon value
        self.ep_endt = ep_endt  # the number of timesteps over which the inital value of epislon is linearly annealed to its final value

        self.lr = lr
        self.minibatch_size = minibatch_size
        self.valid_size = valid_size

        # Q-learning paramters
        self.discount = discount  # discount factor
        self.update_freq = update_freq
        # number of points to replay per learning step
        self.n_replay = n_replay
        # number of steps after which learning starts
        self.learn_start = learn_start
        # size of the transition table
        self.replay_memory = replay_memory
        self.hist_len = hist_len
        self.max_reward = max_reward
        self.min_reward = min_reward

        self.network = network if network else self.createNetwork(n_actions=n_actions)
        self.compile_model(self.network, self.lr)

        # create transition table
        self.transitions = TransitionTable(histLen=self.hist_len, maxSize=self.replay_memory)

        self.numSteps = 0  # number of perceived states
        self.lastState = None
        self.lastAction = None
        self.lastTerminal = None

        self.valid_s = None
        self.valid_a = None
        self.valid_r = None
        self.valid_s2 = None
        self.valid_term = None

    def compile_model(self, model, lr=0.00025):
        optimizer = RMSprop(lr=lr)
        model.compile(
            loss='mse',
            optimizer=optimizer,
            metrics=['accuracy', 'mse'],
        )

    def reset(self, state):
        # TODO 9 Low-priority
        pass

    def preprocess(self, rawstate):  # DONE
        state = np.mean(rawstate, axis=2, dtype=np.uint8)
        state = state[::2, ::2]
        # turn grayscale image to binary image
        # _img = np.where(_img == 0, 0, 255).astype(np.uint8)
        return state

    def getQUpdate(self, s, a, r, s2, term):  # DOME 2
        # merge `s` and `s2` together for one forward pass

        term = (term * -1) + 1
        # `s` and `s2` have to have the same shape
        assert s.shape == s2.shape
        forward_batch = np.concatenate((s, s2), axis=0)

        # I only scale values between [0..1] at the last step to reduce memory usage
        q_batch = self.network.predict(forward_batch / 255.0)
        mid_point = s.shape[0]

        # compute max_a Q(s_2, a)
        q2_max = np.max(q_batch[mid_point:], axis=1)

        # compute q2 = (1-terminal) * gamma * max_a Q(s2,a)
        # this will zero out all the terminal state reward
        q2 = q2_max * self.discount  # Matrix * number
        q2 = q2 * term  # Matrix * Matrix

        # Q(s,a) = Q(s,a) + delta
        delta = q2 + r  # Matrix + Matrix

        # (Lua) local q_all = self.network:foward(s):float()
        q_all = q_batch[:mid_point]
        # (Lua) q = torch.FloatTensor(q_all:size(1))
        q = np.zeros_like(q_all.shape[0])
        for i in range(len(q_all.shape[0])):
            q[i] = q_all[i][a[i]]

        # NeuralQLearner.lua:222
        # (Lua) delta:add(-1, q)
        """
        Lua Torch documentations (maths.md)

        ```
        > x = torch.Tensor(2, 2):fill(2)
        > y = torch.Tensor(4):fill(3)
        > x:add(2, y)
        > x
        8  8
        8  8
        [torch.DoubleTensor of size 2x2]
        ```
        """
        delta = delta - q
        delta
        targets = q_all
        for i in range(min(self.minibatch_size, len(a))):
            targets[i][a[i]] = target[i]

        return targets, delta, q2_max

    def qLearnMinibatch(self, verbose=0):
        # TODO accumulate losses instead of update rightaway
        # Perform a minibatch Q-learning update:
        # w += alpha * (r + gamma max Q(s2,a2) - Q(s,a)) * dQ(s,a)/dw
        assert self.transitions.size() > self.minibatch_size

        s, a, r, s2, term = self.transitions.sample(self.minibatch_size)

        targets, delta, q2_max = self.getQUpdate(s, a, r, s2, term)

        # DONE 2 what is `targets, q2_max`
        # `targets = Q'(s)` with `Q'(s,a) = Q(s,a) + r + gamma * max_a Q(s2)`
        # targets.shape = (batch_size, n_action)

        # `q2_max` is `max_a Q(s2)`
        # q2_max.shape = (batch_size)

        self.network.fit(
            x=(s / 255.0),
            y=targets,
            epochs=1,
            batch_size=self.minibatch_size,
            verbose=verbose,
        )

    def sample_validation_data(self):  # DONE 9
        # for validation
        s, a, r, s2, term = self.transitions.sample(self.valid_size)
        self.valid_s = s
        self.valid_a = a
        self.valid_r = r
        self.valid_s2 = s2
        self.valid_term = term

    def compute_validation_statistics(self):  # TODO 9
        # for validation
        targets, delta, q2_max = self.getQUpdate(
            s=self.valid_s,
            a=self.valid_a,
            r=self.valid_r,
            s2=self.valid_s2,
            term=self.valid_term,
        )
        avg_loss = delta.mean()

        return avg_loss

    def perceive(self, reward, rawstate, terminal, testing=False, testing_ep=None, verbose=0):  # DONE 1
        """
        reward : number
            The received reward from environment.

        rawstate : ndarray
            The game screen.

        terminal : int
            If the game end then `terminal = 1` else `terminal = 0`.

        testing_ep : number
            Testing epsilon value for the epsilon-greedy algorithm.
        """
        # preprocess state
        state = self.preprocess(rawstate)

        # clip reward
        if self.max_reward is not None:
            reward = min(reward, self.max_reward)

        if self.min_reward is not None:
            reward = max(reward, self.min_reward)

        self.transitions.add_recent_state(state, terminal)

        currentFullState = self.transitions.get_recent()

        # store transition s, a, r, s'
        if (self.lastState is not None) and not testing:
            self.transitions.add(self.lastState, self.lastAction, reward, self.lastTerminal)

        curState = self.transitions.get_recent()  # curState.shape == (4, 105, 80)
        # convert to batch (1, 4, 105, 80)
        curState = np.array([curState], dtype=np.uint8)

        # select action
        action = 0
        if not terminal:
            action = self.eGreedy(curState, testing_ep)

        # do some Q-learning updates
        if (self.numSteps > self.learn_start) and (not testing) and (self.numSteps % self.update_freq == 0):
            for i in range(self.n_replay):
                self.qLearnMinibatch(verbose=verbose)

        if not testing:
            self.numSteps += 1

        self.lastState = state
        self.lastAction = action
        self.lastTerminal = terminal

        return action

    def eGreedy(self, state, testing_ep=None):  # DONE 3
        """
        testing_ep : testing epsilon
        """
        if testing_ep is None:
            ep_range = self.ep_start - self.ep_end
            ep_prog = 1.0 - max(0, self.numSteps - self.learn_start) / self.ep_endt
            ep_delta = ep_range * ep_prog
            self.ep = self.ep_end + max(0, ep_delta)
        else:
            self.ep = testing_ep

        if random.random() < self.ep:
            return random.randrange(0, self.n_actions)
        else:
            return self.greedy(state)

    def greedy(self, state):  # DONE 6
        q = self.network.predict(state / 255.0)[0]
        max_q = q[0]
        best_a = [0]

        # evaluate all other actions (with random tie-breaking)
        for a in range(1, self.n_actions):
            if q[a] > max_q:
                best_a = [a]
                max_q = q[a]
            elif q[a] == max_q:
                best_a.append(a)
        # random tie-breaking
        r = random.randrange(0, len(best_a))
        self.lastAction = best_a[r]
        return best_a[r]

    def createNetwork(self, input_shape=(4, 105, 80), n_actions=4):
        model = keras.Sequential([
            Conv2D(
                filters=32,
                kernel_size=8,
                strides=4,
                activation='relu',
                input_shape=(*input_shape, ),
                data_format='channels_first',
            ),
            Conv2D(
                filters=64,
                kernel_size=4,
                strides=2,
                activation='relu',
                data_format='channels_first',
            ),
            Conv2D(
                filters=64,
                kernel_size=3,
                strides=1,
                activation='relu',
                data_format='channels_first',
            ),
            Flatten(),
            Dense(
                units=512,
                activation='relu',
            ),
            Dense(
                units=n_actions,
                activation='linear',
            ),
        ])

        return model