In [1]:
import torch.multiprocessing as mp
import threading as th
import gym
import envs
import random
from torch.multiprocessing import Queue, Lock
import numpy as np
import torch
import time
from enum import Enum
from functools import reduce
from operator import add
import collections

In [2]:
# class Transitions():
#     def __init__(self, size=0):
#         self.size = size
#         self._data = []
        
#     def merge(self, item):
#         if isinstance(item[0], tuple):
#             return (self.merge(i) for i in item)
#         elif isinstance(item[0], torch.Tensor):
#             return torch.cat(item, axis=0)
#         else:
#             return item
    
#     @property
#     def collated(self):
#         values = map(self.merge, zip(*[transition._asdict().values()
#                                        for transition in self]))
        
#         return BatchedTransitions(*values)
    
#     def sample(self, n_items):
#          return Transitions().extend(
#             random.choices(self._data, k=n_items)
#         )
    
#     def push(self, transition):
#         self._data.append(transition)
        
#         if self.size != 0 and len(self._data) > self.size:
#             self._data.pop(0)
            
#     def extend(self, arr):
#         self._data += arr
#         return self
            
#     def __len__(self):
#         return len(self._data)
    
#     def __iter__(self):
#         self._idx = -1
#         return self

#     def __next__(self):
#         self._idx += 1
        
#         if self._idx < len(self._data):
#             return self._data[self._idx]
        
#         raise StopIteration
        

In [3]:
# Transition = namedtuple('Transition', ['state', 'next_state', 'objective', 'reward', 'done'])

In [4]:
# BatchedTransitionsBase = namedtuple('BatchedTransition', ['state', 'next_state', 'objective', 'reward', 'done'])

In [6]:
# class BatchedTransitions(BatchedTransitionsBase):
#     def __new__(cls, *args):
#         self = super(BatchedTransitions, cls).__new__(cls, *args)
#         return self
    
#     def pf(self, item):
#         if isinstance(item, torch.Tensor):
#             return item.pin_memory()
#         elif isinstance(item, tuple):
#             return tuple(self.pf(i) for i in item)
#         else:
#             return item
    
#     def prefetch(self):
#         return BatchedTransitions(*map(self.pf, self._asdict().values()))

In [2]:
class BatchSamplerStatus(Enum):
    READY = 1
    WAITING_TO_SAMPLE = 2
    SAMPLING = 3

class BatchSamplerInitiatorThread(th.Thread):
    def __init__(self, bs):
        super(BatchSamplerInitiatorThread, self).__init__()
        self.bs = bs
        
    def run(self):
        while True:
            if len(self.bs.batches) + self.bs.n_batches_sampling < 10:
                prop = [random.random() for _ in self.bs.explorers]
                prop = [int(p / sum(prop) * self.bs.BATCH_SIZE) for p in prop]
                prop[-1] = self.bs.BATCH_SIZE - sum(prop[:-1])
                
                for idx, (p, explorer) in enumerate(zip(prop, self.bs.explorers)):
                    if p > 0: explorer.instruction_queue.put((self.bs.current_sampler, p))
                    print(f'[INFO] Assigning Explorer {idx} to sample {p} transitions for batch {self.bs.current_sampler}')


                self.bs.n_batches_sampling += 1
                self.bs.current_sampler += 1
                self.bs.current_sampler %= self.bs.PREFETCH_BUFFER
                
#             for i in range(self.bs.PREFETCH_BUFFER):
#                 if self.bs.batch_status[i] != BatchSamplerStatus.WAITING_TO_SAMPLE:
#                     continue

#                 self.bs.batch_status[i] = BatchSamplerStatus.SAMPLING

#                 prop = [random.random() for _ in self.bs.explorers]
#                 prop = [int(p / sum(prop) * self.bs.BATCH_SIZE) for p in prop]
#                 prop[-1] = self.bs.BATCH_SIZE - sum(prop[:-1])
                
#                 for idx, (p, explorer) in enumerate(zip(prop, self.bs.explorers)):
#                     if p > 0: explorer.instruction_queue.put((i, p))
#                     print(f'[INFO] Assigning Explorer {idx} to sample {p} transitions for batch {i}')
                    
class BatchSamplerAssignmentThread(th.Thread):
    def __init__(self, bs):
        super(BatchSamplerAssignmentThread, self).__init__()
        self.bs = bs
        
    def run(self):
        while True:
            if not self.bs.int_queue.empty():
                print(f'[INFO] Aggregator received new batch')
                batch_id, batch = self.bs.int_queue.get()
              
                self.bs.batches.append(batch)
                self.bs.n_batches_sampling -= 1

#                 self.bs.batches[batch_id] = batch
#                 self.bs.batch_status[batch_id] = BatchSamplerStatus.READY
#                 print(f'[INFO] Batch {batch_id} ready.')
                
            time.sleep(1)
                    
