In [1]:
#from data.iterators import PolicyIterator

from data.collectors import EpisodeCollector
import numpy as np
from utils import setup_args
import random

args = setup_args()

args.env_name = "Pong-v0"
args.model_name = "vae"
dc = EpisodeCollector(args)

pygame 1.9.4
Hello from the pygame community. https://www.pygame.org/contribute.html
couldn't import doomish
Couldn't import doom


In [10]:
def get_frame_action_frame(ep_ind, frame_ind,num = 1, stride=1):
    ep = episodes[ep_ind]
    frames = []
    actions = []
    for _ in range(num):
        frame = ep._asdict()["xs"][frame_ind]
        action = ep._asdict()["actions"][frame_ind]
        actions.append(action)
        frames.append(frame)
        frame_ind += stride
    frame = ep._asdict()["xs"][frame_ind]
    frames.append(frame)
    
    trans = dc.get_transition_constructor()(xs=frames, actions=actions)
    return trans

In [34]:
from collections import namedtuple
import torch
import numpy as np
import random
from functools import partial
from data.collectors import EpisodeCollector
from functools import partial
import copy
from data.utils import setup_env, convert_frames,convert_frame
import math

class DataSampler(object):
    """buffer of episodes. you can sample it like a true replay buffer (with replacement) using self.sample
    or like normal data iterator used in most supervised learning problems with sellf.__iter__()"""
    """Memory is uint8 to save space, then when you sample it converts to float tensor"""
    def __init__(self,args, batch_size=64):
        self.args = args
        self.stride = self.args.stride
        self.num_frames = self.args.frames_per_trans
        self.DEVICE = self.args.device
        self.batch_size = batch_size
        self.episodes = []                
   
    def push(self, episode_trans):
        """Saves a transition."""
        self.episodes.append(episode_trans)

    def sample(self,batch_size=None, with_replacement=False):
        batch_size = self.batch_size if batch_size is None else batch_size
        
        #sample indices into frames
        all_inds = self.get_all_inds()
        all_arr = np.stack(random.choices(all_inds, k=batch_size))
        ep_inds, frame_inds = all_arr[:,0], all_arr[:,1]
        
        raw_sample = self.raw_sample(batch_size, ep_inds, frame_inds)
        batch = self._convert_raw_sample(raw_sample)
        return batch

    
    def get_all_inds(self):
        ep_lens = {i:len(ep) for i in  range(self.num_episodes)}
        all_possible_inds = np.concatenate([[(ep_ind, frame_ind) 
                                             for frame_ind in range(ep_lens[ep_ind] - self.stride)] 
                                            for ep_ind in range(self.num_episodes)])
        return all_possible_inds

    
    def raw_sample(self, ep_inds, frame_inds):
        transitions = [self._sample(ep_ind,frame_ind, num, stride) for 
                                    ep_ind, frame_ind in zip(ep_inds,frame_inds)]
        return transitions
        
    
    
    def _sample(self,ep_ind,frame_ind, num):
        raise NotImplementedError
    
    def _convert_raw_sample(self,transitions):
        """converts 8-bit RGB to float and pytorch tensor"""
        # puts all trans objects into one trans object
        trans = self._combine_transitions_into_one_big_one(transitions)
        batch = self._convert_fields_to_pytorch_tensors(trans)
        return batch
        
    def _combine_transitions_into_one_big_one(self,transitions):
        fields = []
        for i,field in enumerate(zip(*transitions)):
            if isinstance(field[0],list):
                new_field = np.stack([list_ for list_ in field])
                if str(new_field.dtype) == "bool":
                    new_field = new_field.astype("int")
                #print(field.shape,field)
            if isinstance(field[0],dict):
                new_field = {}
                for k in field[0].keys():
                    all_items_of_key_k = [dic[k] for dic in field]
                    array_of_items_of_key_k = np.stack([list_ for list_ in all_items_of_key_k])
                    new_field[k] = array_of_items_of_key_k

            fields.append(new_field)
        return Transition(*fields)
    
    def _convert_fields_to_pytorch_tensors(self,trans):
        tb_dict = trans._asdict()
        if "state_param_dict" in tb_dict:
            for k,v  in trans.state_param_dict.items():
                tb_dict["state_param_dict"][k] = torch.tensor(v).to(self.DEVICE)
                
        
        tb_dict["xs"] = torch.stack([convert_frames(np.asarray(trans.xs[i]),to_tensor=True,resize_to=(-1,-1)) for
                                                     i in range(len(trans.xs))]).to(self.DEVICE)
        
        if "actions" in tb_dict:
            tb_dict["actions"] = torch.from_numpy(np.asarray(trans.actions)).to(self.DEVICE)
        if "rewards" in tb_dict:
            tb_dict["rewards"] = torch.from_numpy(np.asarray(trans.rewards)).to(self.DEVICE)

        
        
        batch = Transition(*list(tb_dict.values()))
        return batch

    def __iter__(self):
        """Iterator that samples without replacement for replay buffer
        It's basically like a standard sgd setup
        If you want to sample with replacement like standard replay buffer use self.sample"""
        all_inds = self.get_all_inds()
        random.shuffle(all_inds)
        size = len(all_inds)
        for st in range(0, size, self.batch_size):
            end = st+self.batch_size if st+self.batch_size <= size else size
            batch_inds = np.stack(all_inds[st:end])
            ep_inds, frame_inds = batch_inds[:,0], batch_inds[:,1]
            raw_sample = self.raw_sample(ep_inds, frame_inds)
            yield self._convert_raw_sample(raw_sample)
    
    def __len__(self):
        return len(self.episodes)
    
    @property
    def num_episodes(self):
        return self.__len__()
    

