### based on https://github.com/higgsfield/RL-Adventure and https://medium.com/swlh/introduction-to-reinforcement-learning-coding-sarsa-part-4-2d64d6e37617

In [None]:
%load_ext cython

In [None]:
import resource
resource.setrlimit(resource.RLIMIT_RSS, (32 << 30, 32 << 30))
#resource.getrusage(resource.RUSAGE_SELF)

In [None]:
%matplotlib inline
import collections
import cv2
import gym
import matplotlib.pyplot as plot
import numpy as np
import random
import time
import torch as t
from IPython.display import clear_output

In [None]:
class LazyFrames(object):
    def __init__(self, frames):
        """This object ensures that common frames between the observations are only stored once.
        It exists purely to optimize memory usage which can be huge for DQN's 1M frames replay
        buffers.
        This object should only be converted to numpy array before being passed to the model.
        You'd not belive how complex the previous solution was."""
        self._frames = frames

    def __array__(self, dtype=None):
        out = np.concatenate(self._frames, axis=0)
        if dtype is not None:
            out = out.astype(dtype)
        return out

class ImageToPyTorch(gym.ObservationWrapper):
    """
    Change image shape to CWH
    """
    def __init__(self, env):
        super(ImageToPyTorch, self).__init__(env)
        old_shape = self.observation_space.shape
        self.observation_space = gym.spaces.Box(low=0.0, high=1.0, shape=(old_shape[-1], old_shape[0], old_shape[1]))

    def observation(self, observation):
        return observation.transpose(2, 0, 1)
    
class FrameStack(gym.Wrapper):
    def __init__(self, env, k):
        """Stack k last frames.
        Returns lazy array, which is much more memory efficient.
        See Also
        --------
        baselines.common.atari_wrappers.LazyFrames
        """
        gym.Wrapper.__init__(self, env)
        self.k = k
        self.frames = collections.deque([], maxlen=k)
        shp = env.observation_space.shape
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(shp[0]*k, shp[1], shp[2]))

    def reset(self):
        ob = self.env.reset()
        for _ in range(self.k):
            self.frames.append(ob)
        return self._get_ob()

    def step(self, action):
        ob, reward, done, info = self.env.step(action)
        self.frames.append(ob)
        return self._get_ob(), reward, done, info

    def _get_ob(self):
        assert len(self.frames) == self.k
        return LazyFrames(list(self.frames))

