In [1]:
import findspark
findspark.init()
from pyspark.sql import SparkSession
from pyspark.conf import SparkConf
from pyspark.sql.types import *
from pyspark.sql.functions import col, split
from pyspark.ml.feature import StringIndexer
import pyspark.sql.functions as fn
import shutil
import io
import numpy as np
import pandas as pd
from PIL import Image
import warnings
import time
import os
import torch

warnings.filterwarnings('ignore')

In [2]:
def timing(start):
    print(f'Elapsed time: {time.time() - start:.2f} s')
# start = time.time()

# Start Session

In [3]:
start = time.time()

spark = SparkSession.builder.appName('SparkCPU').config("spark.driver.memory", "15g").getOrCreate()

timing(start)

23/08/01 09:44:21 WARN Utils: Your hostname, bdai-desktop resolves to a loopback address: 127.0.1.1; using 165.132.118.198 instead (on interface enp0s31f6)
23/08/01 09:44:21 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
23/08/01 09:44:21 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable


Elapsed time: 2.62 s


# 1. Extract

In [4]:
# imagezip_path = "/home/bdai/covid_data/covidx-cxr2.zip"
image_path = "/home/bdai/spark_work/covid_dataset"

# shutil.unpack_archive(imagezip_path, image_path)

In [5]:
def get_dir_size(path='.'):
    total = 0
    with os.scandir(path) as it:
        for entry in it:
            if entry.is_file():
                total += entry.stat().st_size
            elif entry.is_dir():
                total += get_dir_size(entry.path)
    return total
dir_size = round(get_dir_size(image_path) / (1024 ** 3),2)

print("Total dataset size : {} GBs".format(dir_size))

Total dataset size : 13.07 GBs


In [6]:
start = time.time()

train_images = spark.read.format("binaryFile").option("recursiveFileLookup", "true").load(image_path + "/train")
test_images = spark.read.format("binaryFile").option("recursiveFileLookup", "true").load(image_path + "/test")

# [patient id] [filename] [class] [data source] 
train_txt = spark.read.text("/home/bdai/spark_work/covid_dataset/train.txt")
test_txt = spark.read.text("/home/bdai/spark_work/covid_dataset/test.txt")

timing(start)

Elapsed time: 2.34 s


# 2. Transform

In [7]:
def extract_size (content):
    # Extrach image size from its raw content
    image = Image.open(io.BytesIO(content))
    return image.size

@fn.pandas_udf("width: int, height: int")
def extract_size_udf(content_series):
    sizes = content_series.apply(extract_size)
    return pd.DataFrame(list(sizes))


def transform_merge(image, text):
    image = image.withColumn("file_name", fn.substring_index(image.path, "/", -1))
    text = text.select(split(col("value")," ").getItem(0).alias("patient_id"),
                       split(col("value")," ").getItem(1).alias("file_name"),
                       split(col("value")," ").getItem(2).alias("class")).drop("value")
    df = image.join(text,['file_name'],how='inner')
    df = df.select(fn.col("path"),
                   fn.col("file_name"),
                   extract_size_udf(fn.col("content")).alias("size"),
                   fn.col("content"),
                   fn.col("class"))
    indexer = StringIndexer(inputCol="class", outputCol="label")
    df = indexer.fit(df).transform(df)

    return df




In [8]:
start = time.time()

train_df = transform_merge(train_images, train_txt)
test_df = transform_merge(test_images, test_txt)

timing(start)

                                                                                

Elapsed time: 5.19 s


In [None]:
# temp = test_df.select("content").collect()

# from torchvision import transforms
# for i in range(100):
#     temp_image = Image.open(io.BytesIO(temp[i]["content"]))
#     trans = transforms.ToTensor()
#     print(trans(temp_image).shape)

# 3. Load

In [9]:
start = time.time()

compression = spark.conf.get("spark.sql.parquet.compression.codec")
spark.conf.set("spark.sql.parquet.compression.codec", "uncompressed")

train_df.write.format("parquet").mode("overwrite").option("mergeSchema", True).saveAsTable("covid_train_binary")
test_df.write.format("parquet").mode("overwrite").option("mergeSchema", True).saveAsTable("covid_test_binary")
spark.conf.set("spark.sql.parquet.compression.codec", compression)

timing(start)



Elapsed time: 213.07 s


                                                                                

In [10]:
spark.stop()