In [1]:
import sys
import pyspark
from utility import NodeLookup

In [2]:
# Settings for this notebook

MODEL_URL = 'http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz'
model_dir = './imagenet'

IMAGES_INDEX_URL = 'http://127.0.0.1/imagenet_data/urls/imagenet_fall11_urls.tgz'
images_read_limit = 1000  # Increase this to read more images

# Number of images per batch.
# 1 batch corresponds to 1 RDD row.
image_batch_size = 20

num_top_predictions = 5

In [3]:
import numpy as np
import tensorflow as tf
import os
from tensorflow.python.platform import gfile
import os.path
import re
import sys
import tarfile
from subprocess import Popen, PIPE, STDOUT

# Download the model

We download a pre-trained model or find a pre-downloaded one.

In [4]:
def maybe_download_and_extract():
    """Download and extract model tar file."""
    from six.moves import urllib
    dest_directory = model_dir
    if not os.path.exists(dest_directory):
        os.makedirs(dest_directory)
    filename = MODEL_URL.split('/')[-1]
    filepath = os.path.join(dest_directory, filename)
    if not os.path.exists(filepath):
        filepath2, _ = urllib.request.urlretrieve(MODEL_URL, filepath)
        print("filepath2", filepath2)
        statinfo = os.stat(filepath)
        print('Succesfully downloaded', filename, statinfo.st_size, 'bytes.')
        tarfile.open(filepath, 'r:gz').extractall(dest_directory)
    else:
        print('Data already downloaded:', filepath, os.stat(filepath))

maybe_download_and_extract()

Data already downloaded: ./imagenet/inception-2015-12-05.tgz os.stat_result(st_mode=33188, st_ino=18481302, st_dev=2066, st_nlink=1, st_uid=0, st_gid=0, st_size=88931400, st_atime=1540708338, st_mtime=1536174560, st_ctime=1536174560)


In [5]:
sc

# Load model data

Load the model data, and broadcast it for use on Spark workers.


In [6]:
model_path = os.path.join(model_dir, 'classify_image_graph_def.pb')
with gfile.FastGFile(model_path, 'rb') as f:
    model_data = f.read()

In [7]:
model_data_bc = sc.broadcast(model_data)

# Node lookups

Concepts (as represented by synsets, or groups of synomymous terms) have integer node IDs. This code loads a mapping from node IDs to human-readable strings for each synset.


In [9]:
node_lookup = NodeLookup().node_lookup
# Broadcast node lookup table to use on Spark workers
node_lookup_bc = sc.broadcast(node_lookup)

# Read index of image files

We load an index of image file URLs. We will parallelize this index. Spark workers will process batches of URLs in parallel by downloading the images and running TensorFlow inference on the images.


In [10]:
# Helper methods for reading images

def run(cmd):
    p = Popen(cmd, shell=True, stdin=PIPE, stdout=PIPE, stderr=STDOUT, close_fds=True)
    return p.stdout.read()

def read_file_index():
    from six.moves import urllib
    content = urllib.request.urlopen(IMAGES_INDEX_URL)
    data = content.read(images_read_limit)
    tmpfile = "./imagenet.tgz"
    with open(tmpfile, 'wb') as f:
        f.write(data)
    run("tar -xOzf %s > ./imagenet.txt" % tmpfile)
    with open("./imagenet.txt", 'r') as f:
        lines = [l.split() for l in f]
        input_data = [tuple(elts) for elts in lines if len(elts) == 2]
        return [input_data[i:i+image_batch_size] for i in range(0,len(input_data), image_batch_size)]

In [11]:
batched_data = read_file_index()
print("There are %d batches" % len(batched_data))

There are 1 batches


# Distributed image processing: TensorFlow on Spark

This section contains the main processing code. We first define methods which will be run as tasks on Spark workers. We then use Spark to parallelize the execution of these methods on the image URL dataset.


In [12]:
def run_inference_on_image(sess, img_id, img_url, node_lookup):
    """Download an image, and run inference on it.

    Args:
    image: Image file URL

    Returns:
    (image ID, image URL, scores),
    where scores is a list of (human-readable node names, score) pairs
    """
    from six.moves import urllib
    from urllib.request import HTTPError
    try:
        image_data = urllib.request.urlopen(img_url, timeout=1.0).read()
    except:
        return (img_id, img_url, None)
    # Some useful tensors:
    # 'softmax:0': A tensor containing the normalized prediction across
    #   1000 labels.
    # 'pool_3:0': A tensor containing the next-to-last layer containing 2048
    #   float description of the image.
    # 'DecodeJpeg/contents:0': A tensor containing a string providing JPEG
    #   encoding of the image.
    # Runs the softmax tensor by feeding the image_data as input to the graph.
    softmax_tensor = sess.graph.get_tensor_by_name('softmax:0')
    try:
        predictions = sess.run(softmax_tensor,{
                                                'DecodeJpeg/contents:0': image_data
                                                })
    except:
        # Handle problems with malformed JPEG files
        return (img_id, img_url, None)
    predictions = np.squeeze(predictions)
    top_k = predictions.argsort()[-num_top_predictions:][::-1]
    scores = []
    for node_id in top_k:
        if node_id not in node_lookup:
            human_string = ''
        else:
            human_string = node_lookup[node_id]
    score = predictions[node_id]
    scores.append((human_string, score))
    return (img_id, img_url, scores)

def apply_inference_on_batch(batch):
    """Apply inference to a batch of images.

    We do not explicitly tell TensorFlow to use a GPU.
    It is able to choose between CPU and GPU based on its guess of which will be faster.
    """
    with tf.Graph().as_default() as g:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(model_data_bc.value)
        tf.import_graph_def(graph_def, name='')
        with tf.Session() as sess:
            labeled = [run_inference_on_image(sess, img_id, img_url, node_lookup_bc.value) for (img_id, img_url) in batch]
            return [tup for tup in labeled if tup[2] is not None]

In [13]:
urls = sc.parallelize(batched_data)
labeled_images = urls.flatMap(apply_inference_on_batch)

# Examine results

When we call collect(), we will finally run the Spark job to process our images.


In [14]:
local_labeled_images = labeled_images.collect()

KeyboardInterrupt: 

In [None]:
local_labeled_images

In [None]:
labeled_images.collect()