In [1]:
import os
import numpy as np
import torch
from PIL import Image
import cv2
import open3d as o3d
import torch.nn.functional as F
import supervision as sv
from transformers import CLIPProcessor, CLIPModel
from Grounded_Segment_Anything.GroundingDINO.groundingdino.util import box_ops
from Grounded_Segment_Anything.GroundingDINO.groundingdino.util.inference import load_model, load_image, predict, annotate
from Grounded_Segment_Anything.segment_anything.segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator


CONFIG_PATH = "./Grounded_Segment_Anything/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
CHECKPOINT_PATH = "./models/groundingdino_swint_ogc.pth"
SAM_CHECKPOINT = "./models/sam_vit_h_4b8939.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TEXT_PROMPT = "ear"
BOX_THRESHOLD = 0.3
TEXT_THRESHOLD = 0.25
VIEWS_DIR = "./render_views/bunny/"
OUTPUT_DIR = "./new/newnewnew/"
os.makedirs(OUTPUT_DIR, exist_ok=True)



In [5]:
groundingdino_model = load_model(CONFIG_PATH, CHECKPOINT_PATH).to(DEVICE)
sam = sam_model_registry["vit_h"](checkpoint=SAM_CHECKPOINT).to(DEVICE)
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(DEVICE)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
sam_predictor = SamPredictor(sam)

def rerank_boxes_with_clip(image_source, boxes, text_prompt, top_k=1):
    """
    Refine box selection using CLIP image-text similarity.
    """
    H, W, _ = image_source.shape
    boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
    boxes_xyxy = boxes_xyxy.int()

    cropped_images = []
    for box in boxes_xyxy.view(-1, 4):
        x1, y1, x2, y2 = box.tolist()
        crop = image_source[y1:y2, x1:x2]
        if crop.size == 0:  # avoid invalid crops
            continue
        crop_pil = Image.fromarray(crop).convert("RGB")
        cropped_images.append(crop_pil)

    if not cropped_images:
        return torch.empty((0, 4), dtype=boxes.dtype).to(boxes.device)

    inputs = clip_processor(
        text=[text_prompt] * len(cropped_images),
        images=cropped_images,
        return_tensors="pt",
        padding=True
    ).to(DEVICE)

    with torch.no_grad():
        outputs = clip_model(**inputs)
        logits_per_image = outputs.logits_per_image  # shape [N, 1]
        probs = F.softmax(logits_per_image.squeeze(), dim=0)

    # Select top-k indices and remove extra dimensions
    top_indices = probs.topk(top_k).indices
    filtered_boxes = boxes[top_indices]
    if filtered_boxes.ndim == 1:
        filtered_boxes = filtered_boxes.unsqueeze(0)

    return filtered_boxes


def segment(image, sam_model, boxes):
  sam_model.set_image(image)
  H, W, _ = image.shape
  boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])

  transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_xyxy.to(DEVICE), image.shape[:2])
  masks, _, _ = sam_model.predict_torch(
      point_coords = None,
      point_labels = None,
      boxes = transformed_boxes,
      multimask_output = False,
      )
  return masks.cpu()
  

def draw_mask(mask, image, random_color=True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    
    annotated_frame_pil = Image.fromarray(image).convert("RGBA")
    mask_image_pil = Image.fromarray((mask_image.cpu().numpy() * 255).astype(np.uint8)).convert("RGBA")

    return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))

def extract_segmented_object(image, mask):
    """Extracts only the masked object from the image (black background)."""
    binary_mask = (mask > 0).astype(np.uint8)

    # Apply the mask to each channel
    segmented = cv2.bitwise_and(image, image, mask=binary_mask)

    return segmented

def box_to_pixel(box, image_shape):
    h, w = image_shape[:2]
    cx, cy, bw, bh = box
    x1 = int((cx - bw / 2) * w)
    y1 = int((cy - bh / 2) * h)
    x2 = int((cx + bw / 2) * w)
    y2 = int((cy + bh / 2) * h)
    return np.array([x1, y1, x2, y2])

