In [0]:
from io import BytesIO
from json import loads
from typing import Any, TypedDict, Union
from collections import defaultdict

import numpy as np

import mlflow

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.transforms.functional import pil_to_tensor

import pyspark.sql.functions as F

from PIL import Image

from tsdb.ml.utils import cut_square_detection
from models.common import Detections  # Detection object for YOLOv5 model

In [0]:
class ImageMetadata(TypedDict):
    """
    A class to represent image metadata.
    height: the image height
    width: the image width
    lat: the latitude of the image
    long: the longitude of the image
    image_id: the id of the image
    map_provider: the map provider the image is from 
    image: The PIL image object
    """
    height: int
    width: int
    lat: float
    long: float
    image_id: int
    map_provider: str
    image: Image


def get_image_metadata(image_binary: bytes) -> ImageMetadata:  # pragma: no cover
        # Try to read the image and if we fail, we have to default to
        # to the null image case
        image_binary = BytesIO(image_binary)

        try:
            image = Image.open(image_binary)
            exif = image._getexif()

        except FileNotFoundError:
            exif = None
        except UnicodeDecodeError:
            exif = None

        user_comment_exif_id = 37510

        if exif is None or user_comment_exif_id not in exif:
            # we need to return with default values
            fake_image = Image.new('RGB', (640, 640), (0, 0, 0))
            return {
                "height": 640,
                "width": 640,
                "lat": 0.0,
                "long": 0.0,
                "image_id": -1,
                "map_provider": "unknown",
                "image": fake_image
            }
        
        try:
            user_comment_exif = exif[user_comment_exif_id]
            exif_dict = loads(
                user_comment_exif.decode("utf-8").replace("\'", "\"")
            )
        
        except UnicodeDecodeError as e:
            # can we gracefully handle this?
            raise ValueError(f"Unable to decode exif data: {e}")
        
        image_id = -1 if "id" not in exif_dict else int(exif_dict["id"])
        return {
            "height": image.height,
            "width": image.width,
            "lat": exif_dict["lat"],
            "long": exif_dict["lng"],
            "image_id": image_id,
            "map_provider": exif_dict["mapProvider"],
            "image": image
        }


class ImageBinaryDataset(Dataset):
    def __init__(self, images):
        self.images = images
    
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, index) -> ImageMetadata:
         return get_image_metadata(self.images[index])
    

def inference_collate_fn(data: list[dict[str, ImageMetadata]]) -> dict[str, Union[Image,dict[str, Any]]]:
    batch = defaultdict(list)
    for item in data:
        batch["images"].append(item.pop("image"))
        batch["images_metadata"].append(item)

    return batch
    

In [0]:
request_id = "be69e91f"
user_id = "cnu4"
base_path = f"/Volumes/edav_dev_csels/towerscout/images/maps/bronze/{user_id}/{request_id}"


image_df = (
    spark
    .read
    .format("binaryFile")
    .load(base_path) # parameterize
    .select("content")
    .limit(20)
    #.repartition(8)
    #.withColumn("inference", yolo_inference_udf(F.col("content")))
)


image_df = image_df.toPandas()
image_bins = image_df["content"]

bin_dataset = ImageBinaryDataset(image_bins)

batch_size = 4
num_workers = 4

loader = DataLoader(
        bin_dataset, 
        batch_size=batch_size, 
        num_workers=num_workers,
        collate_fn=inference_collate_fn 
    )

mlflow.set_registry_uri("databricks-uc")

In [0]:
yolo_model = mlflow.pytorch.load_model(
            model_uri=f"models:/edav_dev_csels.towerscout.yolo_autoshape@aws"
        )

en_model =  mlflow.pytorch.load_model(
            model_uri=f"models:/edav_dev_csels.towerscout.efficientnet@aws"
        )


yolo_model.eval()
en_model.eval()

if torch.cuda.is_available():  # pragma: no cover
    en_model.cuda()
    yolo_model.cuda()

