In [1]:
import pyspark
spark = pyspark.sql.SparkSession.builder.appName("MyApp") \
    .config("spark.jars.packages", "io.delta:delta-core_2.11:0.6.0") \
    .getOrCreate()
sc = spark.sparkContext
sc.addPyFile("/usr/lib/spark/jars/delta-core_2.11-0.6.0.jar")
from delta.tables import *
# Enable Arrow support.
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "128")
sc = spark.sparkContext

VBox()

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
14,application_1590608562184_0015,pyspark,idle,Link,Link,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [2]:
import os
import shutil
import tarfile
import time
import zipfile

try:
    from urllib.request import urlretrieve
except ImportError:
    from urllib import urlretrieve

import pandas as pd

import torch
from torch.utils.data import Dataset
from torchvision import datasets, models, transforms
from torchvision.datasets.folder import default_loader  # private API

from pyspark.sql.functions import col, pandas_udf, PandasUDFType, monotonically_increasing_id
from pyspark.sql.types import ArrayType, FloatType

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [3]:
from determined.experimental import Determined
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

checkpoint = Determined(master="http://44.232.66.32:8080/").get_experiment(18).top_checkpoint()
model = checkpoint.load(path="/home/.config/ckpt", map_location=torch.device('cpu'))

b_state_dict = sc.broadcast(model.state_dict())


VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…



In [4]:
def get_model_for_eval():
    """Gets the broadcasted model."""
    model = fasterrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 20)
    model.load_state_dict(b_state_dict.value)
    model.eval()
    return model

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [5]:
images_df = spark.read.format("delta").option("versionAsOf", 0).load("s3://david-voc-delta/val/")

images_df = images_df.select('image')
images_df = images_df.withColumn("id", monotonically_increasing_id())
images_df.show(5)
images_df.count()

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+----------+
|               image|        id|
+--------------------+----------+
|[FF D8 FF E0 00 1...|8589934592|
|[FF D8 FF E0 00 1...|8589934593|
|[FF D8 FF E0 00 1...|8589934594|
|[FF D8 FF E0 00 1...|8589934595|
|[FF D8 FF E0 00 1...|8589934596|
+--------------------+----------+
only showing top 5 rows

9244

In [6]:
import io
from PIL import Image
from torchvision.transforms import Compose, ToTensor

class VOCDataset(Dataset):
  def __init__(self, images):
    self.raw_images = images
    
    transforms = []
    transforms.append(ToTensor())
    self.transform = Compose(transforms)
    
  def __len__(self):
    return len(self.raw_images)

  def __getitem__(self, index):
    raw_image = self.raw_images[index]
    image = Image.open(io.BytesIO(raw_image)).convert('RGB')
    image = self.transform(image)
    return image

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [7]:
import torch
from pyspark.sql.types import StructType, StructField, FloatType, ArrayType, IntegerType, StringType
import numpy as np
import os

def collate_fn(batch):
    return list(batch)

def predict_batch(raw_images):
    os.environ['LRU_CACHE_CAPACITY'] = '1'
    ds = VOCDataset(list(raw_images))
    loader = torch.utils.data.DataLoader(ds, batch_size=4, num_workers=8, shuffle=False, collate_fn=collate_fn)
    model = get_model_for_eval()
    all_predictions = []
    with torch.no_grad():
        for images in loader:
            predictions = model(images)
            for prediction in predictions:
                bs = prediction['boxes'].cpu().numpy().flatten().tolist()
                ls = prediction['labels'].cpu().numpy().astype(np.float32).tolist()
                ss = prediction['scores'].cpu().numpy().tolist()
                all_predictions.append([bs, ls, ss])
#     return pd.Series([len(ds)] * len(ds))
    return pd.Series(all_predictions)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [8]:
predict_udf = pandas_udf(ArrayType(ArrayType(FloatType())), PandasUDFType.SCALAR)(predict_batch)
# predict_udf = pandas_udf(IntegerType(), PandasUDFType.SCALAR)(predict_batch)



VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [9]:
input_df = images_df.repartition(64)
predictions_df = input_df.select(col('id'), predict_udf(col('image')).alias("prediction"))

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [10]:
(
  predictions_df
  .write
  .format("delta")
  .mode("overwrite")
  .option("compression", "gzip")
  .save("s3://david-voc-predictions/preds")
)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [11]:
path ="s3://david-voc-predictions/preds"
df = spark.read.format("delta").load(path)
df.show(50)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+------------+--------------------+
|          id|          prediction|
+------------+--------------------+
|  8589935229|[[12.843936, 18.8...|
|  8589934646|[[0.95023984, 183...|
|  8589935294|[[1.938038, 9.295...|
|  8589935647|[[430.304, 128.06...|
|  8589934656|[[136.34753, 12.2...|
|  8589935676|[[52.576393, 104....|
|  8589935000|[[180.81972, 106....|
|  8589935649|[[0.0, 29.932024,...|
|  8589935045|[[415.22015, 20.9...|
|  8589935711|[[63.349686, 72.7...|
|  8589935682|[[38.227417, 183....|
|  8589935289|[[43.355602, 22.4...|
|  8589935197|[[1.3848653, 0.0,...|
|  8589935118|[[8.791582, 95.85...|
|  8589934793|[[18.786577, 21.8...|
|  8589935082|[[213.70203, 100....|
|  8589935517|[[420.54333, 222....|
|  8589935383|[[338.31122, 158....|
|137438954297|[[119.47388, 332....|
|137438953676|[[3.0561523, 78.6...|
|137438954435|[[221.24753, 157....|
|137438953979|[[62.66492, 152.8...|
|137438954203|[[55.591507, 95.9...|
|137438954237|[[12.8235, 293.44...|
|137438954408|[[40.804195, 5