In [None]:
import typing
import json, sys, os
from packaging import version

import cv2
import numpy as np
import torch
import transformers
from transformers import ViTFeatureExtractor

LEGACY_FEATURE_EXTRACTOR = ViTFeatureExtractor.from_pretrained(
    "google/vit-base-patch16-224-in21k", return_tensors="pt"
)
UPDATED_FEATURE_EXTRACTOR = ViTFeatureExtractor.from_pretrained(
    "google/vit-base-patch16-224-in21k", return_tensors="pt"
)
# torchvision crop will pad output with 0s if input_image_shape < output_image_shape
from torchvision.transforms.functional import crop

sys.path.append(os.environ["BUILD_WORKSPACE_DIRECTORY"])
from core.labeling.label_store.label_reader import LabelReader
from core.utils.video_reader import S3VideoReader, S3VideoReaderInput
from core.structs.actor import Actor
from core.structs.attributes import RectangleXYWH, RectangleXYXY

if version.parse(transformers.__version__) >= version.parse("4.25.1"):
    from transformers import ViTImageProcessor

    UPDATED_FEATURE_EXTRACTOR = ViTImageProcessor.from_pretrained(
        "google/vit-base-patch16-224-in21k", return_tensors="pt"
    )

In [None]:
VIDEO_UUID = "uscold/quakertown/0001/cha/detector/1660283690_1660283700_0000"
actor_labels = {
    output["relative_timestamp_ms"]: [
        Actor.from_dict(actor)
        for actor in output["actors"]
        if actor["category"] == "PERSON_V2"
    ]
    for output in json.loads(LabelReader().read(VIDEO_UUID))["frames"]
}
video_frames = {
    output.relative_timestamp_ms: output.image
    for output in S3VideoReader(S3VideoReaderInput(VIDEO_UUID)).read()
}

## Crop Images Test

In [None]:
def crop_image_legacy(
    actor_bboxes: typing.List[Actor],
    frame: np.ndarray,
    padding: int,
):
    cropped_images = []
    for actor in actors:
        rect = RectangleXYWH.from_polygon(actor.polygon)

        cropped_image = frame[
            max(0, rect.top_left_vertice.y - padding) : min(
                frame.shape[0], rect.top_left_vertice.y + rect.h + padding
            ),
            max(0, rect.top_left_vertice.x - padding) : min(
                frame.shape[1], rect.top_left_vertice.x + rect.w + padding
            ),
        ]
        cropped_image = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2RGB)
        cropped_images.append(cropped_image)
    return cropped_images

In [None]:
def crop_images_tensor(
    actors_xyxy: typing.List[torch.Tensor],
    frame: torch.Tensor,
    padding: int,
):
    cropped_images = []
    for xyxy in actors_xyxy:
        cropped_images = []
        for xyxy in actors_xyxy:
            cropped_image = frame[
                max(0, int(xyxy[1] - padding)) : min(
                    int(frame.shape[0]), int(xyxy[3] + padding)
                ),
                max(0, int(xyxy[0] - padding)) : min(
                    int(frame.shape[1]), int(xyxy[2] + padding)
                ),
                :,
            ]
            cropped_image_rgb = cropped_image.flip(2)
            cropped_images.append(cropped_image_rgb)
    return cropped_images

In [None]:
def legacy_preprocess(
    actor_bboxes: typing.List[Actor],
    frame: np.ndarray,
    padding: int,
):
    image_crops = crop_image_legacy(actor_bboxes, frame, padding)
    if len(image_crops) == 0:
        return None
    return LEGACY_FEATURE_EXTRACTOR(image_crops, return_tensors="pt")[
        "pixel_values"
    ]


def torch_preprocess(
    actors_xyxy: typing.List[torch.Tensor],
    frame: torch.Tensor,
    padding: int,
):
    image_crops = crop_images_tensor(actors_xyxy, frame, padding)
    if len(image_crops) == 0:
        return None
    return UPDATED_FEATURE_EXTRACTOR(image_crops, return_tensors="pt")[
        "pixel_values"
    ]

In [None]:
preprocessed_images_legacy = {}
preprocessed_images_torch = {}
PADDING = 5
for ts in actor_labels.keys():
    actors = actor_labels[ts]
    frame = video_frames[ts]
    preprocessed_images_legacy[ts] = legacy_preprocess(actors, frame, PADDING)
    preprocessed_images_torch[ts] = torch_preprocess(
        [
            torch.tensor(RectangleXYXY.from_polygon(actor.polygon).to_list())
            for actor in actors
        ],
        torch.from_numpy(frame),
        PADDING,
    )
for ts in preprocessed_images_legacy.keys():
    legacy_input = preprocessed_images_legacy[ts]
    new_input = preprocessed_images_torch[ts]
    assert torch.equal(legacy_input, new_input)