In [None]:
import requests
import torchvision
from torchvision import transforms
from torchvision.io import read_image
from torchvision.utils import draw_bounding_boxes
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection 
import os
from skimage import exposure
from torchvision.transforms.functional import to_pil_image, to_tensor
import numpy as np


model_id = "IDEA-Research/grounding-dino-tiny"
device = "mps"

processor = AutoProcessor.from_pretrained(model_id)
model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)

In [None]:

image_type = "car"
root_folder = "/Users/jkerlin/PycharmProjects/satana/data/images/"
input_image_folder = root_folder + "gold_standard/"
output_image_folder = root_folder + "gold_standard_crop/"
show_example = False


if image_type == 'car':
    text = "individual cars from an satellite view."

In [None]:


# Check input folder for valid images
valid_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".gif"]
image_list = []
for file in os.listdir(input_image_folder):
    # Get the file extension
    _, extension = os.path.splitext(file)
    
    # If the file is an image, process it
    if extension.lower() in valid_extensions:
        image_list.append(file)

#loop through all the images in the input_image_folder
        
for i,image_file in enumerate(image_list):

    image_filepath = input_image_folder + image_file
    image = Image.open(image_filepath).convert("RGB")
    
    

    # # Convert the image to a numpy array
    # image_np = np.array(image)
    # 
    # # Apply contrast stretching
    # p_low, p_high = np.percentile(image_np, (0, 20))
    # image_rescale = exposure.rescale_intensity(image_np, in_range=(p_low, p_high))
    # 
    # # Convert the rescaled image back to a PIL Image
    # image = Image.fromarray(image_rescale)



    inputs = processor(images=image, text=text, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model(**inputs)

    results = processor.post_process_grounded_object_detection(
        outputs,
        inputs.input_ids,
        box_threshold=0.15,
        text_threshold=0.3,
        target_sizes=[image.size[::-1]]
    )

    # Define your size criteria
    min_dim = 15
    max_dim = 150

    # Filter the bounding boxes
    filtered_boxes = []
    box_area = []
    for box in results[0]["boxes"].tolist():
        width = box[2] - box[0]
        height = box[3] - box[1]
        if width > min_dim and height > min_dim and width < max_dim and height < max_dim:
            filtered_boxes.append(box)
            #box_area.append(width * height)

    if show_example == True and i == 0:
        tensor_boxes = torch.tensor(filtered_boxes) 
        tensor_image = transforms.ToTensor()(image)
        # Convert the Tensor back to uint8
        tensor_image = tensor_image.mul(255).byte()
    
        bbox_img = draw_bounding_boxes(tensor_image, tensor_boxes, width = 3, colors = "red")
        bbox_img = torchvision.transforms.ToPILImage()(bbox_img)
    
        bbox_img.show()
        break
    else:
        for i, box in enumerate(filtered_boxes):
            # PIL's crop function takes a tuple of (left, upper, right, lower) pixel coordinates
            # Here, we assume that each box is a list of [xmin, ymin, xmax, ymax]
            cropped_image = image.crop((box[0], box[1], box[2], box[3]))
    
            # Save the cropped image as a jpg file
            cropped_image.save(output_image_folder + image_file[:-4] + "_crop_" + str(i) + ".jpg")
            

