In [16]:
import pyspark
spark = pyspark.sql.SparkSession.builder.appName("MyApp") \
    .config("spark.jars.packages", "io.delta:delta-core_2.11:0.6.0") \
    .getOrCreate()
sc = spark.sparkContext
sc.addPyFile("/usr/lib/spark/jars/delta-core_2.11-0.6.0.jar")
from delta.tables import *
# Enable Arrow support.
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "10")
sc = spark.sparkContext

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [17]:
import os
import shutil
import tarfile
import time
import zipfile

try:
    from urllib.request import urlretrieve
except ImportError:
    from urllib import urlretrieve

import pandas as pd

import torch
from torch.utils.data import Dataset
from torchvision import datasets, models, transforms
from torchvision.datasets.folder import default_loader  # private API

from pyspark.sql.functions import col, pandas_udf, PandasUDFType
from pyspark.sql.types import ArrayType, FloatType

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [18]:
from determined.experimental import Determined

def get_model_for_eval(experiment_id):
    checkpoint = Determined(master="http://44.232.66.32:8080/").get_experiment(experiment_id).top_checkpoint()
    model = checkpoint.load(path="/home/.config/ckpt", map_location=torch.device('cpu'))
    return model

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [19]:
import boto3
import os

def load_voc_image_names(bucket_name, prefix):
    s3 = boto3.resource('s3')
    bucket = s3.Bucket(bucket_name)
    files = []
    for obj in bucket.objects.filter(Prefix=os.path.join(prefix, "JPEGImages")):
        if obj.key.endswith('.jpg'):
            files.append(obj.key)
    return(files)
    

keys = load_voc_image_names('david-voc-data', 'v1')

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [20]:
import io

def load_val_list(bucket_name, prefix):
    s3 = boto3.resource('s3')
    bucket = s3.Bucket(bucket_name)
    path = os.path.join(prefix, "ImageSets", "Main", "val.txt")
    response = bucket.Object(path)
    data = response.get()['Body'].read()
    return data.decode('utf8').split('\n')

val_list = load_val_list('david-voc-data', 'v1')
files = [k for k in keys if os.path.basename(k).split('.')[0] in val_list][:400]

splits = []
num_splits = 100
for i in range(len(files)):
    splits.append(int(i*num_splits/len(files)))

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [21]:
rows_df = sqlContext.createDataFrame(zip(splits, files),['split','path'])
rows_df.show(5)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

+-----+--------------------+
|split|                path|
+-----+--------------------+
|    0|v1/JPEGImages/200...|
|    0|v1/JPEGImages/200...|
|    0|v1/JPEGImages/200...|
|    0|v1/JPEGImages/200...|
|    1|v1/JPEGImages/200...|
+-----+--------------------+
only showing top 5 rows

In [22]:
import io
from PIL import Image
from torchvision.transforms import Compose, ToTensor

class VOCDataset(Dataset):
  def __init__(self, paths, bucket):
    self.paths = paths
    self.bucket = bucket
    
    transforms = []
    transforms.append(ToTensor())
    self.transform = Compose(transforms)
    
  def __len__(self):
    return len(self.paths)

  def __getitem__(self, index):
    path = self.paths[index]
    s3 = boto3.client('s3')
    response = s3.get_object(Bucket=self.bucket, Key=path)
    body = response["Body"]
    contents = bytearray(body.read())
    image = Image.open(io.BytesIO(contents)).convert('RGB')
    if self.transform is not None:
      image = self.transform(image)
    return image, path

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [23]:
import torch
from pyspark.ml.linalg import Vectors
from pyspark.ml.linalg import VectorUDT
from pyspark.sql.types import StructType, StructField, FloatType, ArrayType, IntegerType, StringType

def collate_fn(batch):
    return tuple(zip(*batch))

schema = StructType([StructField("boxes", ArrayType(ArrayType(FloatType()))),
                     StructField("labels", ArrayType(IntegerType())),
                     StructField("scores", ArrayType(FloatType())),
                     StructField("path", StringType())],
                   )

