# Image iterators using the multiprocessing module
*Jonas Teuwen*

In this notebook we show how to combine the multiprocessing module in Python with iterators.

This can be useful in deep learning scripts when you, for instance, want to write an iterator which extracts and augments patches from your image on-the-fly to feed to your convolutional neural network (CNN).

We simulate the following problem:
- You have a list of images, in this case represented by filenames in `self.images`.
- We continuously load an image, and put these into a queue, ready to read.
- Next should give your next image, in this case we only output the filename.

Each process will load one of the images, but as the processes are now separated we need some way to track which images have already been passed and which ones have not. To implement such a counter we can use shared memory in python. The multiprocessing `RawValue` implements such a `ctype` which allows multiple processes to read from the same variable. There is one problem: race conditions. It might happen that the counter value has not been updated yet and then the next process will output the same image. To handle this we use `Lock` to lock the counter when we either read or write to it to prevent such conditions.

Note that your output will not necessarily be in the same order as your input as some of the processes processes might complete faster. For CNNs with stochastic gradient decent this is definitely no problem as shuffling improves the result. Check out the Stochastic Gradient Descent Tricks (Sec. 4) at http://cilvr.cs.nyu.edu/diglib/lsml/bottou-sgd-tricks-2012.pdf. 

For more information check https://docs.python.org/2/library/multiprocessing.html.

In [1]:
import numpy as np
from multiprocessing import Process, Value, Lock, Queue
from __future__ import division  # This is needed so the division of integers is a float as in Python 3.
import math

Basic imports. In CNNs you often wish to pass batches of images through the network. We do this by first computing the number of batches we will need. The final batch does not need to have the same size as the others. Often, in the neural network frameworks it is then possible to skip the "missing" values. This only requires some small changes.
Be careful about shared variables between processes!

In [2]:
class Iter(object):
    def __init__(self,batch_size):
        self.images = ['image_{}.jpg'.format(i) for i in range(20)]
        self.n_images = len(self.images)
        self.batch_size = batch_size
        self.n_batches = int(math.ceil(self.n_images / batch_size))
        self.cur_batch = 0
    
        self.reset()
            
    def reset(self):
        self.shared_val = Value('i', 0)
        self.lock = Lock()
        self.q = Queue(maxsize=2)
        self.procs = [Process(target=self.write, args=(self.shared_val, self.lock)) for i in range(4)]
        for p in self.procs:
            p.daemon = True
            p.start()
    
    def get_batch(self, offset):
        i = 0
        batch = []
        while i < self.batch_size:
            if offset + i < self.n_images:
                batch.append(self.images[offset + i])
                i += 1
            else:
                break
        return batch
            
    def write(self, v, lock):
        while True:
            with lock:
                idx = v.value
                v.value += self.batch_size

            if idx < self.n_images:
                batch = self.get_batch(idx)
                self.q.put(batch, block=True)
    
    def next(self):
        self.cur_batch += 1
        if self.cur_batch > self.n_batches:

            raise StopIteration()
        if self.q.empty():
            # This is called when the queue is empty, for instance, when waiting for data.
            pass
        return self.q.get()

Let us test. Remember there are 10 images.

In [None]:
s = Iter(2)
vals = []
for i in range(10):
    vals.append(s.next())
vals

Or a batch size of 3, which does not divide the total number of images.

In [None]:
s = Iter(3)
vals = []
for i in range(6):
    vals.append(s.next())
vals

What about when we want to access too many of the images?

In [None]:
s = Iter(4)
vals = []
for i in range(6):
    vals.append(s.next())
vals

Good!