class BatchSamplerProcessingProcess(mp.Process):
    def __init__(self, batch_size, in_queue, out_queue, prefetch):
        super(BatchSamplerProcessingProcess, self).__init__()
        self.batch_size = batch_size
        
        self.transitions = []
        
        self.in_queue = in_queue
        self.out_queue = out_queue
        
        self.prefetch = prefetch
        self.n_transitions = 0
        
    def merge(self, item, prefetch=False):
        if len(item) == 1:
            return item[0]
    
        if isinstance(item[0], tuple):
            return tuple(self.merge(i, prefetch=prefetch) for i in zip(*item))
        elif isinstance(item[0], torch.Tensor):
            if prefetch:
                return torch.cat(item, axis=0).pin_memory()
            else:
                return torch.cat(item, axis=0)
        elif isinstance(item[0], list):
            return list(reduce(add, item))
        else:
            return list(item)
        
    def run(self):
        while True:
            if not self.in_queue.empty():
                batch_id, data = self.in_queue.get()
                
                self.transitions.append(data)
                self.n_transitions += len(data[-1])
                
                print(f'[INFO] Batch {batch_id} received new data {self.n_transitions}')

                if self.n_transitions == self.batch_size:
                    batch = self.merge(self.transitions, prefetch=self.prefetch)
                    self.out_queue.put((batch_id, batch))

                    self.transitions = []
                    self.n_transitions = 0
        
class BatchSampler:
    PREFETCH_BUFFER = 10
    BATCH_SIZE = 64
    
    def __init__(self, explorers, exp2bs_queues, prefetch=True):
        self.explorers = explorers
        self.current = 0
        self.exp2bs_queues = exp2bs_queues
        self.int_queue = Queue()
        
        self.prefetch = prefetch
        
#         self.batches = [[] for _ in range(self.PREFETCH_BUFFER)]
        self.batches = collections.deque()
        self.batch_status = [
            BatchSamplerStatus.WAITING_TO_SAMPLE
            for _ in range(self.PREFETCH_BUFFER)
        ]
        self.n_batches_sampling = 0
        self.current_sampler = 0
        
        # Thread to monitor sampling progress
        for batch_id in range(self.PREFETCH_BUFFER):
            BatchSamplerProcessingProcess(
                self.BATCH_SIZE, self.exp2bs_queues[batch_id], self.int_queue, self.prefetch
            ).start()
        
        BatchSamplerAssignmentThread(self).start()
        
        # Thread to initiate sampling
        BatchSamplerInitiatorThread(self).start()
        
    def __iter__(self):
        self.current = 0
        return self
    
    def __next__(self):
        while len(self.batches) == 0:
            pass
        
        return self.batches.popleft()
#         while self.batch_status[self.current] != BatchSamplerStatus.READY:
#             pass
        
#         batch = self.batches[self.current]
        
#         self.batches[self.current] = []
#         self.batch_status[self.current] = BatchSamplerStatus.WAITING_TO_SAMPLE
        
#         self.current += 1
#         self.current %= self.PREFETCH_BUFFER
        
#         return batch

In [3]:
class ExplorerIOThread(th.Thread):
    def __init__(self, process):
        super(ExplorerIOThread, self).__init__()
        
        self.process = process
        
    def merge(self, item, prefetch=False):
        if len(item) == 1:
            return item[0]
    
        if isinstance(item[0], tuple):
            return tuple(self.merge(i, prefetch=prefetch) for i in zip(*item))
        elif isinstance(item[0], torch.Tensor):
            if prefetch:
                return torch.cat(item, axis=0).pin_memory()
            else:
                return torch.cat(item, axis=0)
        elif isinstance(item[0], list):
            return list(reduce(add, item))
        else:
            return list(item)
    
    def run(self):
        while True:
            if not self.process.instruction_queue.empty():
                ins = self.process.instruction_queue.get()
                
                if ins is None:
                    print(f'[INFO] Process {self.process.idx} - thread: Joining...')        
                    self.process.stop = True

                    break
                    
                batch_id, n_items = ins
                    
                sampled = random.choices(self.process.transitions, k=n_items)
                self.process.result_queue[batch_id].put((batch_id, self.merge(sampled)))
                print(f'[INFO] Process {self.process.idx}: Sampled {n_items} transitions')        

class ExplorerProcess(mp.Process):
    def __init__(self, instruction_queue, result_queue, idx, env_name, **kwargs):
        super(ExplorerProcess, self).__init__()
        
        self.instruction_queue = instruction_queue
        self.result_queue = result_queue
        
        self.env = gym.make('StopSkip-v1', **kwargs)
        self.env.reset()
        
        self.idx = idx
        self.transitions = []
        
        self.stop = False
        
    def run(self):
        print(f'[INFO] Process {self.idx}: {id(self.transitions)}')
        ExplorerIOThread(self).start()
              
        j = 1
        
        s = self.env._obs()
        while not self.stop:
            action = self.env.action_space.sample()
            ns, o, r, d = self.env.step(action)
            
            self.transitions.append((s, ns, o, r, d))
                                              
            s = ns
            
            if len(self.transitions) > 100:
                self.transitions.pop(0)
            
            if j % 10 == 0:
                print(f'[INFO] Process {self.idx}: {len(self.transitions)} transitions')
                j = 0
            
            j += 1
            
        print(f'[INFO] Process {self.idx}: Joining...')

