<img src="resources/inference.png" align='center' width=600 />

In [None]:
import pyspark
spark = pyspark.sql.SparkSession.builder.appName("MyApp") \
    .config("spark.jars.packages", "io.delta:delta-core_2.11:0.6.0") \
    .getOrCreate()
sc = spark.sparkContext
sc.addPyFile("/usr/lib/spark/jars/delta-core_2.11-0.6.0.jar")
from delta.tables import *
# Enable Arrow support.
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "128")
sc = spark.sparkContext

In [2]:
import os
import shutil
import tarfile
import time
import zipfile

try:
    from urllib.request import urlretrieve
except ImportError:
    from urllib import urlretrieve

import pandas as pd

import torch
from torch.utils.data import Dataset
from torchvision import datasets, models, transforms
from torchvision.datasets.folder import default_loader  # private API

from pyspark.sql.functions import col, pandas_udf, PandasUDFType, monotonically_increasing_id
from pyspark.sql.types import ArrayType, FloatType
import determined as det

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

# Loading a Checkpoint from Determined

<img src="resources/checkpoint.png" align='center' width=500 />


In [3]:
from determined.experimental import Determined
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

checkpoint = Determined(master="ec2-54-185-44-13.us-west-2.compute.amazonaws.com:8080").get_experiment(2).top_checkpoint()
model = checkpoint.load(path="/home/.config/ckpt", map_location=torch.device('cpu'))

b_state_dict = sc.broadcast(model.state_dict())

def get_model_for_eval():
    """Gets the broadcasted model."""
    model = fasterrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 20)
    model.load_state_dict(b_state_dict.value)
    model.eval()
    return model

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

# Load inference Dataset

<img src="resources/load_data.png" align='center' width=500 />



In [5]:
images_df = spark.read.format("delta").option("versionAsOf", 0).load("s3://david-voc-delta/val/")

images_df = images_df.select(col('image'), col('key'))
images_df.show(5)
images_df.count()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+--------------------+
|               image|                 key|
+--------------------+--------------------+
|[FF D8 FF E0 00 1...|v1/JPEGImages/200...|
|[FF D8 FF E0 00 1...|v1/JPEGImages/200...|
|[FF D8 FF E0 00 1...|v1/JPEGImages/200...|
|[FF D8 FF E0 00 1...|v1/JPEGImages/200...|
|[FF D8 FF E0 00 1...|v1/JPEGImages/200...|
+--------------------+--------------------+
only showing top 5 rows

1000

# Defining the Inference Process

We need to tell Spark how to load the data and perform inference.  For actual inference, we use a Pandas UDF to efficiently batch the data and minimize the time spent loading the model weights.

In [6]:
import io
from PIL import Image
from torchvision.transforms import Compose, ToTensor

class VOCDataset(Dataset):
  def __init__(self, images):
    self.raw_images = images
    
    transforms = []
    transforms.append(ToTensor())
    self.transform = Compose(transforms)
    
  def __len__(self):
    return len(self.raw_images)

  def __getitem__(self, index):
    raw_image = self.raw_images[index]
    image = Image.open(io.BytesIO(raw_image)).convert('RGB')
    image = self.transform(image)
    return image

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [7]:
import torch
from pyspark.sql.types import StructType, StructField, FloatType, ArrayType, IntegerType, StringType
import numpy as np
import os

def collate_fn(batch):
    return list(batch)

def predict_batch(raw_images):
    os.environ['LRU_CACHE_CAPACITY'] = '1'
    ds = VOCDataset(list(raw_images))
    loader = torch.utils.data.DataLoader(ds, batch_size=2, num_workers=8, shuffle=False, collate_fn=collate_fn)
    model = get_model_for_eval()
    all_predictions = []
    with torch.no_grad():
        for images in loader:
            predictions = model(images)
            for prediction in predictions:
                bs = prediction['boxes'].cpu().numpy().flatten().tolist()
                ls = prediction['labels'].cpu().numpy().astype(np.float32).tolist()
                ss = prediction['scores'].cpu().numpy().tolist()
                all_predictions.append([bs, ls, ss])
#     return pd.Series([len(ds)] * len(ds))
    return pd.Series(all_predictions)

predict_udf = pandas_udf(ArrayType(ArrayType(FloatType())), PandasUDFType.SCALAR)(predict_batch)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

# Make Predictions

<img src="resources/preds.png" align='center' width=300 />

In [11]:
input_df = images_df.repartition(40)
predictions_df = input_df.select(col('key'), predict_udf(col('image')).alias("prediction"))

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [12]:
(
  predictions_df
  .write
  .format("delta")
  .mode("overwrite")
  .option("compression", "gzip")
  .save("s3://david-voc-predictions/preds")
)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

# Inspect the Predictions

We can then inspect the predictions on the fly.

In [13]:
path ="s3://david-voc-predictions/preds"
df = spark.read.format("delta").load(path)
df.show(20)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+--------------------+
|                 key|          prediction|
+--------------------+--------------------+
|v1/JPEGImages/201...|[[19.41631, 86.18...|
|v1/JPEGImages/201...|[[32.854958, 25.9...|
|v1/JPEGImages/201...|[[396.55264, 160....|
|v1/JPEGImages/201...|[[36.59609, 36.16...|
|v1/JPEGImages/201...|[[167.91513, 23.1...|
|v1/JPEGImages/201...|[[97.14731, 22.16...|
|v1/JPEGImages/201...|[[308.70505, 266....|
|v1/JPEGImages/200...|[[130.96548, 29.4...|
|v1/JPEGImages/200...|[[204.10103, 67.2...|
|v1/JPEGImages/200...|[[106.18939, 116....|
|v1/JPEGImages/200...|[[41.099865, 133....|
|v1/JPEGImages/200...|[[89.92501, 40.92...|
|v1/JPEGImages/200...|[[54.83442, 0.0, ...|
|v1/JPEGImages/200...|[[145.85356, 0.0,...|
|v1/JPEGImages/200...|[[213.14555, 117....|
|v1/JPEGImages/201...|[[258.44434, 203....|
|v1/JPEGImages/201...|[[238.99425, 99.7...|
|v1/JPEGImages/201...|[[299.58047, 127....|
|v1/JPEGImages/200...|[[51.827614, 218....|
|v1/JPEGImages/200...|[[321.1475