def predict_batch(pdf):
    paths = pdf.path
    images = VOCDataset(paths, 'david-voc-data')
    loader = torch.utils.data.DataLoader(images, batch_size=2, num_workers=8, collate_fn=collate_fn)
    model = get_model_for_eval(18)
    boxes = []
    labels = []
    scores = []
    paths = []
    model.eval()
    with torch.no_grad():
        for images, ps in loader:
            predictions = list(model(list(images)))
            for prediction in predictions:
                boxes.append(prediction['boxes'].cpu().numpy().tolist())
                labels.append(prediction['labels'].cpu().numpy().tolist())
                scores.append(prediction['scores'].cpu().numpy().tolist())
            for p in ps:
                paths.append(p)
    return pd.DataFrame({"boxes": boxes, "labels": labels, "scores": scores, "path": paths})

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [24]:
pdf = pd.DataFrame({"path": files[:2]})
preds = predict_batch(pdf)
print(preds['labels'][1])

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

[18, 14, 5, 12, 2, 9, 3]
	nonzero(Tensor input, *, Tensor out)
Consider using one of the following signatures instead:
	nonzero(Tensor input, *, bool as_tuple)

In [25]:
predict_udf = pandas_udf(schema, PandasUDFType.GROUPED_MAP)(predict_batch)



VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [28]:
test = rows_df.groupby("split").apply(predict_udf).show(40)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

An error was encountered:
"cannot resolve '`blah`' given input columns: [split, path];;\n'FlatMapGroupsInPandas ['blah], predict_batch(split#81L, path#82), [boxes#118, labels#119, scores#120, path#121]\n+- 'Project ['blah, split#81L, path#82]\n   +- LogicalRDD [split#81L, path#82], false\n"
Traceback (most recent call last):
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/sql/group.py", line 275, in apply
    jdf = self._jgd.flatMapGroupsInPandas(udf_column._jc.expr())
  File "/usr/lib/spark/python/lib/py4j-0.10.7-src.zip/py4j/java_gateway.py", line 1257, in __call__
    answer, self.gateway_client, self.target_id, self.name)
  File "/usr/lib/spark/python/lib/pyspark.zip/pyspark/sql/utils.py", line 69, in deco
    raise AnalysisException(s.split(': ', 1)[1], stackTrace)
pyspark.sql.utils.AnalysisException: "cannot resolve '`blah`' given input columns: [split, path];;\n'FlatMapGroupsInPandas ['blah], predict_batch(split#81L, path#82), [boxes#118, labels#119, scores#120, path#121]\

In [26]:
predictions_df = rows_df.groupby("split").apply(predict_udf)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

In [27]:
(
  predictions_df
  .write
  .format("delta")
  .mode("overwrite")
  .option("compression", "gzip")
  .save("s3://david-voc-predictions/preds")
)

VBox()

FloatProgress(value=0.0, bar_style='info', description='Progress:', layout=Layout(height='25px', width='50%'),…

An error was encountered:
An error occurred while calling o325.save.
: org.apache.spark.SparkException: Job aborted.
	at org.apache.spark.sql.execution.datasources.FileFormatWriter$.write(FileFormatWriter.scala:198)
	at org.apache.spark.sql.delta.files.TransactionalWrite$$anonfun$writeFiles$1.apply(TransactionalWrite.scala:152)
	at org.apache.spark.sql.delta.files.TransactionalWrite$$anonfun$writeFiles$1.apply(TransactionalWrite.scala:134)
	at org.apache.spark.sql.execution.SQLExecution$.org$apache$spark$sql$execution$SQLExecution$$executeQuery$1(SQLExecution.scala:83)
	at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExecutionId$1$$anonfun$apply$1.apply(SQLExecution.scala:94)
	at org.apache.spark.sql.execution.QueryExecutionMetrics$.withMetrics(QueryExecutionMetrics.scala:141)
	at org.apache.spark.sql.execution.SQLExecution$.org$apache$spark$sql$execution$SQLExecution$$withMetrics(SQLExecution.scala:178)
	at org.apache.spark.sql.execution.SQLExecution$$anonfun$withNewExe