In [0]:
def parse_yolo_detections(
    images: list[Image],
    images_metadata: dict[str, Any],
    yolo_results: Detections,
    secondary_model: torch.nn.Module = None,
    **kwargs
) -> list[dict[str, Any]]:
    """
    A function to parse the detections from the YOLO model by converting them into a list
    of dicts with the following keys:
    - x1: the x1 coordinate of the bounding box
    - y1: the y1 coordinate of the bounding box
    - x2: the x2 coordinate of the bounding box
    - y2: the y2 coordinate of the bounding box
    - conf: the YOLO model confidence of the detection
    - class: the class of the detection
    - class_name: the name of the class of the detection
    - secondary: the secondary model confidence of the detection (if a secondary model is supplied)

    Args:
        images: the list of PIL images
        images_metadata: the metadata for the images
        yolo_results: the Detections object from the YOLO model
        secondary_model: the secondary model used to evaluate the detections
        **kwargs: additional keyword arguments to pass to the secondary model
    Returns:
        A list of dicts with the keys from above.
    """
    parsed_results = []
    batch_detections = yolo_results.xyxyn

    for image, image_detections in zip(images, batch_detections):
        image_detections = image_detections.cpu().numpy().tolist()
        
        if secondary_model is not None:
            apply_secondary_model(secondary_model, image, image_detections, **kwargs)

        image_results = [
                    {
                        "x1": item[0],
                        "y1": item[1],
                        "x2": item[2],
                        "y2": item[3],
                        "conf": item[4],
                        "class": int(item[5]),
                        "class_name": yolo_results.names[int(item[5])],
                        "secondary": item[6] if len(item) > 6 else 1,
                    }
                    for item in image_detections
                ]

        parsed_results.append(image_results)

    return parsed_results


def apply_secondary_model(
    secondary_model: torch.nn.Module,
    image: Image,
    detections: list[np.array],
    min_conf: float = 0.25,
    max_conf: float = 0.65,
) -> None:
    """
    A function to apply the secondary model to the detections from the YOLO model. The function 
    first crops the image based on the bounding box predicted by the YOLO model, then applies 
    the secondary model to the cropped image to determine the probablity the image contains a cooling tower
    and appends the computed probability to the detection array.

    Args:
        secondary_model: the secondary model to apply to the cropped image
        image: the image to crop
        detections: list of the detections from the YOLO model for the input image
        min_conf: the minimum confidence to apply the secondary model
        max_conf: the maximum confidence to apply the secondary model
    """
    transform = transforms.Compose(
        [
            transforms.Resize([456, 456]),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=(0.5553, 0.5080, 0.4960), std=(0.1844, 0.1982, 0.2017)
            ),
        ]
    )

    for detection in detections:
        x1, y1, x2, y2, conf = detection[0:5]

        # Use secondary model only for certain confidence range
        if conf >= min_conf and conf <= max_conf:
            bbox_cropped_image = cut_square_detection(image, x1, y1, x2, y2)

            # apply transformations
            input = transform(bbox_cropped_image).unsqueeze(0)

            if torch.cuda.is_available():  # pragma: no cover
                input = input.cuda()

            # subtract from 1 because the secondary has class 0 as tower
            output = 1 - torch.sigmoid(secondary_model(input).cpu()).item()
            p2 = output
        elif conf < min_conf:
            # set secondary classifier probability to 0
            p2 = 0
        else:
            # if >= max_conf set secondary classifier probability to 1
            p2 = 1

        detection.append(p2)
    
    return 


# for batch in loader:
#     #print("INPUT BATCH:\n", batch['images'])
#     #print(batch['image'][0].size())
#     yolo_output = yolo_model(batch["images"])
#     print(f"yolo output: {yolo_output.xyxyn}")
#     parsed_results = parse_yolo_detections(batch["images"], batch["images_metadata"], yolo_output, en_model, min_conf=0.65, max_conf=0.95)
#     print("PARSED RESULTS:\n", parsed_results)
#     break