class ResizeObservation(gym.ObservationWrapper):
    def __init__(self, env):
        super(ResizeObservation, self).__init__(env)
        shp = env.observation_space.shape
        self.observation_space = gym.spaces.Box(low=0, high=255, shape=(shp[0] // 2, shp[1] // 2, shp[2]))
        self.resize_to = (shp[1] // 2, shp[0] // 2)

    def observation(self, observation):
        return cv2.resize(observation, self.resize_to, interpolation=cv2.INTER_AREA)

SUFFIX = 'NoFrameskip-v4'
env = gym.make('Pong' + SUFFIX)
env = ResizeObservation(env)
env = ImageToPyTorch(env)
env = FrameStack(env, 4)

In [None]:
USE_CUDA = t.cuda.is_available()# and False
device = t.device('cuda') if USE_CUDA else t.device('cpu')

In [None]:
class Actor(object):
    def __init__(self, env, model, eps, eps_final, eps_steps, initial_explore=0):
        self.env = env
        self.model = model
        self.eps = eps
        self.eps_final = eps_final
        self.eps_decay = np.exp(np.log(eps_final / eps) / eps_steps)
        self.initial_explore = initial_explore
    
    def act(self, state):
        if self.initial_explore > 0:
            self.initial_explore -= 1
            return self.env.action_space.sample()
        self.eps = max(self.eps_final, self.eps * self.eps_decay)
        if random.random() < self.eps:
            return self.env.action_space.sample()
        self.model.eval()
        state = t.FloatTensor(np.array(state)).to(device)
        q = self.model(state)
        return q.argmax().item()

In [None]:
class Model(t.nn.Module):
    def __init__(self, input_shape, input_frames, n_out):
        super().__init__()
        self.cnn = t.nn.Sequential(
            t.nn.Conv2d(3 * input_frames, 32, kernel_size=8, stride=4),
            t.nn.PReLU(),
            t.nn.Conv2d(32, 64, kernel_size=4, stride=2),
            t.nn.PReLU(),
            t.nn.Conv2d(64, 64, kernel_size=3, stride=1),
            t.nn.PReLU(),
        ) # -> 64 9 6
        cnn_fc = self.feature_size(self.cnn, input_shape)
        self.fc = t.nn.Sequential(
            t.nn.Linear(cnn_fc, 512),
            t.nn.PReLU(),
            t.nn.Linear(512, n_out)
        )
        self.apply(self.weights_init)
    
    def feature_size(self, cnn, shape):
        return cnn(t.zeros(1, *shape)).view(1, -1).size(1)

    def weights_init(self, m):
        if isinstance(m, t.nn.Linear):
            t.nn.init.kaiming_normal_(m.weight, 2)
            t.nn.init.constant_(m.bias, 0)
        elif isinstance(m, t.nn.Conv2d):
            t.nn.init.kaiming_normal_(m.weight, 2)
            t.nn.init.constant_(m.bias, 0)
        
    def forward(self, x):
        if len(x.shape) < 4:
            x = x.unsqueeze(0)
        p = self.cnn(x)
        p = p.view(p.size(0), -1)
        return self.fc(p)

In [None]:
%%cython
# distutils: language = c++
cimport numpy as np
import numpy as np

cdef class SumTree(object):
    cdef size_t capacity
    cdef np.ndarray tree
    cdef size_t data_pointer
    cdef np.ndarray data

    def __init__(self, capacity : int):
        self.capacity = capacity
        self.tree = np.zeros(capacity * 2 - 1, dtype=np.float32)
        self.data_pointer = 0
        self.data = np.zeros(capacity, dtype=object)
    
    def add(self, p : float, data : object):
        cdef size_t tree_idx
        tree_idx = self.data_pointer + self.capacity - 1
        self.data[self.data_pointer] = data
        self.update(tree_idx, p)
        
        self.data_pointer += 1
        if self.data_pointer >= self.capacity:
            self.data_pointer = 0
    
    def update(self, tree_idx : size_t, p : float):
        cdef float change = p - self.tree[tree_idx]
        self.tree[tree_idx] = p
        while tree_idx != 0:
            tree_idx = (tree_idx - 1) // 2
            self.tree[tree_idx] += change

    def get_total_p(self) -> float:
        return self.tree[0]
    
    def get_leaf(self, v : float):
        cdef size_t parent_idx = 0
        cdef size_t cl_idx
        cdef size_t cr_idx
        while True:
            cl_idx = 2 * parent_idx + 1
            cr_idx = cl_idx + 1
            if  cl_idx >= self.capacity * 2 - 1:
                leaf_idx = parent_idx
                break
            else:
                if v <= self.tree[cl_idx] or self.tree[cr_idx] == 0.0:
                    parent_idx = cl_idx
                else:
                    v -= self.tree[cl_idx]
                    parent_idx = cr_idx
        data_idx = leaf_idx - self.capacity + 1
        return leaf_idx, self.tree[leaf_idx], self.data[data_idx]
    
    def get_max_p(self) -> float:
        return np.max(self.tree[-self.capacity:])
    
    def get_min_p(self) -> float:
        return np.min(self.tree[-self.capacity:])
    
    def get_capacity(self) -> size_t:
        return self.capacity

class PriorExpReplay(object):
    def __init__(self, capacity : size_t):
        self.tree = SumTree(capacity)
        self.beta = 0.4

    def store(self, transition : object):
        cdef float max_p = self.tree.get_max_p()
        if max_p == 0.0:
            max_p = 1.0
        self.tree.add(max_p, transition)
        
    def sample(self, n : size_t):
        cdef int [:] b_idx = np.empty((n,), dtype=np.int32)
        cdef object[:] b_memory = np.empty((n,), dtype=object)
        cdef float[:] ISWeights = np.empty((n,), dtype=np.float32)
        cdef float pri_seg = self.tree.get_total_p() / float(n)
        cdef float min_prob, prob
        self.beta = np.min([1., self.beta + 0.001])
        
        min_prob = self.tree.get_min_p() / self.tree.get_total_p()
           
        cdef float a, b, v, p
        cdef object data
        cdef size_t i, idx
        for i in range(n):
            a = pri_seg * i
            b = pri_seg * (i + 1)
            v = np.random.uniform(a, b)
            idx, p, data = self.tree.get_leaf(v)
            prob = p / self.tree.get_total_p()
            ISWeights[i] = np.power(prob/min_prob, -self.beta)
            b_idx[i], b_memory[i] = idx, data
        return b_idx, b_memory, ISWeights
    
    def batch_update(self, tree_idx : np.ndarray, abs_errors : np.ndarray):
        cdef np.ndarray clipped_errors = np.clip(abs_errors, 0.01, 1.0)
        cdef np.ndarray ps = np.power(clipped_errors, 0.6)
        cdef size_t i
        for i in range(len(ps)):
            self.tree.update(tree_idx[i], ps[i])

In [None]:
class Replay(object):
    def __init__(self, maxlen):
        self.memory = PriorExpReplay(maxlen)
    def __len__(self):
        return len(self.memory)
    def add(self, state, action, next_state, reward, done):
        self.memory.store((state, action, next_state, reward, done))
    def sample(self, n):
        with t.no_grad():
            indices, memory, weights = self.memory.sample(n)
            states, actions, next_states, rewards, masks = zip(*memory)
            actions = t.LongTensor(actions).to(device)
            rewards = t.FloatTensor(rewards).to(device)
            masks = 1 - t.FloatTensor(masks).to(device)
            states = Replay.stack_states(states)
            next_states = Replay.stack_states(next_states)
            return states, actions, next_states, rewards, masks
    def batch_update(self, indices, abs_errors):
        self.memory.batch_update(indices, abs_errors)
    @staticmethod
    def stack_states(states):
        s = np.concatenate([np.expand_dims(x, 0) for x in states])
        return t.ByteTensor(s).to(device).float()