In [None]:
import os
import re
import json
import numpy as np
from tqdm import tqdm
from dotenv import load_dotenv
from PIL import Image

from utils import ObjectDetectionPredictor

In [None]:
load_dotenv("env")

In [None]:
bucket_dir = os.path.join("s3://ava-cv-raw-photo-bucket", "10-plants", "GoogleImages")
local_dir = os.path.join("images", "10-plants", "GoogleImages")
if not os.path.isdir(local_dir):
    !aws s3 cp {bucket_dir} {local_dir} --recursive --only-show-errors

In [None]:
model_predictor = ObjectDetectionPredictor(
    model_id=os.environ["OBJECT_DETECTION_MODEL_ID"],
    model_version=os.environ["OBJECT_DETECTION_MODEL_VERSION"],
    instance_type=os.environ["INFERENCE_INSTANCE_TYPE"],
)

In [None]:
%%time
model_predictor.deploy(
    instance_type=os.environ["INFERENCE_INSTANCE_TYPE"],
    instance_count=int(os.environ["INFERENCE_INSTANCE_COUNT"]),
)

In [None]:
%%time
min_confidence = 0.1
classes_to_keep = ["potted plant"]

for root, dirs, files in os.walk(os.path.join("images", "10-plants", "GoogleImages")):
    if ".ipynb_checkpoints" in root:
        continue
    
    if dirs and not files:
        for dir in dirs:
            new_dir = os.path.join(root, dir).replace("GoogleImages", "CroppedGoogleImages")
            if not os.path.isdir(new_dir):
                os.makedirs(new_dir)
        continue
            
    image_filenames = [file for file in files if file.endswith((".jpg", ".jpeg", ".png"))]
    with tqdm(image_filenames, position=0, leave=True) as pbar:
        for image_filename in pbar:
            image_path = os.path.join(root, image_filename)
            pbar.set_description(image_path)
            image = Image.open(image_path).convert('RGB')
            image_np = np.array(image)
            with open(image_path, "rb") as file:
                image_binary = file.read()
            try:
                normalized_boxes, classes_names, confidences, labels = model_predictor.predict(image_binary)
            except Exception as exc:
                print(f"Exception occured when predicting bounding boxes. Skipping {image_path}...")
                continue

            n_boxes = len(normalized_boxes)
            normalized_boxes = [
                normalized_boxes[i] for i in range(n_boxes)
                if confidences[i] >= min_confidence and classes_names[i] in classes_to_keep
            ]

            cropped_images = []
            for normalized_box in normalized_boxes:
                left, top, right, bot = normalized_box
                left, right = [val * image_np.shape[1] for val in [left, right]]
                bot, top = [val * image_np.shape[0] for val in [bot, top]]
                cropped_image = image.crop((left, top, right, bot))
                cropped_images.append(cropped_image)
                
            # if no bounding boxes were found, save the original image
            if not cropped_images:
                cropped_images = [image]

            for i, cropped_image in enumerate(cropped_images):
                save_path = os.path.join(
                    root.replace("GoogleImages", "CroppedGoogleImages"),
                    re.sub(r"\.(jpe?g|png)", f"-{i}.jpg", image_filename)
                )
                cropped_image.save(save_path)

In [None]:
model_predictor.delete()

In [None]:
%%time
local_dir = local_dir.replace("GoogleImages", "CroppedGoogleImages")
bucket_dir = bucket_dir.replace("GoogleImages", "CroppedGoogleImages")
!aws s3 cp {local_dir} {bucket_dir} --recursive --only-show-errors --exclude "*" --include "*.jpg" --include "*.jpeg" --include "*.png"