In [None]:
import os
import re
import numpy as np
from tqdm import tqdm
from dotenv import load_dotenv

from utils.s3_bucket import S3Bucket
from utils.predictors import ObjectDetectionPredictor

In [None]:
load_dotenv("env")

In [None]:
bucket_name = "ava-cv-raw-photo-bucket"
input_prefix = os.path.join("10-plants", "GoogleImages")
output_prefix = os.path.join("10-plants", "CroppedGoogleImagess")

In [None]:
model_predictor = ObjectDetectionPredictor(
    model_id=os.environ["OBJECT_DETECTION_MODEL_ID"],
    model_version=os.environ["OBJECT_DETECTION_MODEL_VERSION"],
    instance_type=os.environ["INFERENCE_INSTANCE_TYPE"],
)

In [None]:
%%time
model_predictor.deploy(
    instance_type=os.environ["INFERENCE_INSTANCE_TYPE"],
    instance_count=int(os.environ["INFERENCE_INSTANCE_COUNT"]),
)

In [None]:
%%time
min_confidence = float(os.environ["OBJECT_DETECTION_MIN_CONFIDENCE"])
classes_to_keep = ["potted plant"]

bucket = S3Bucket(
    bucket_name=bucket_name,
    region_name=os.environ["REGION_NAME"]
)
with tqdm(bucket.filter(prefix=input_prefix), position=0, leave=True) as pbar:
    for obj in pbar:
        key = obj.key
        s3_image_path = os.path.join("s3://", bucket_name, key)
        pbar.set_description(s3_image_path)
        
        image = bucket[key]
        image_np = np.array(image)
        try:
            normalized_boxes, classes_names, confidences, labels = model_predictor.predict(image)
        except KeyboardInterrupt as exc:
            raise exc
        except Exception:
            print(f"Exception occured when predicting bounding boxes. Skipping {s3_image_path}...")
            continue
            
        n_boxes = len(normalized_boxes)
        normalized_boxes = [
            normalized_boxes[i] for i in range(n_boxes)
            if confidences[i] >= min_confidence and classes_names[i] in classes_to_keep
        ]

        cropped_images = []
        for normalized_box in normalized_boxes:
            left, top, right, bot = normalized_box
            left, right = [val * image_np.shape[1] for val in [left, right]]
            bot, top = [val * image_np.shape[0] for val in [bot, top]]
            cropped_image = image.crop((left, top, right, bot))
            cropped_images.append(cropped_image)

        # if no bounding boxes were found, save the original image
        if not cropped_images:
            cropped_images = [image]

        for i, cropped_image in enumerate(cropped_images):
            new_key = re.sub(r"\.(jpe?g|png)", f"-{i}.jpg", key)
            new_key = new_key.replace(input_prefix, output_prefix)
            bucket[new_key] = cropped_image

In [None]:
model_predictor.delete()