In [1]:
import pyspark
spark = pyspark.sql.SparkSession.builder.appName("MyApp") \
    .config("spark.jars.packages", "io.delta:delta-core_2.11:0.6.0") \
    .getOrCreate()
sc = spark.sparkContext
sc.addPyFile("/usr/lib/spark/jars/delta-core_2.11-0.6.0.jar")
from delta.tables import *
# Enable Arrow support.
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "40")
sc = spark.sparkContext

VBox()

Starting Spark application


ID,YARN Application ID,Kind,State,Spark UI,Driver log,Current session?
9,application_1591129707896_0010,pyspark,idle,Link,Link,✔


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

SparkSession available as 'spark'.


FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [2]:
import os
import shutil
import tarfile
import time
import zipfile

try:
    from urllib.request import urlretrieve
except ImportError:
    from urllib import urlretrieve

import pandas as pd

import torch
from torch.utils.data import Dataset
from torchvision import datasets, models, transforms
from torchvision.datasets.folder import default_loader  # private API

from pyspark.sql.functions import col, pandas_udf, PandasUDFType, monotonically_increasing_id
from pyspark.sql.types import ArrayType, FloatType

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [3]:
from determined.experimental import Determined
from torchvision.models.detection import fasterrcnn_resnet50_fpn
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor

checkpoint = Determined(master="ec2-54-185-44-13.us-west-2.compute.amazonaws.com:8080").get_experiment(1).top_checkpoint()
model = checkpoint.load(path="/home/.config/ckpt", map_location=torch.device('cpu'))

b_state_dict = sc.broadcast(model.state_dict())


VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [4]:
def get_model_for_eval():
    """Gets the broadcasted model."""
    model = fasterrcnn_resnet50_fpn(pretrained=False, pretrained_backbone=False)
    in_features = model.roi_heads.box_predictor.cls_score.in_features
    model.roi_heads.box_predictor = FastRCNNPredictor(in_features, 20)
    model.load_state_dict(b_state_dict.value)
    model.eval()
    return model

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [5]:
images_df = spark.read.format("delta").option("versionAsOf", 0).load("s3://david-voc-delta/val/")

images_df = images_df.select(col('key'), col('image')).withColumn("id", monotonically_increasing_id())


batch_size = 10
num_images = images_df.count()

num_splits = num_images / batch_size
images_df = images_df.withColumn("split", col("id") % num_splits)

images_df.show(5)


VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+--------------------+--------------------+----------+-----+
|                 key|               image|        id|split|
+--------------------+--------------------+----------+-----+
|v1/JPEGImages/200...|[FF D8 FF E0 00 1...|8589934592| 92.0|
|v1/JPEGImages/200...|[FF D8 FF E0 00 1...|8589934593| 93.0|
|v1/JPEGImages/200...|[FF D8 FF E0 00 1...|8589934594| 94.0|
|v1/JPEGImages/200...|[FF D8 FF E0 00 1...|8589934595| 95.0|
|v1/JPEGImages/200...|[FF D8 FF E0 00 1...|8589934596| 96.0|
+--------------------+--------------------+----------+-----+
only showing top 5 rows

In [6]:
import io
from PIL import Image
from torchvision.transforms import Compose, ToTensor

class VOCDataset(Dataset):
  def __init__(self, images, ids):
    self.raw_images = images
    self.ids = ids
    
    transforms = []
    transforms.append(ToTensor())
    self.transform = Compose(transforms)
    
  def __len__(self):
    return len(self.raw_images)

  def __getitem__(self, index):
    raw_image = self.raw_images[index]
    image = Image.open(io.BytesIO(raw_image)).convert('RGB')
    image = self.transform(image)
    id = self.ids[index]
    return image, id

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [7]:
import torch
import numpy as np
import os

def collate_fn(batch):
    return tuple(zip(*batch))


def predict_batch(pdf):
    os.environ['LRU_CACHE_CAPACITY'] = '1'
    raw_images = pdf.image
    img_ids = pdf.id
    ds = VOCDataset(list(raw_images), list(img_ids))
    loader = torch.utils.data.DataLoader(ds, batch_size=2, num_workers=8, collate_fn=collate_fn)
    model = get_model_for_eval()
    boxes = []
    labels = []
    scores = []
    ids = []
    model.eval()
    with torch.no_grad():
        for images, img_ids in loader:
            predictions = list(model(list(images)))
            for prediction in predictions:
                boxes.append(prediction['boxes'].cpu().numpy().tolist())
                labels.append(prediction['labels'].cpu().numpy().tolist())
                scores.append(prediction['scores'].cpu().numpy().tolist())
            for id in img_ids:
                ids.append(ids)
    return pd.DataFrame({"boxes": boxes, "labels": labels, "scores": scores, "id": ids})

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [8]:
from pyspark.sql.types import StructType, StructField, IntegerType, BinaryType, StringType

schema = StructType([StructField("boxes", ArrayType(ArrayType(FloatType()))),
                     StructField("labels", ArrayType(IntegerType())),
                     StructField("scores", ArrayType(FloatType())),
                     StructField("id", IntegerType())],
                   )
predict_udf = pandas_udf(schema, PandasUDFType.GROUPED_MAP)(predict_batch)


VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [9]:
predictions_df = images_df.groupby("split").apply(predict_udf)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [10]:
(
  predictions_df
  .write
  .format("delta")
  .mode("overwrite")
  .option("compression", "gzip")
  .save("s3://david-voc-predictions/group_preds")
)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

KeyboardInterrupt: 

In [None]:
path ="s3://david-voc-predictions/group_preds"
df = spark.read.format("delta").load(path)
df.show(50)