In [1]:
from queue import Queue, Empty
from threading import Lock, Thread
from abc import ABC, abstractmethod

import logging

In [2]:
class DataSupplier(ABC):
    
    def __init__(self, num_enqueueing_threads=2):
        self._num_batches_remaining = 10
        
        self._batch_queue = Queue(maxsize=10) # FIXME - maxsize should be a param
        
        # Have we run out of data to enqueue?
        # If this variable is True, AND the Queue is empty, then we're done.
        self._data_exhausted = False
        
        self._next_batch_ready = False
        
        # This lock is used to ensure that only one consumer is dequeue-ing at any one time.
        self._iteration_lock = Lock()
        
        # This lock is used to ensure that only one thread is being assigned it batch params at any one time.
        self._batch_params_lock = Lock()
        self._queueing_threads = [Thread(target=self._enqueue, args=(), daemon=True)
                                  for _ in range(num_enqueueing_threads)]
        for thread in self._queueing_threads:
            thread.start()
            
    def __iter__(self):
        while True:
            try:
                yield self._next()
            except StopIteration:
                pass
        
    def _next(self):
        with self._iteration_lock: # Ensure that we are the only thread running this function
            if self._data_exhausted is False:
                # If we get here, there is still data to enqueue.
                # We can safely wait for the queue to be ready, since we are the only thing dequeueing.
                return self._batch_queue.get(block=True, timeout=None)
            else:
                # There is no data left to enqueue.
                try:
                    # Short pause to ensure that final batch has actually been enqueued.
                    return self._batch_queue.get(block=True, timeout=0.5)
                except Empty:
                    # Finished the iterator
                    raise StopIteration

    @abstractmethod
    def _get_next_batch_params(self):
        raise NotImplementedError
        
    # FIXME - need some way to reset the params iteration
    
    @abstractmethod
    def _get_batch_from_params(self, params):
        raise NotImplementedError        
            
    def _enqueue(self):
        """
        Function that gets run its own thread.
        Continues running until we run out of batches.
        """
        logging.info('foo')
        while self._data_exhausted is False:
            with self._batch_params_lock:
                batch_params = self._get_next_batch_params()
                print('Got params: {}'.format(batch_params))
            if batch_params is None:
                self._data_exhausted = True
            else:
                this_batch = self._get_batch_from_params(batch_params)
                self._batch_queue.put(this_batch, block=True, timeout=None)


In [3]:
class TestDataSupplier(DataSupplier):
    
    def __init__(self, num_batches=100, **kwargs):
        self._next_batch_idx = 0
        self._num_batches = 100
        super(TestDataSupplier, self).__init__(**kwargs)
    
    def _get_next_batch_params(self):
        if self._next_batch_idx < self._num_batches:
            self._next_batch_idx += 1
            return (self._next_batch_idx-1)
        else:
            return None
    
    def _get_batch_from_params(self, params):
        return params

In [4]:
tds = TestDataSupplier()

Got params: 0
Got params: 1
Got params: 2
Got params: 3
Got params: 4
Got params: 5
Got params: 6
Got params: 7
Got params: 8
Got params: 9
Got params: 10
Got params: 11


In [None]:
for batch in tds:
    print(batch)

0
1
2
3
4
5
6
7
8
9
Got params: 1210
11

Got params: 1312

Got params: 1413

Got params: 15
Got params: 16
Got params: 17
Got params: 18
Got params: 1914
15
16
17
18

19Got params: 20

20Got params: 21

21
Got params: 22
Got params: 23
Got params: 24
Got params: 25
Got params: 26
Got params: 27
Got params: 28
Got params: 29
Got params: 30
Got params: 31
Got params: 32
Got params: 3322
Got params: 34

23
24
25
26
27
28
29
30
31
33
Got params: 3532
34

35Got params: 36

36Got params: 37

37
Got params: 38
38Got params: 39

39
Got params: 40
Got params: 4140

41
Got params: 42
42Got params: 43

43
Got params: 44
44Got params: 45

Got params: 4645

46Got params: 47

47
Got params: 48
Got params: 49
Got params: 5048
49

50Got params: 51

51
Got params: 52
52Got params: 53

53
Got params: 54
54Got params: 55

55Got params: 56

56
Got params: 57
57Got params: 58

58
Got params: 59
59Got params: 60

60
Got params: 61
61Got params: 62

62
Got params: 63
63Got params: 64

64
Got params: 65
65Got