In [2]:
import imageio as io
import matplotlib.pyplot as plt
import numpy as np
import os
import tqdm

from PIL import Image

import torch

from transformers import OwlViTProcessor, OwlViTForObjectDetection
from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Used device: {DEVICE}")

Used device: cuda


In [None]:
class PoseApproximator:
    def __init__(self, device):

        self.device = device

        self.od_processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
        self.od_model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32").to(self.device)

        SAM_MODEL_TYPE = "vit_h" #vit_l, vit_b
        SAM_CHECKPOINT = "sam_checkpoint/sam_vit_h_4b8939.pth"
        self.segm_model = SamPredictor(sam_model_registry[SAM_MODEL_TYPE](checkpoint=SAM_CHECKPOINT))

    def __call__(self, image, query_image, depth, od_score_thresh=0.1):
        od_inputs = self.od_processor(images=image, query_images=query_image, return_tensors="pt",).to(DEVICE)
        with torch.no_grad():
            # (1) OD
            od_outputs = self.od_model.image_guided_detection(**od_inputs)
            od_logits = torch.max(od_outputs["logits"][0], dim=-1)

            od_scores = torch.sigmoid(od_logits.values).cpu().detach().numpy()
            od_labels = od_logits.indices.cpu().detach().numpy()
            od_bboxes = od_outputs["target_pred_boxes"][0].cpu().detach().numpy()

            #TODO: filter based on score thresh
            # box: A length 4 array given a box prompt to the model, in XYXY format.

            best_bbox: np.ndarray = None

            # (2) Segmentation
            input_label = np.array([1])
            self.segm_model.set_image(image)
            segm_masks, segm_scores, segm_logits = self.segm_model.predict(
                box=best_bbox,
                point_labels=input_label,
                multimask_output=True,
            )
            max_idx = np.argmax(segm_scores)
            segm_masks = segm_masks[max_idx:max_idx + 1, ...]
            segm_scores = segm_scores[max_idx:max_idx + 1]

            # (3) Estimate 3D position (world) using the depth
            # (segment out the depth -> depth point estimate -> deproject bbox/centroid from 2D img space to 3D world Euclidean space -> (Xw, Yw, Zw))

            # (4) Tracker (track pose temporaly)

        return (od_scores, od_labels, od_bboxes), (segm_scores, segm_masks)

    def plot_boxes(self, img, scores, boxes):
        fig, ax = plt.subplots(1, 1, figsize=(8, 8))
        ax.imshow(img, extent=(0, 1, 1, 0))
        ax.set_axis_off()

        for score, box in zip(scores, boxes):
            cx, cy, w, h = box
            ax.plot(
                [cx-w/2, cx+w/2, cx+w/2, cx-w/2, cx-w/2],
                [cy-h/2, cy-h/2, cy+h/2, cy+h/2, cy-h/2], 
                "r"
            )

            ax.text(
                cx - w / 2,
                cy + h / 2 + 0.015,
                f"{score:1.2f}",
                ha="left",
                va="top",
                color="red",
                bbox={
                    "facecolor": "white",
                    "edgecolor": "red",
                    "boxstyle": "square,pad=.3"
                }
            )
        plt.show()
        plt.clf()

    def plot_masks(self, img, scores, masks):
        def show_mask(mask, ax, random_color=False):
            if random_color:
                color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
            else:
                color = np.array([255, 0, 0, 0.6])
                color[:-1] /= 255
            h, w = mask.shape[-2:]
            mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
            ax.imshow(mask_image)

        plt.figure(figsize=(10,10))

        for i, (mask, score) in enumerate(zip(masks, scores)):
            plt.imshow(img)
            show_mask(mask, plt.gca())
        
        plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
        plt.axis('off')    
        plt.show()
        plt.clf()