In [0]:
from io import BytesIO
from json import loads
from typing import Any, TypedDict

import numpy as np

from torch.utils.data import Dataset
from torchvision.transforms.functional import pil_to_tensor

import pyspark.sql.functions as F

import PIL

In [0]:
class ImageMetadata(TypedDict):
    """
    A class to represent image metadata.
    height: the image height
    width: the image width
    lat: the latitude of the image
    long: the longitude of the image
    image_id: the id of the image
    map_provider: the map provider the image is from 
    image: The PIL image object
    """
    height: int
    width: int
    lat: float
    long: float
    image_id: int
    map_provider: str
    image: PIL.Image


def get_image_metadata(image_binary: bytes) -> ImageMetadata:  # pragma: no cover
        # Try to read the image and if we fail, we have to default to
        # to the null image case
        image_binary = BytesIO(image_binary)

        try:
            image = PIL.Image.open(image_binary)
            exif = image._getexif()

        except FileNotFoundError:
            exif = None
        except UnicodeDecodeError:
            exif = None

        user_comment_exif_id = 37510

        if exif is None or user_comment_exif_id not in exif:
            # we need to return with default values
            fake_image = PIL.Image.new("RGB", (640, 640), "black")
            return {
                "height": 640,
                "width": 640,
                "lat": 0.0,
                "long": 0.0,
                "image_id": -1,
                "map_provider": "unknown",
                "image": fake_image
            }
        
        try:
            user_comment_exif = exif[user_comment_exif_id]
            exif_dict = loads(
                user_comment_exif.decode("utf-8").replace("\'", "\"")
            )
        
        except UnicodeDecodeError as e:
            # can we gracefully handle this?
            raise ValueError(f"Unable to decode exif data: {e}")
        
        image_id = -1 if "id" not in exif_dict else int(exif_dict["id"])
        return {
            "height": image.height,
            "width": image.width,
            "lat": exif_dict["lat"],
            "long": exif_dict["lng"],
            "image_id": image_id,
            "map_provider": exif_dict["mapProvider"],
            "image": image
        }


class ImageBinaryDataset(Dataset):
    def __init__(self, images):
        self.images = images
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index) -> ImageMetadata:
         return get_image_metadata(self.images[index])

In [0]:
# request_id = "be69e91f"
# user_id = "cnu4"
# dbutils.fs.cp(f"/Volumes/edav_dev_csels/towerscout/images/maps/bronze/{user_id}/{request_id}", "/Volumes/edav_dev_csels/towerscout/misc/unit_tests/image_binary_dataset/", recurse=True)

dbutils.fs.mv('/Volumes/edav_dev_csels/towerscout/misc/unit_tests/index.json', "/Volumes/edav_dev_csels/towerscout/misc/unit_tests/mosaic_streaming_unit_test/", recurse=True)

In [0]:
request_id = "be69e91f"
user_id = "cnu4"
base_path = f"/Volumes/edav_dev_csels/towerscout/images/maps/bronze/{user_id}/{request_id}"


image_df = (
    spark
    .read
    .format("binaryFile")
    .load(base_path) # parameterize
    .select("content")
    .limit(20)
    #.repartition(8)
    #.withColumn("inference", yolo_inference_udf(F.col("content")))
)

#display(image_df)

image_df = image_df.toPandas()
image_bins = image_df["content"]

print(image_df)
bin_dataset = ImageBinaryDataset(image_bins)

display(bin_dataset[18])

In [0]:
print(type(bin_dataset))
assert isinstance(bin_dataset, ImageBinaryDataset)