In [35]:
class FramesSampler(DataSampler):
    def __init__(self,args,batch_size):
        super(FramesSampler,self).__init__(args, batch_size)
    
    def _sample(self,ep_ind,frame_ind, num=1, stride=1):
        ep = episodes[ep_ind]
        frames_to_go = len(ep.xs)  - frame_ind 
        frames_covered = num*stride
        diff = frames_to_go - frames_covered
        if diff < 0:
            #print(frame_ind, diff)
            frame_ind += diff
        frames = []
        for _ in range(num):
            frame = ep._asdict()["xs"][frame_ind]
            frames.append(frame)
            frame_ind += stride
        trans = dc.get_transition_constructor()(xs=frames)
        return trans
    
    

In [36]:
fs = FramesSampler(args,32)

In [37]:
fs.sample()

ValueError: need at least one array to concatenate

In [30]:
super(DataSampler,fs).__init__(args,32)

TypeError: object.__init__() takes no parameters

[0;31mType:[0m            super
[0;31mString form:[0m     <super: <class 'DataSampler'>, <FramesSampler object>>
[0;31mDocstring:[0m       The most base type
[0;31mClass docstring:[0m
super() -> same as super(__class__, <first argument>)
super(type) -> unbound super object
super(type, obj) -> bound super object; requires isinstance(obj, type)
super(type, type2) -> bound super object; requires issubclass(type2, type)
Typical use to call a cooperative superclass method:
class C(B):
    def meth(self, arg):
        super().meth(arg)
This works for class methods too:
class C(B):
    @classmethod
    def cmeth(cls, arg):
        super().cmeth(arg)


In [56]:
all_arr = np.stack(random.choices(all_possible, k=32))

ep_inds, frame_inds = all_arr[:,0], all_arr[:,1]

In [None]:
    
    

class BufferFiller(object):
    """creates and fills replay buffers with transitions"""
    def __init__(self,args,policy=None):
        #self.env = env
        self.args = args
        self.policy=policy

    def make_empty_buffer(self):
        return ReplayMemory(batch_size=self.args.batch_size, args=self.args)
    
    def fill(self,size):
        """fill with transitions by just following a policy"""
        buffer = self.make_empty_buffer()
        collector = EpisodeCollector(args=self.args,policy=self.policy)
        buffer = self._fill(size, collector, buffer)
        return buffer 
    
    def _fill(self,size, collector, buffer):
        cur_size = 0
        while cur_size < size:
            episode = collector.collect_episode_per_the_policy()
            cur_size += len(episode.xs)
            buffer.push(episode)
        return buffer