In [None]:
from io import BytesIO
import os
import sys
from pathlib import Path
import requests
import warnings

import matplotlib.pyplot as plt
import numpy as np
from PIL import Image

from lang_sam import LangSAM

# Suppress warning messages
warnings.filterwarnings("ignore")

In [None]:
def download_image(url):
    response = requests.get(url)
    response.raise_for_status()
    return Image.open(BytesIO(response.content)).convert("RGB")

def save_mask(mask_np, filename):
    mask_image = Image.fromarray((mask_np * 255).astype(np.uint8))
    mask_image.save(filename)

def display_image_with_masks(image, masks, boxes, logits, figwidth=15, savefig=None, all_masks=True):
    if not all_masks:
        masks = masks[:1]
    num_masks = len(masks)

    fig, axes = plt.subplots(1, num_masks + 1, figsize=(figwidth, 5))
    axes[0].imshow(image)
    axes[0].set_title("Image with Bounding Boxes")
    axes[0].axis('off')

    for box, logit in zip(boxes, logits):
        x_min, y_min, x_max, y_max = box
        confidence_score = round(logit.item(), 2)  # Convert logit to a scalar before rounding
        box_width = x_max - x_min
        box_height = y_max - y_min

        # Draw bounding box
        rect = plt.Rectangle((x_min, y_min), box_width, box_height, fill=False, edgecolor='red', linewidth=2)
        axes[0].add_patch(rect)

        # Add confidence score as text
        axes[0].text(x_min, y_min, f"Confidence: {confidence_score}", fontsize=8, color='red', verticalalignment='top')

    for i, mask_np in enumerate(masks):
        axes[i+1].imshow(mask_np, cmap='gray')
        axes[i+1].set_title(f"Mask {i+1}")
        axes[i+1].axis('off')

    if savefig is not None:
        fig.savefig(savefig)
    plt.tight_layout()
    plt.show()

def display_image_with_boxes(image, boxes, logits):
    fig, ax = plt.subplots()
    ax.imshow(image)
    ax.set_title("Image with Bounding Boxes")
    ax.axis('off')

    for box, logit in zip(boxes, logits):
        x_min, y_min, x_max, y_max = box
        confidence_score = round(logit.item(), 2)  # Convert logit to a scalar before rounding
        box_width = x_max - x_min
        box_height = y_max - y_min

        # Draw bounding box
        rect = plt.Rectangle((x_min, y_min), box_width, box_height, fill=False, edgecolor='red', linewidth=2)
        ax.add_patch(rect)

        # Add confidence score as text
        ax.text(x_min, y_min, f"Confidence: {confidence_score}", fontsize=8, color='red', verticalalignment='top')

    plt.show()

def print_bounding_boxes(boxes):
    print("Bounding Boxes:")
    for i, box in enumerate(boxes):
        print(f"Box {i+1}: {box}")

def print_detected_phrases(phrases):
    print("\nDetected Phrases:")
    for i, phrase in enumerate(phrases):
        print(f"Phrase {i+1}: {phrase}")

def print_logits(logits):
    print("\nConfidence:")
    for i, logit in enumerate(logits):
        print(f"Logit {i+1}: {logit}")

In [None]:
model = LangSAM()

In [None]:
image_dir = Path("data/barrelddt1")

reconstr_dir = Path(f"results/{image_dir.name}-reconstr")
mask_dir = reconstr_dir / "masks"
maskcomp_dir = reconstr_dir / "image_with_masks"
mask_dir.mkdir(parents=True, exist_ok=True)
text_prompt = "underwater barrel"
imgpaths = sorted(image_dir.glob("*.jpg"))

In [None]:
for i, imgpath in enumerate(imgpaths):
    print(f"Processing image: {imgpath}")
    image_pil = Image.open(imgpath).convert("RGB")

    masks, boxes, phrases, logits = model.predict(image_pil, text_prompt)

    if len(masks) == 0:
        print(f"No objects of the '{text_prompt}' prompt detected in the image.")
        fig, ax = plt.subplots(1, 1)
        ax.imshow(image_pil)
        plt.show()
    else:
        # Convert masks to numpy arrays
        masks_np = [mask.squeeze().cpu().numpy() for mask in masks]

        bbox_mask_path = maskcomp_dir / f"{imgpath.stem}_img_with_mask.png"
        bbox_mask_path.parent.mkdir(parents=True, exist_ok=True)
        # Display the original image and masks side by side
        display_image_with_masks(image_pil, masks_np, boxes, logits, figwidth=13, savefig=bbox_mask_path, all_masks=True)

        # Display the image with bounding boxes and confidence scores
        # display_image_with_boxes(image_pil, boxes, logits)

        # Save the masks
        for i, mask_np in enumerate(masks_np):
            mask_path = mask_dir / f"{imgpath.stem}_mask_{i+1}.png"
            save_mask(mask_np, mask_path)

        # Print the bounding boxes, phrases, and logits
        # print_bounding_boxes(boxes)
        # print_detected_phrases(phrases)
        # print_logits(logits)