In [None]:
import random
import numpy as np
from collections import deque 

In [None]:
"""This file is copied/apdated from https://pytorch.org/tutorials/intermediate/reinforcement_q_learning.html"""

Transition = namedtuple('Transition',
                        ('state', 'action', 'next_state', 'reward'))


class ExperienceReplay:
    
    def __init__(self, capacity):
        self.memory = deque([],maxlen=capacity)

    def push(self, *args):
        """Save a transition"""
        self.memory.append(Transition(*args))

    def sample(self, batch_size):
        return random.sample(self.memory, batch_size)

    def __len__(self):
        return len(self.memory)
        

In [None]:
class ReplayBuffer:
    
    def __init__(self, size, screen_shape=(84, 84)):
        self.size = size
        self.screen_shape = screen_shape
        self.num_in_buffer = 0
        self.screens = deque(maxlen=self.size)
        self.actions = deque(maxlen=self.size)
        self.rewards = deque(maxlen=self.size)
#          self.next_screens = deque(maxlin=self.size)
        self.terminal = deque(maxlen=self.size)
        
    def push(self, screen, action, reward):
        self.screens.append(screen)
        self.actions.append(np.unit8(action))
        self.rewards.append(reward)
#         self.next_screens.append(next_screens)
        
        self.num_in_buffer = len(self.screens)
        
    def can_sample(self, batch_size):
        """Returns true if `batch_size` different transitions can be sampled from the buffer."""
        return batch_size + 1 <= self.num_in_buffer
    
    def _encode_sample(self, idxes):
        # Return batch data for screens, actions, rewards, next screens and terminal info
        # one screen state corresponding to one action by default, needing to consider grouped screens and actions
        obs_batch      = torch.from_numpy(np.concatenate([self.screens[idx] for idx in idxes], 0))
        act_batch      = torch.from_numpy(np.concatenate([self.actions[idx] for idx in idxes], 0))
        rew_batch      = torch.from_numpy(np.concatenate([self.rewards[idx] for idx in idxes], 0))
        next_obs_batch = torch.from_numpy(np.concatenate([self.screens[idx + 1] for idx in idxes], 0))
        done_mask      = np.array([1.0 if self.terminal[idx] else 0.0 for idx in idxes], dtype=np.float32)
        
        return obs_batch, act_batch, rew_batch, next_obs_batch, done_mask
        
    
    def sample(self, batch_size):
        assert self.can_sample(batch_size)
        inds = random.sample(range(self.num_in_buffer), batch_size)
        
        return self._encode_sample(inds)
        
        
        
        