In [0]:
def make_towerscout_predict_udf(
    catalog: str,
    schema: str,
    yolo_alias: str = "aws",
    efficientnet_alias: str = "aws",
    batch_size: int = 100
) -> DataFrame:  # pragma: no cover
    """
    For a pandas UDF, we need the outer function to initialize the models
    and the inner function to perform the inference process. For more
    information, see the following reference by NVIDIA:
    -
 
    Args:
        model_fn (InferenceModelType): The PyTorch model.
        batch_size (int): Batch size for the DataLoader.
 
    Returns:
        DataFrame: DataFrame with predictions.
    """
    set_registry_uri("databricks-uc")
 
    yolo_model_name = f"{catalog}.{schema}.yolo_autoshape"
    en_model_name = f"{catalog}.{schema}.efficientnet"  
 
    # Retrieves models by alias and create inference objects
    yolo_detector = mlflow.pytorch.load_model(
            model_uri=f"models:/{yolo_model_name}@{yolo_alias}"
        )
    
    YOLOv5_Detector.from_uc_registry(
        model_name=yolo_model_name,
        alias=yolo_alias,
        batch_size=batch_size,
    )
 
    # We nearly always use efficientnet for classification but you don't have to
    en_classifier = mlflow.pytorch.load_model(
            model_uri=f"models:/{en_model_name}@{efficientnet_alias}"
        )
 
    metadata = {
        "yolo_model": "yolo_autoshape",
        "yolo_model_version": yolo_detector.uc_version,
        "efficientnet_model": "efficientnet",
        "efficientnet_model_version": en_classifier.uc_version,
    }
 
    return_type = T.StructType([
        T.StructField("bboxes", yolo_detector.return_type),
        T.StructField("model_version", MODEL_VERSION_STRUCT)
    ])
 
    @no_grad()
    def predict(content_series_iter: pd.Series):  # pragma: no cover
        """
        This predict function is distributed across executors to perform inference.
 
        YOLOv5 library expects the following image formats:
        For size(height=640, width=1280), RGB images example inputs are:
        #   file:        ims = 'data/images/zidane.jpg'  # str or PosixPath
        #   URI:             = 'https://ultralytics.com/images/zidane.jpg'
        #   OpenCV:          = cv2.imread('image.jpg')[:,:,::-1]  # HWC BGR to RGB x(640,1280,3)
        #   PIL:             = Image.open('image.jpg') or ImageGrab.grab()  # HWC x(640,1280,3)
        #   numpy:           = np.zeros((640,1280,3))  # HWC
        #   torch:           = torch.zeros(16,3,320,640)  # BCHW (scaled to size=640, 0-1 values)
        #   multiple:        = [Image.open('image1.jpg'), Image.open('image2.jpg'), ...]  # list of images
        - Source: https://github.com/ultralytics/yolov5/blob/master/models/common.py
       
        The ultralytics lib accepts the following image formats:
        - Source: https://github.com/ultralytics/ultralytics/blob/main/ultralytics/engine/model.py
       
 
        # No need to resize for yolov5 lib as it does it for you
        - Source: letterbox and exif_transpose funcs in:
            https://github.com/ultralytics/yolov5/blob/master/models/common.py
 
        Args:
            content_series_iter: Iterator over content series.
 
        Yields:
            DataFrame: DataFrame with predicted labels.
        """
        for content_series in content_series_iter:
            # Create dataset object to apply transformations
            image_batch = [
                Image.open(BytesIO(content)).convert("RGB")
                for content in content_series
            ]
 
            # Perform inference on batch
            outputs = yolo_detector.predict(
                model_input=image_batch,
                secondary=en_classifier
            )
 
            outputs = [
                {"bboxes": output, "model_version": metadata}
                for output in outputs
            ]
            yield pd.DataFrame(outputs)
 
    return pandas_udf(return_type, PandasUDFType.SCALAR_ITER)(predict)