## Using Spark to score batches of images deployed in Azure ML

In [0]:
#Limit 10 images per call to the API
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", 10)

In [0]:
#Develop UDF to call 
import pandas as pd
from pyspark.sql.functions import col, pandas_udf
from pyspark.sql.types import StringType, StructType, StructField, FloatType
from io import BytesIO
import base64
import requests
import ast
import time
ENCODING = 'utf-8'
max_retries =10 #retries in case of throttling from API
#pandas UDF to score from rest endpoint
endpoint='' # Replace this with the API URI for the web service
api_key = '' # Replace this with the API key for the web service
headers = {'Content-Type':'application/json', 'Authorization':('Bearer '+ api_key)}
schema = StructType([StructField("startX", FloatType()),
                     StructField("startY", FloatType()),
                     StructField("endX", FloatType()),
                     StructField("endY", FloatType()),
                     StructField("label", StringType())])
def score_images(iterator) -> pd.DataFrame:

  for df in iterator:
    base64_string_list=[]
    for _, line in df.iterrows():
      image_bytes = BytesIO(line['content'])
      encoded_image =base64.b64encode(image_bytes.getvalue())
      base64_string = encoded_image.decode(ENCODING)
      base64_string= "b'{0}'".format(base64_string)
      base64_string_list.append(base64_string)
    image_request = {"data": base64_string_list}
    body = json.dumps(image_request)
    response = requests.post(url=endpoint, data=body,headers={"Content-type": "application/json"})
    i=0
    while response.status_code !=200:
      time.sleep(20)
      i+=1
      if i>max_retries:
        startX, startY, endX, endY, label = [999],[999], [999], [999],["error"]
        break
      response = requests.post(url=endpoint, data=body,headers=headers)
    if response.status_code ==200:
      pred_results = ast.literal_eval(response.text)
      startX,startY,endX,endY, label =zip(*pred_results)
    yield pd.DataFrame({"startX":startX,"startY":startY,"endX":endX,"endY":endY, "label":label})

      

 #read images  in binary format
df = spark.read.format("binaryFile") \
  .option("pathGlobFilter", "*.jpg") \
  .option("recursiveFileLookup", "true") \
  .load("/mnt/aml-mlflow-object-detection/dataset/images")
    
# Apply the scoring logic using MapInPandas    
images_score =df.mapInPandas(score_images,schema=schema)

In [0]:
display(images_score)

startX,startY,endX,endY,label
0.12996936,0.12943318,0.82521915,0.7459129,airplane
0.22097903,0.16232991,0.83722997,0.83005524,airplane
0.16602665,0.0765706,0.84160006,0.8030168,airplane
0.14294857,0.20163828,0.84063244,0.79306126,airplane
0.13298759,0.084847,0.8681232,0.86245644,airplane
0.14573008,0.2037864,0.8601751,0.8209756,airplane
0.12395829,0.17410663,0.87076133,0.7992951,airplane
0.13916388,0.1896129,0.8462864,0.799316,airplane
0.15537542,0.19219005,0.8708534,0.8285449,airplane
0.124807596,0.1330117,0.87569255,0.8016786,airplane


In [0]:
images_score.write.format("parquet").mode("overwrite").save("/mnt/aml-mlflow-object-detection/output/result_aks")

In [0]:
df = spark.read.format("parquet").load("/mnt/aml-mlflow-object-detection/output/result_aks")

In [0]:
df.registerTempTable("result_tbl")

In [0]:
%sql select * from result_tbl where label=="error"

In [0]:
%sql select count(*) from result_tbl