def get_masks_only(boxes, image_source, image_rgb):
    box = boxes[0].cpu().numpy()
    box_pixel = box_to_pixel(box, image_source.shape)

    sam_predictor.set_image(image_rgb)
    masks, scores, _ = sam_predictor.predict(
        box=box_pixel,
        multimask_output=True
    )

    best_mask = masks[np.argmax(scores)]

    return (best_mask.astype(np.uint8)) * 255

def auto_mask(image_source, base_name):
    mask_generator = SamAutomaticMaskGenerator(sam)
    masks = mask_generator.generate(image_source)

    # Pick the largest mask (or you could use CLIP scoring here)
    best_mask = sorted(masks, key=lambda x: x['area'], reverse=True)[0]['segmentation']
    auto_mask_render = extract_segmented_object(image_source, best_mask)

    cv2.imwrite(os.path.join(OUTPUT_DIR, f"{base_name}_fallback_overlay.png"), auto_mask_render)
    return auto_mask_render

def segment_and_save_views():
    """Segment all views and save results as images."""
    view_files = sorted([f for f in os.listdir(VIEWS_DIR) if f.endswith(('.png', '.jpg'))])
    
    if not view_files:
        print(f"No images found in {VIEWS_DIR}")
        return
    
    print(f"Found {len(view_files)} views to process")
    
    for view_file in view_files:
        print(f"Processing {view_file}...")
        view_path = os.path.join(VIEWS_DIR, view_file)
        
        try:
            # Load and prepare image
            image_source, image = load_image(view_path)
            image_rgb = cv2.cvtColor(image_source, cv2.COLOR_BGR2RGB)

            # Get boxes from GroundingDINO
            boxes, logits, _ = predict(
                model=groundingdino_model,
                image=image,
                caption=TEXT_PROMPT,
                box_threshold=BOX_THRESHOLD,
                text_threshold=TEXT_THRESHOLD,
                device=DEVICE
            )
            boxes = rerank_boxes_with_clip(image_source, boxes, TEXT_PROMPT, top_k=1)

            if len(boxes) == 0:
                print(f"No objects detected in {view_file}")
                continue
            

            # Save results
            base_name = os.path.splitext(view_file)[0]

            # Save annotation with boxes
            annotated = annotate(
                image_source=image_source,
                boxes=boxes,
                logits=logits,
                phrases=[TEXT_PROMPT]*len(boxes)
            )
            annotated = annotated[...,::-1]

            segmented_frame_masks = segment(image_source, sam_predictor, boxes=boxes)
            fused_mask = segmented_frame_masks[0].float().mean(dim=0).cpu().numpy()
            binary_mask = (fused_mask > 0.5).astype(np.uint8) * 255

            annotated_frame_with_mask = draw_mask(segmented_frame_masks[0][0], annotated)
            masked = get_masks_only(boxes, image_source, image_rgb)

            cv2.imwrite(os.path.join(OUTPUT_DIR, f"{base_name}_annotated.png"), annotated)
            cv2.imwrite(os.path.join(OUTPUT_DIR, f"{base_name}_highlighted.png"), annotated_frame_with_mask)
            highlighted_on_original = extract_segmented_object(image_source, binary_mask)#masked
            cv2.imwrite(os.path.join(OUTPUT_DIR, f"{base_name}_masked_overlay.png"), highlighted_on_original)

            auto_mask(image_source, base_name)

            print(f"Saved results for {view_file}")
            
        except Exception as e:
            print(f"Error processing {view_file}: {str(e)}")

if __name__ == "__main__":
    segment_and_save_views()

    print("Segmentation complete! Check the output directory for results.")

final text_encoder_type: bert-base-uncased
Found 5 views to process
Processing view_0.png...




Saved results for view_0.png
Processing view_1.png...
Error processing view_1.png: xyxy must be a 2D np.ndarray with shape (_, 4), but got shape (2, 1, 4)
Processing view_2.png...
Error processing view_2.png: xyxy must be a 2D np.ndarray with shape (_, 4), but got shape (2, 1, 4)
Processing view_3.png...




KeyboardInterrupt: 