In [4]:
NUMBER_OF_PROCESSES = 1
PREFETCH_BUFFER = 20
explorers = []
bs2exp_queues = []
exp2bs_queue = [Queue() for _ in range(PREFETCH_BUFFER)]

for i in range(0, NUMBER_OF_PROCESSES):
    bs2exp_queues.append(Queue())

    p = ExplorerProcess(
        instruction_queue=bs2exp_queues[-1],
        result_queue=exp2bs_queue,
        idx=i,
        env_name='StopSkip-v1'
    )
    
    p.start()
    explorers.append(p)

[INFO] Process 0: 139849177679232
[INFO] Process 0: 10 transitions
[INFO] Process 0: 20 transitions
[INFO] Process 0: 30 transitions
[INFO] Process 0: 40 transitions
[INFO] Process 0: 50 transitions
[INFO] Process 0: Sampled 64 transitions
[INFO] Process 0: Sampled 64 transitions
[INFO] Process 0: Sampled 64 transitions
[INFO] Process 0: Sampled 64 transitions
[INFO] Process 0: Sampled 64 transitions
[INFO] Process 0: Sampled 64 transitions
[INFO] Process 0: Sampled 64 transitions
[INFO] Process 0: Sampled 64 transitions
[INFO] Process 0: Sampled 64 transitions
[INFO] Process 0: Sampled 64 transitions
[INFO] Process 0: 60 transitions
[INFO] Process 0: 70 transitions
[INFO] Process 0: 80 transitions
[INFO] Process 0: Sampled 64 transitions
[INFO] Process 0: 90 transitions
[INFO] Process 0: 100 transitions
[INFO] Process 0: 100 transitions
[INFO] Process 0: 100 transitions
[INFO] Process 0: 100 transitions
[INFO] Process 0: Sampled 64 transitions
[INFO] Process 0: Sampled 64 transitions


In [5]:
bs = BatchSampler(explorers, exp2bs_queue)

[INFO] Batch 0 received new data 64
[INFO] Assigning Explorer 0 to sample 64 transitions for batch 0
[INFO] Assigning Explorer 0 to sample 64 transitions for batch 1
[INFO] Assigning Explorer 0 to sample 64 transitions for batch 2
[INFO] Assigning Explorer 0 to sample 64 transitions for batch 3
[INFO] Assigning Explorer 0 to sample 64 transitions for batch 4
[INFO] Assigning Explorer 0 to sample 64 transitions for batch 5
[INFO] Assigning Explorer 0 to sample 64 transitions for batch 6
[INFO] Assigning Explorer 0 to sample 64 transitions for batch 7
[INFO] Assigning Explorer 0 to sample 64 transitions for batch 8
[INFO] Assigning Explorer 0 to sample 64 transitions for batch 9
[INFO] Batch 1 received new data 64
[INFO] Batch 2 received new data 64
[INFO] Batch 3 received new data 64
[INFO] Batch 4 received new data 64
[INFO] Batch 5 received new data 64
[INFO] Batch 6 received new data 64
[INFO] Batch 7 received new data 64
[INFO] Batch 8 received new data 64
[INFO] Batch 9 received ne

In [6]:
_ = next(bs)

In [7]:
t = time.time()

n_batches = 0
while time.time() - t < 30:
    _ = next(bs)
    time.sleep(0.1)
    
    n_batches += 1
    
t = time.time() - t

[INFO] Assigning Explorer 0 to sample 64 transitions for batch 1
[INFO] Assigning Explorer 0 to sample 64 transitions for batch 2
[INFO] Assigning Explorer 0 to sample 64 transitions for batch 3
[INFO] Assigning Explorer 0 to sample 64 transitions for batch 4
[INFO] Assigning Explorer 0 to sample 64 transitions for batch 5
[INFO] Aggregator received new batch
[INFO] Assigning Explorer 0 to sample 64 transitions for batch 6
[INFO] Assigning Explorer 0 to sample 64 transitions for batch 7
[INFO] Assigning Explorer 0 to sample 64 transitions for batch 8
[INFO] Assigning Explorer 0 to sample 64 transitions for batch 9
[INFO] Assigning Explorer 0 to sample 64 transitions for batch 0
[INFO] Assigning Explorer 0 to sample 64 transitions for batch 1
[INFO] Aggregator received new batch
[INFO] Assigning Explorer 0 to sample 64 transitions for batch 2
[INFO] Aggregator received new batch
[INFO] Assigning Explorer 0 to sample 64 transitions for batch 3
[INFO] Aggregator received new batch
[INFO] 

In [8]:
bpm = n_batches / t
print(f'{bpm:.2f} BPS')

[INFO] Aggregator received new batch
0.97 BPM
[INFO] Aggregator received new batch
[INFO] Aggregator received new batch
[INFO] Aggregator received new batch
[INFO] Assigning Explorer 0 to sample 64 transitions for batch 1
[INFO] Aggregator received new batch
[INFO] Aggregator received new batch
[INFO] Aggregator received new batch
