In [1]:
# !sudo add-apt-repository ppa:openjdk-r/ppa
!sudo apt-get install openjdk-11-jdk
# To Install Oracke JDK varsion 8
# !sudo add-apt-repository ppa:webupd8team/java
# !sudo apt-get install oracle-java8-installer

Reading package lists... Done
Building dependency tree       
Reading state information... Done
openjdk-11-jdk is already the newest version (11.0.14.1+1-0ubuntu1~18.04).
0 upgraded, 0 newly installed, 0 to remove and 39 not upgraded.


In [2]:
# !wget -q https://downloads.apache.org/spark/spark-3.1.1/spark-3.1.1-bin-hadoop3.2.tgz
# !tar xvzf spark-3.1.1-bin-hadoop3.2.tgz
!pip install pyspark
!pip install -q findspark
!pip install pyarrow
try:
  # %tensorflow_version only exists in Colab.
  !pip install  tf-estimator-nightly==2.8.0.dev2021122109
except Exception:
  pass



In [3]:
import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-amd64"
# os.environ["SPARK_HOME"] = "/content/spark-3.2.1"


In [4]:
from pyspark.sql.functions import col, pandas_udf, regexp_extract
import io

from tensorflow.keras.applications.imagenet_utils import decode_predictions
import pandas as pd
from pyspark.sql.functions import col, pandas_udf, PandasUDFType

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image

from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
import pathlib
import findspark
from pyspark.sql import SparkSession
import matplotlib.pyplot as plt 


findspark.init()
spark = SparkSession.builder.master("local[*]").getOrCreate()

# Dataformatting

In [5]:
!wget https://www.cs.toronto.edu/%7Ekriz/cifar-10-python.tar.gz

--2022-04-16 04:07:26--  https://www.cs.toronto.edu/%7Ekriz/cifar-10-python.tar.gz
Resolving www.cs.toronto.edu (www.cs.toronto.edu)... 128.100.3.30
Connecting to www.cs.toronto.edu (www.cs.toronto.edu)|128.100.3.30|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 170498071 (163M) [application/x-gzip]
Saving to: ‘cifar-10-python.tar.gz’


2022-04-16 04:07:38 (14.9 MB/s) - ‘cifar-10-python.tar.gz’ saved [170498071/170498071]



In [6]:
!tar xvf cifar-10-python.tar.gz

cifar-10-batches-py/
cifar-10-batches-py/data_batch_4
cifar-10-batches-py/readme.html
cifar-10-batches-py/test_batch
cifar-10-batches-py/data_batch_3
cifar-10-batches-py/batches.meta
cifar-10-batches-py/data_batch_2
cifar-10-batches-py/data_batch_5
cifar-10-batches-py/data_batch_1


In [7]:
!ls cifar-10-batches-py/

batches.meta  data_batch_2  data_batch_4  readme.html
data_batch_1  data_batch_3  data_batch_5  test_batch


In [8]:
def load_cifar10_batch(cifar10_dataset_folder_path, batch_id):
  import pickle
  with open(cifar10_dataset_folder_path + '/data_batch_' + str(batch_id), mode='rb') as file:
      # note the encoding type is 'latin1'
      batch = pickle.load(file, encoding='latin1')
      
  features = batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)
  labels = batch['labels']
      
  return features, labels

def load_label_names():
  return ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

In [9]:
import numpy as np

features_p = np.array([])
labels_p = np.array([])

for batch_id in range(1,6):
  features, labels = load_cifar10_batch('./cifar-10-batches-py', batch_id)
  labels = np.expand_dims(np.squeeze(labels),1)
  if batch_id-1:
    features_acc = np.vstack([features_p, features])
    labels_acc = np.vstack([labels_p, labels])
    features_p = features_acc
    labels_p = labels_acc
  else:
    features_p = features
    labels_p = labels

In [10]:
label_names = load_label_names()

In [12]:
from PIL import Image as im
import os
import shutil
from tqdm.notebook import tqdm

def write_imagenet_format(features_acc, labels_acc, data_path):
  label_names = load_label_names()

  if not os.path.exists(data_path):
    os.makedirs(data_path)

  for label in label_names:
    sub_fold = os.path.join(data_path,label)
    if not os.path.exists(sub_fold):
      os.mkdir(sub_fold)

  for i in tqdm(range(features_acc.shape[0])):
    samp = features_acc[i]
    label = np.squeeze(labels_acc[i])
    data = im.fromarray(samp, 'RGB')
    data_save_path = os.path.join(data_path,label_names[label],str(i)+'.jpg')
    data.save(data_save_path)

In [13]:
train_path = 'c10_data/train_data'
write_imagenet_format(features_acc, labels_acc, train_path)

  0%|          | 0/50000 [00:00<?, ?it/s]

In [14]:
import pickle
with open('./cifar-10-batches-py/test_batch', mode='rb') as file:
    batch = pickle.load(file, encoding='latin1')
test_features = batch['data'].reshape((len(batch['data']), 3, 32, 32)).transpose(0, 2, 3, 1)
test_labels = batch['labels']

In [15]:
test_path = 'c10_data/test_data'
write_imagenet_format(test_features, test_labels, test_path)

  0%|          | 0/10000 [00:00<?, ?it/s]

