In [0]:
# Enable Arrow support.
spark.conf.set("spark.sql.execution.arrow.enabled", "true")
spark.conf.set("spark.sql.execution.arrow.maxRecordsPerBatch", "64")
import os
import shutil
import uuid
from typing import Iterator, Tuple
 
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
from torch.utils.data import Dataset
from torchvision import datasets, models, transforms
from torch import Tensor
 
import torch
from torch.utils.data import Dataset
from torchvision import datasets, models, transforms
from torchvision.datasets.folder import default_loader  # private API
 
from pyspark.sql.functions import col, pandas_udf, PandasUDFType
from pyspark.sql.types import ArrayType, FloatType
import requests
from io import BytesIO
from PIL import Image

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")

In [0]:
class ImageDataset(Dataset):
    def __init__(self, paths, transform=None):
        self.paths = paths
        self.transform = transform
        

    def __getitem__(self, item):
        image_path = self.paths[item]
        
        response = requests.get(image_path)
        img = Image.open(BytesIO(response.content)).convert('RGB')
        img = self.transform(img)
        return img

    def __len__(self):
        return len(self.paths)
            

def get_model_for_eval():
    """Gets the broadcasted model."""
    model = load_model()
    model.load_state_dict(bc_model_state.value)
    model.eval()
    return model


# @pandas_udf(ArrayType(FloatType()))
def predict_batch(paths):
    transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
      ])
    images = ImageDataset(paths, transform=transform)
    loader = torch.utils.data.DataLoader(images, batch_size=32, num_workers=8)
    model = get_model_for_eval()
#     model.to(device)
    all_predictions = []
    with torch.no_grad():
        for batch in loader:
            predictions = list(model(batch).cpu().numpy())
            for prediction in predictions:
                all_predictions.append(prediction)
    return pd.Series(all_predictions)

def load_model(model_path=None):
    model = models.resnet50(pretrained=False)
    num_ftrs = model.fc.in_features
    model.fc = nn.Linear(num_ftrs, 13)
    if model_path:
        model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
    model = model.to(device)
    return model
    
    
bc_model_state = sc.broadcast(load_model('/dbfs/FileStore/model.h5').state_dict())
predict_udf = pandas_udf(ArrayType(FloatType()), PandasUDFType.SCALAR)(predict_batch)

In [0]:
from pyspark import SparkConf,SparkContext
from pyspark.streaming import StreamingContext
from pyspark.sql import Row,SQLContext
from pyspark.sql.functions import monotonically_increasing_id
import sys
import requests
import copy
# create spark configuration
conf = SparkConf()
conf.setAppName("TwitterStreamApp")
# create spark context with the above configuration
# sc = SparkContext(conf=conf)
sc = SparkContext.getOrCreate(conf=conf)
sc.setLogLevel("WARN")
# create the Streaming Context from the above spark context with interval size 2 seconds
ssc = StreamingContext(sc, 8)
# setting a checkpoint to allow RDD recovery
# ssc.checkpoint("checkpoint_TwitterApp")
# read data from port 9009
dataStream = ssc.socketTextStream("localhost",9017)

In [0]:
def process_rdd(time, rdd):
    # time rdd arrived
    print("----------- %s -----------" % str(time))
    try:
        # rdd operations here
#         print(rdd.collect())
        if not rdd.isEmpty():
            df = rdd.map(lambda x: (x, )).toDF(['url'])
            predictions_df = df.select(col('url'), predict_udf(col('url')).alias("prediction"))
            res = predictions_df.take(4)
#             print(res)
            np.save(f"/dbfs/FileStore/animals/{str(time)}.npy", np.array(res))
    except:
        e = sys.exc_info()[0]
        print("Error: %s" % e)

In [0]:
# functions to be implemented
dataStream.foreachRDD(process_rdd)
# dataStream.pprint()

# start the streaming computation
ssc.start()


# wait for the streaming to finish, timeout can be removed to continuely streaming
ssc.awaitTermination()