# Read to SPARK

In [16]:
import io

from tensorflow.keras.applications.imagenet_utils import decode_predictions
import pandas as pd
from pyspark.sql.functions import col, pandas_udf, PandasUDFType
from pyspark.sql.functions import col, pandas_udf, regexp_extract
import pyspark.sql.functions as sqlf
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import models, transforms
from PIL import Image

In [17]:
images = spark.read.format("binaryFile").option("recursiveFileLookup", "true").option("pathGlobFilter", "*.jpg").load('./c10_data/train_data')

In [18]:

def extract_label(path_col):
  """Extract label from file path using built-in SQL functions."""
  return regexp_extract(path_col, "./c10_data/train_data/([^/]+)", 1)

def extract_size(content):
  """Extract image size from its raw content."""
  image = Image.open(io.BytesIO(content))
  return image.size

@pandas_udf("width: int, height: int")
def extract_size_udf(content_series):
  sizes = content_series.apply(extract_size)
  return pd.DataFrame(list(sizes))

df = images.select(
  col("path"),
  col("modificationTime"),
  extract_label(col("path")).alias("label"),
  extract_size_udf(col("content")).alias("size"),
  col("content"))


In [19]:
class ImageNetDataset(Dataset):
  """
  Converts image contents into a PyTorch Dataset with standard ImageNet preprocessing.
  """
  def __init__(self, contents):
    self.contents = contents

  def __len__(self):
    return len(self.contents)

  def __getitem__(self, index):
    return self._preprocess(self.contents[index])

  def _preprocess(self, content):
    """
    Preprocesses the input image content using standard ImageNet normalization.
    
    See https://pytorch.org/docs/stable/torchvision/models.html.
    """
    image = Image.open(io.BytesIO(content))
    transform = transforms.Compose([
      transforms.Resize(256),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
    ])
    return transform(image)

In [20]:
def imagenet_model_udf():
  """
  Wraps an ImageNet model into a Pandas UDF that makes predictions.
  
  You might consider the following customizations for your own use case:
    - Tune DataLoader's batch_size and num_workers for better performance.
    - Use GPU for acceleration.
    - Change prediction types.
  """
  model = models.mobilenet_v2(pretrained=True)
  model.eval()
  def predict(content_series_iter):

    for content_series in content_series_iter:
      dataset = ImageNetDataset(list(content_series))
      loader = DataLoader(dataset, batch_size=64)
      with torch.no_grad():
        for image_batch in loader:
          predictions = model(image_batch).numpy()
          predicted_labels = [x[0] for x in decode_predictions(predictions, top=1)]
          yield pd.DataFrame(predicted_labels)
  return_type = "class: string, desc: string"
  return pandas_udf(return_type, PandasUDFType.SCALAR_ITER)(predict)

In [21]:
mobilenet_v2_udf = imagenet_model_udf()
predictions = df.withColumn("prediction", mobilenet_v2_udf(col("content")))
prediction_mobil = predictions.select(col("label"),col("prediction.desc").alias("mobilenetv2 prediction"))
prediction_mobil.show(25,False)



+----------+----------------------+
|label     |mobilenetv2 prediction|
+----------+----------------------+
|frog      |rock_python           |
|bird      |pinwheel              |
|truck     |bearskin              |
|automobile|mousetrap             |
|truck     |oil_filter            |
|truck     |thresher              |
|frog      |jaguar                |
|truck     |moving_van            |
|airplane  |waffle_iron           |
|automobile|panpipe               |
|frog      |sidewinder            |
|truck     |airliner              |
|automobile|maraca                |
|truck     |thresher              |
|frog      |clog                  |
|truck     |thresher              |
|truck     |moving_van            |
|frog      |jersey                |
|truck     |thresher              |
|truck     |thresher              |
|cat       |fire_screen           |
|truck     |moving_van            |
|frog      |sidewinder            |
|truck     |tobacco_shop          |
|bird      |howler_monkey   

In [22]:
prediction_mobil_ser = prediction_mobil.limit(2500).toPandas()

In [23]:
top_num = 5
for label_name in label_names:
  filt_rows = prediction_mobil_ser.loc[prediction_mobil_ser['label'] == label_name]
  print(f"\n\n ####### Top {top_num} predictions for class {label_name} #######")
  final_rows = filt_rows['mobilenetv2 prediction'].value_counts().nlargest(top_num).to_frame('counts')
  print(final_rows)



 ####### Top 5 predictions for class airplane #######
               counts
moving_van          7
rock_beauty         4
thresher            4
assault_rifle       4
chain_saw           4


 ####### Top 5 predictions for class automobile #######
                 counts
moving_van          244
thresher             47
chain_saw            41
amphibian            25
cassette_player      15


 ####### Top 5 predictions for class bird #######
                  counts
fox_squirrel          10
three-toed_sloth       8
rock_beauty            5
platypus               3
lesser_panda           3


 ####### Top 5 predictions for class cat #######
                  counts
EntleBucher            9
fox_squirrel           7
Japanese_spaniel       5
bearskin               5
Windsor_tie            4


 ####### Top 5 predictions for class deer #######
                  counts
fox_squirrel          13
sorrel                 5
barn_spider            5
cardoon                4
Japanese_spaniel       3


 ##