In [None]:
import os
import numpy as np
import torch
from PIL import Image
import cv2
import open3d as o3d
import supervision as sv
from Grounded_Segment_Anything.GroundingDINO.groundingdino.util import box_ops
from Grounded_Segment_Anything.GroundingDINO.groundingdino.util.inference import load_model, load_image, predict, annotate
from Grounded_Segment_Anything.segment_anything.segment_anything import sam_model_registry, SamPredictor
from transformers import BlipProcessor, BlipForConditionalGeneration

CONFIG_PATH = "./Grounded_Segment_Anything/GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
CHECKPOINT_PATH = "./models/groundingdino_swint_ogc.pth"
SAM_CHECKPOINT = "./models/sam_vit_h_4b8939.pth"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TEXT_PROMPT = "bunny ear"
BOX_THRESHOLD = 0.3
TEXT_THRESHOLD = 0.25
VIEWS_DIR = "./render_views/bunny"
OUTPUT_DIR = "./new/new_bunny"
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
groundingdino_model = load_model(CONFIG_PATH, CHECKPOINT_PATH).to(DEVICE)
sam = sam_model_registry["vit_h"](checkpoint=SAM_CHECKPOINT).to(DEVICE)
sam_predictor = SamPredictor(sam)
view_list = []

def segment(image, sam_model, boxes):
  sam_model.set_image(image)
  H, W, _ = image.shape
  boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])

  transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_xyxy.to(DEVICE), image.shape[:2])
  masks, _, _ = sam_model.predict_torch(
      point_coords = None,
      point_labels = None,
      boxes = transformed_boxes,
      multimask_output = False,
      )
  return masks.cpu()
  

def draw_mask(mask, image, random_color=True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    
    annotated_frame_pil = Image.fromarray(image).convert("RGBA")
    mask_image_pil = Image.fromarray((mask_image.cpu().numpy() * 255).astype(np.uint8)).convert("RGBA")

    return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))

def extract_segmented_object(image, mask):
    """Extracts only the masked object from the image (black background)."""
    # Ensure mask is binary (0 or 1)
    binary_mask = (mask > 0).astype(np.uint8)

    # Apply the mask to each channel
    segmented = cv2.bitwise_and(image, image, mask=binary_mask)

    return segmented

def box_to_pixel(box, image_shape):
    h, w = image_shape[:2]
    cx, cy, bw, bh = box
    x1 = int((cx - bw / 2) * w)
    y1 = int((cy - bh / 2) * h)
    x2 = int((cx + bw / 2) * w)
    y2 = int((cy + bh / 2) * h)
    return np.array([x1, y1, x2, y2])

def get_masks_only(boxes, image_source, image_rgb):
    box = boxes[0].cpu().numpy()
    box_pixel = box_to_pixel(box, image_source.shape)

    sam_predictor.set_image(image_rgb)
    masks, scores, _ = sam_predictor.predict(
        box=box_pixel,
        multimask_output=True
    )

    best_mask = masks[np.argmax(scores)]

    return (best_mask.astype(np.uint8)) * 255


def segment_and_save_views():
    """Segment all views and save results as images."""
    view_files = sorted([f for f in os.listdir(VIEWS_DIR) if f.endswith(('.png', '.jpg'))])
    
    if not view_files:
        print(f"No images found in {VIEWS_DIR}")
        return
    
    print(f"Found {len(view_files)} views to process")
    
    for view_file in view_files:
        print(f"Processing {view_file}...")
        view_path = os.path.join(VIEWS_DIR, view_file)
        
        try:
            # Load and prepare image
            image_source, image = load_image(view_path)
            image_rgb = cv2.cvtColor(image_source, cv2.COLOR_BGR2RGB)

            # Get boxes from GroundingDINO
            boxes, logits, _ = predict(
                model=groundingdino_model,
                image=image,
                caption=TEXT_PROMPT,
                box_threshold=BOX_THRESHOLD,
                text_threshold=TEXT_THRESHOLD,
                device=DEVICE
            )
            
            if len(boxes) == 0:
                print(f"No objects detected in {view_file}")
                continue
            

            # Save results
            base_name = os.path.splitext(view_file)[0]

            # Save annotation with boxes
            annotated = annotate(
                image_source=image_source,
                boxes=boxes,
                logits=logits,
                phrases=[TEXT_PROMPT]*len(boxes)
            )

            segmented_frame_masks = segment(image_source, sam_predictor, boxes=boxes)
            annotated_frame_with_mask = draw_mask(segmented_frame_masks[0][0], annotated)
            masked = get_masks_only(boxes, image_source, image_rgb)

            cv2.imwrite(os.path.join(OUTPUT_DIR, f"{base_name}_annotated.png"), annotated)
            cv2.imwrite(os.path.join(OUTPUT_DIR, f"{base_name}_highlighted.png"), annotated_frame_with_mask)
            highlighted_on_original = extract_segmented_object(image_source, masked)
            cv2.imwrite(os.path.join(OUTPUT_DIR, f"{base_name}_masked_overlay.png"), highlighted_on_original)

            
            view_list.append(highlighted_on_original)

            print(f"Saved results for {view_file}")
            
        except Exception as e:
            print(f"Error processing {view_file}: {str(e)}")

    return view_list
if __name__ == "__main__":
    segment_and_save_views()

    print("Segmentation complete! Check the output directory for results.")

final text_encoder_type: bert-base-uncased
Found 5 views to process
Processing view_0.png...




Saved results for view_0.png
Processing view_1.png...




Saved results for view_1.png
Processing view_2.png...




Saved results for view_2.png
Processing view_3.png...




Saved results for view_3.png
Processing view_4.png...




Saved results for view_4.png
Segmentation complete! Check the output directory for results.


## 3d fusion

In [None]:
# Initialize models
groundingdino_model = load_model(CONFIG_PATH, CHECKPOINT_PATH).to(DEVICE)
sam = sam_model_registry["vit_h"](checkpoint=SAM_CHECKPOINT).to(DEVICE)
sam_predictor = SamPredictor(sam)

def segment_views():
    """Segment all views using GroundingDINO + SAM"""
    masks = []
    view_files = sorted([f for f in os.listdir(VIEWS_DIR) if f.endswith('.png')])
    
    for view_file in view_files:
        # Load image
        view_path = os.path.join(VIEWS_DIR, view_file)
        image_source, image = load_image(view_path)
        
        # Get boxes from GroundingDINO
        boxes, logits, _ = predict(
            model=groundingdino_model,
            image=image,
            caption=TEXT_PROMPT,
            box_threshold=BOX_THRESHOLD,
            text_threshold=TEXT_THRESHOLD
        )
        
        if len(boxes) == 0:
            print(f"No objects detected in {view_file}")
            masks.append(np.zeros(image_source.shape[:2], dtype=np.uint8))
            continue
        
        # Get masks from SAM
        sam_predictor.set_image(image_source)
        box = boxes[0].cpu().numpy()  # Use the highest-scoring box
        mask, _, _ = sam_predictor.predict(box=box)
        mask = mask[0].astype(np.uint8) * 255
        
        masks.append(mask)
        cv2.imwrite(os.path.join(OUTPUT_DIR, f"mask_{view_file}"), mask)
        
        # Save visualization (optional)
        annotated = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=[TEXT_PROMPT]*len(boxes))
        cv2.imwrite(os.path.join(OUTPUT_DIR, f"annotated_{view_file}"), annotated)
    
    return masks

def reconstruct_3d(masks, camera_poses, K, voxel_size=0.01):
    """Reconstruct 3D from masks using TSDF fusion"""
    volume = o3d.pipelines.integration.ScalableTSDFVolume(
        voxel_length=voxel_size,
        sdf_trunc=3 * voxel_size,
        color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8,
    )
    
    for i, (mask, pose) in enumerate(zip(masks, camera_poses)):
        # Create depth map (assuming mask bounds the object)
        depth = np.where(mask > 0, 0.5, 0)  # Replace with actual depth if available
        
        # Convert to Open3D types
        color = o3d.io.read_image(os.path.join(VIEWS_DIR, f"view_{i}.png"))
        depth = o3d.geometry.Image(depth.astype(np.float32))
        
        # Integrate into TSDF volume
        intrinsic = o3d.camera.PinholeCameraIntrinsic(
            width=masks[0].shape[1],
            height=masks[0].shape[0],
            fx=K[0,0], fy=K[1,1],
            cx=K[0,2], cy=K[1,2],
        )
        extrinsic = np.linalg.inv(pose)  # Camera-to-world transform
        volume.integrate(
            o3d.geometry.RGBDImage.create_from_color_and_depth(
                color, depth, depth_scale=1.0, depth_trunc=1.0, convert_rgb_to_intensity=False
            ),
            intrinsic,
            extrinsic,
        )
    
    return volume.extract_triangle_mesh()

# Main pipeline
if __name__ == "__main__":
    # 1. Segment all views
    masks = segment_views()
    
    # 2. Load/generate camera poses (replace with your actual poses)
    num_views = len(masks)
    camera_poses = []
    for i in range(num_views):
        angle = 2 * np.pi * i / num_views
        pose = np.eye(4)
        pose[:3, 3] = [np.cos(angle), np.sin(angle), 0.5]  # Circular path
        camera_poses.append(pose)
    
    # 3. Create intrinsic matrix (replace with your camera parameters)
    K = np.array([
        [500, 0, masks[0].shape[1]/2],
        [0, 500, masks[0].shape[0]/2],
        [0, 0, 1]
    ])
    
    # 4. Reconstruct 3D
    mesh = reconstruct_3d(masks, camera_poses, K)
    
    # 5. Save and visualize
    o3d.io.write_triangle_mesh("reconstructed.ply", mesh)
    o3d.visualization.draw_geometries([mesh], mesh_show_wireframe=True)

In [12]:
import numpy as np

# Configuration
IMAGE_WIDTH = 800  # Adjust to your image dimensions
IMAGE_HEIGHT = 600
FOCAL_LENGTH = 1000  # In pixels
DISTANCE_FROM_ORIGIN = 2.0  # Distance from object center
ANGLES = [0, 45, 90, 135, 180]  # Your specified angles

def generate_camera_parameters():
    """Generate camera intrinsics and extrinsics for each view"""
    cameras = {}
    
    # Intrinsic matrix (same for all cameras in this setup)
    K = np.array([
        [FOCAL_LENGTH, 0, IMAGE_WIDTH/2],
        [0, FOCAL_LENGTH, IMAGE_HEIGHT/2],
        [0, 0, 1]
    ])
    
    # Generate extrinsic parameters for each angle
    for angle in ANGLES:
        # Convert angle to radians
        theta = np.radians(angle)
        
        # Camera position (circular path around object)
        cam_pos = np.array([
            DISTANCE_FROM_ORIGIN * np.sin(theta),
            0,
            DISTANCE_FROM_ORIGIN * np.cos(theta)
        ])
        
        # Look at origin (object center)
        look_at = np.array([0, 0, 0])
        
        # Up vector (assuming Y is up)
        up = np.array([0, 1, 0])
        
        # Create view matrix (camera extrinsics)
        z_axis = (look_at - cam_pos)
        z_axis /= np.linalg.norm(z_axis)
        x_axis = np.cross(up, z_axis)
        x_axis /= np.linalg.norm(x_axis)
        y_axis = np.cross(z_axis, x_axis)
        
        R = np.vstack([x_axis, y_axis, z_axis]).T
        t = -R @ cam_pos
        
        # Create 4x4 transformation matrix
        extrinsic = np.eye(4)
        extrinsic[:3, :3] = R
        extrinsic[:3, 3] = t
        
        cameras[f"angle_{angle}"] = {
            "intrinsic": K,
            "extrinsic": extrinsic,
            "position": cam_pos,
            "rotation": R
        }
    
    return cameras

In [None]:
def reconstruct_3d_with_known_cameras(views_dir, masks_dir, output_path="reconstruction.ply"):
    cameras = generate_camera_parameters()
    point_cloud = o3d.geometry.PointCloud()
    
    mask_files = sorted([f for f in os.listdir(masks_dir) if f.endswith('_mask.png')])
    
    for mask_file, angle in zip(mask_files, ANGLES):
        view_file = mask_file.replace('_mask.png', '.png')
        view_path = os.path.join(views_dir, view_file)
        mask_path = os.path.join(masks_dir, mask_file)
        
        # Load images
        view = cv2.imread(view_path)
        mask = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        
        # Get camera parameters for this angle
        cam_data = cameras[f"angle_{angle}"]
        K = cam_data["intrinsic"]
        extrinsic = cam_data["extrinsic"]
        
        # Create depth map (simplified - assumes flat depth)
        # In a real scenario, you'd need proper depth estimation
        depth_map = np.where(mask > 128, DISTANCE_FROM_ORIGIN, 0).astype(np.float32)
        
        # Back-project to 3D using camera parameters
        h, w = depth_map.shape
        u, v = np.meshgrid(np.arange(w), np.arange(h))
        uv_homog = np.stack([u.flatten(), v.flatten(), np.ones_like(u.flatten())], axis=1)
        
        # Transform to camera coordinates
        points_cam = (np.linalg.inv(K) @ uv_homog.T).T * depth_map.flatten()[:, None]
        
        # Transform to world coordinates
        points_world = (extrinsic[:3, :3] @ points_cam.T + extrinsic[:3, [3]]).T
        
        # Filter valid points
        valid_points = points_world[depth_map.flatten() > 0]
        
        if len(valid_points) > 0:
            point_cloud.points.extend(o3d.utility.Vector3dVector(valid_points))
    
    # Post-processing
    point_cloud, _ = point_cloud.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0)
    
    # Surface reconstruction
    mesh, _ = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(point_cloud)
    
    # Save result
    o3d.io.write_triangle_mesh(output_path, mesh)
    print(f"Saved 3D reconstruction to {output_path}")

## blip

In [None]:
# Initialize models
print("Loading models...")
groundingdino_model = load_model(CONFIG_PATH, CHECKPOINT_PATH).to(DEVICE)
sam = sam_model_registry["vit_h"](checkpoint=SAM_CHECKPOINT).to(DEVICE)
sam_predictor = SamPredictor(sam)

print("Loading BLIP-2 model...")
blip_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(DEVICE)

def generate_prompt_with_blip2(image_pil):
    inputs = blip_processor(image_pil, return_tensors="pt").to(DEVICE)
    out = blip_model.generate(**inputs, max_new_tokens=20)
    caption = blip_processor.decode(out[0], skip_special_tokens=True)
    return caption

def segment(image, sam_model, boxes):
    sam_model.set_image(image)
    H, W, _ = image.shape
    boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])
    transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_xyxy.to(DEVICE), image.shape[:2])
    masks, _, _ = sam_model.predict_torch(
        point_coords = None,
        point_labels = None,
        boxes = transformed_boxes,
        multimask_output = False,
    )
    return masks.cpu()

def draw_mask(mask, image, random_color=True):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)

    annotated_frame_pil = Image.fromarray(image).convert("RGBA")
    mask_image_pil = Image.fromarray((mask_image.cpu().numpy() * 255).astype(np.uint8)).convert("RGBA")
    return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))

def get_camera_pose(angle):
    rad = np.deg2rad(angle)
    cam_pos = np.array([np.sin(rad), 0, np.cos(rad)])  # circular orbit
    look_at = np.array([0, 0, 0])
    forward = look_at - cam_pos
    forward /= np.linalg.norm(forward)
    right = np.cross(np.array([0, 1, 0]), forward)
    up = np.cross(forward, right)
    rot = np.stack([right, up, forward], axis=1)
    return rot, cam_pos

def backproject(mask, cam_rot, cam_pos, voxel_grid):
    indices = torch.nonzero(mask[0], as_tuple=False)
    for y, x in indices:
        direction = get_ray_direction(x.item(), y.item(), cam_rot)
        for depth in torch.linspace(0.5, 2.0, steps=64):
            pt = cam_pos + depth.item() * direction
            grid_pt = torch.from_numpy((pt + 1.0) * (VOXEL_GRID_SIZE / 2)).long()
            if ((grid_pt >= 0) & (grid_pt < VOXEL_GRID_SIZE)).all():
                voxel_grid[grid_pt[0], grid_pt[1], grid_pt[2]] = 1

def get_ray_direction(x, y, cam_rot):
    ndc_x = (x / IMAGE_SIZE - 0.5) * 2 * np.tan(np.deg2rad(FOV/2))
    ndc_y = (y / IMAGE_SIZE - 0.5) * 2 * np.tan(np.deg2rad(FOV/2))
    ray_cam = np.array([ndc_x, -ndc_y, -1.0])
    ray_cam /= np.linalg.norm(ray_cam)
    ray_world = cam_rot @ ray_cam
    return ray_world

def segment_and_fuse():
    view_files = sorted([f for f in os.listdir(VIEWS_DIR) if f.endswith(('.png', '.jpg'))])
    if not view_files:
        print(f"No images found in {VIEWS_DIR}")
        return

    print(f"Found {len(view_files)} views to process")

    for i, view_file in enumerate(view_files):
        print(f"Processing {view_file}...")
        view_path = os.path.join(VIEWS_DIR, view_file)
        image_source, image = load_image(view_path)
        image_source_rgb = cv2.cvtColor(image_source, cv2.COLOR_BGR2RGB)
        image_pil = Image.fromarray(image_source_rgb)

        boxes, logits, _ = predict(
            model=groundingdino_model,
            image=image,
            caption=None,
            box_threshold=BOX_THRESHOLD,
            text_threshold=TEXT_THRESHOLD,
            device=DEVICE
        )

        if len(boxes) == 0:
            print(f"No objects detected in {view_file}")
            continue

        base_name = os.path.splitext(view_file)[0]
        annotated = annotate(
            image_source=image_source,
            boxes=boxes,
            logits=logits,
            phrases=[TEXT_PROMPT]*len(boxes)
        )

        segmented_masks = segment(image_source, sam_predictor, boxes=boxes)
        annotated_frame_with_mask = draw_mask(segmented_masks[0][0], annotated)
        cv2.imwrite(os.path.join(OUTPUT_DIR, f"{base_name}_mask.png"), annotated_frame_with_mask)
        
        
if __name__ == "__main__":
    segment_and_fuse()
    print("Segmentation and volumetric fusion complete!")
    


Loading models...
final text_encoder_type: bert-base-uncased


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Loading BLIP-2 model...
Found 5 views to process
Processing view_0.png...


AttributeError: 'NoneType' object has no attribute 'lower'

In [2]:
groundingdino_model = load_model(CONFIG_PATH, CHECKPOINT_PATH).to(DEVICE)
sam = sam_model_registry["vit_h"](checkpoint=SAM_CHECKPOINT).to(DEVICE)
sam_predictor = SamPredictor(sam)

def box_to_pixel(box, image_shape):
    h, w = image_shape[:2]
    return np.array([
        int(box[0] * w),
        int(box[1] * h),
        int(box[2] * w),
        int(box[3] * h)
    ])

def segment_and_save_views():
    view_files = sorted([f for f in os.listdir(VIEWS_DIR) if f.endswith(('.png', '.jpg'))])
    
    if not view_files:
        print(f"No images found in {VIEWS_DIR}")
        return
    
    print(f"Found {len(view_files)} views to process")

    for view_file in view_files:
        print(f"\nProcessing {view_file}...")
        view_path = os.path.join(VIEWS_DIR, view_file)

        try:
            # Load and prepare image
            image_source, image = load_image(view_path)
            image_rgb = cv2.cvtColor(image_source, cv2.COLOR_BGR2RGB)
            h, w = image_source.shape[:2]

            # Get boxes from GroundingDINO
            boxes, logits, _ = predict(
                model=groundingdino_model,
                image=image,
                caption=TEXT_PROMPT,
                box_threshold=BOX_THRESHOLD,
                text_threshold=TEXT_THRESHOLD,
                device=DEVICE
            )

            if len(boxes) == 0:
                print(f"No objects detected in {view_file}")
                continue

            # Only use the first detected box
            box = boxes[0].cpu().numpy()
            box_pixel = box_to_pixel(box, image_source.shape)

            # Segment using box-prompted SAM
            sam_predictor.set_image(image_rgb)
            masks, scores, _ = sam_predictor.predict(
                box=box_pixel,
                multimask_output=True
            )

            best_mask = masks[np.argmax(scores)]

            # Convert to uint8 for OpenCV use
            mask_uint8 = best_mask.astype(np.uint8) * 255

            # Apply visual highlighting
            overlay = image_source.copy()
            overlay[best_mask > 0] = [0, 255, 0]
            highlighted = cv2.addWeighted(overlay, 0.5, image_source, 0.5, 0)

            # Draw the bounding box
            cv2.rectangle(highlighted,
                          (box_pixel[0], box_pixel[1]),
                          (box_pixel[2], box_pixel[3]),
                          (0, 0, 255), 2)

            # Save files
            base_name = os.path.splitext(view_file)[0]
            cv2.imwrite(os.path.join(OUTPUT_DIR, f"{base_name}_mask.png"), mask_uint8)
            cv2.imwrite(os.path.join(OUTPUT_DIR, f"{base_name}_highlighted.png"), highlighted)

            # Annotated version (from GroundingDINO)
            annotated = annotate(
                image_source=image_source,
                boxes=boxes,
                logits=logits,
                phrases=[TEXT_PROMPT] * len(boxes)
            )
            cv2.imwrite(os.path.join(OUTPUT_DIR, f"{base_name}_annotated.png"), annotated)

            print(f"Successfully processed {view_file}")

        except Exception as e:
            print(f"Error processing {view_file}: {str(e)}")

if __name__ == "__main__":
    segment_and_save_views()
    print("Segmentation complete! Check the output directory for results.")




final text_encoder_type: bert-base-uncased
Found 5 views to process

Processing view_0.png...




Successfully processed view_0.png

Processing view_1.png...




Successfully processed view_1.png

Processing view_2.png...




Successfully processed view_2.png

Processing view_3.png...




Successfully processed view_3.png

Processing view_4.png...
Successfully processed view_4.png
Segmentation complete! Check the output directory for results.




## llm integration

In [3]:
import os
import numpy as np
import torch
from PIL import Image
import cv2
from GroundingDINO.groundingdino.util import box_ops
from GroundingDINO.groundingdino.util.inference import load_model, load_image, predict, annotate
from segment_anything.segment_anything import sam_model_registry, SamPredictor
from transformers import AutoProcessor, LlavaForConditionalGeneration
from typing import List, Optional

# Configuration
CONFIG_PATH = "./GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py"
CHECKPOINT_PATH = "./groundingdino_swint_ogc.pth"
SAM_CHECKPOINT = "../models/sam_vit_h_4b8939.pth"
LLAVA_MODEL_NAME = "llava-hf/llava-1.5-7b-hf"  # Open-source VLM
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TEXT_PROMPT = "tail"  # Default prompt
BOX_THRESHOLD = 0.3
TEXT_THRESHOLD = 0.25
VIEWS_DIR = "../view/"
OUTPUT_DIR = "../segmented_views/"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Initialize models
print("Loading models...")
groundingdino_model = load_model(CONFIG_PATH, CHECKPOINT_PATH).to(DEVICE)
sam = sam_model_registry["vit_h"](checkpoint=SAM_CHECKPOINT).to(DEVICE)
sam_predictor = SamPredictor(sam)

# Initialize LLaVA (vision-language model)
print("Loading LLaVA model...")
llava_processor = AutoProcessor.from_pretrained(LLAVA_MODEL_NAME)
llava_model = LlavaForConditionalGeneration.from_pretrained(
    LLAVA_MODEL_NAME,
    torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
    low_cpu_mem_usage=True
).to(DEVICE)

class PartSegmenter:
    def __init__(self):
        self.device = DEVICE
        
    def refine_prompt_with_llm(self, image: np.ndarray, initial_prompt: str) -> str:
        """Use LLaVA to refine the text prompt based on image content."""
        try:
            # Convert numpy array to PIL Image
            image_pil = Image.fromarray(image)
            
            # Prepare prompt for LLaVA - using the correct format
            prompt = f"USER: <image>\nWhat is the most precise way to describe the {initial_prompt} in this image for segmentation purposes? Just provide the description without additional text.\nASSISTANT:"
            
            # Process inputs with correct text format
            inputs = llava_processor(
                text=prompt,  # Single string input
                images=image_pil,
                return_tensors="pt",
                padding=True
            ).to(DEVICE, torch.float16)
            
            # Generate response
            generate_ids = llava_model.generate(
                **inputs,
                max_new_tokens=50,
                do_sample=True
            )
            
            # Decode response
            response = llava_processor.decode(
                generate_ids[0],
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False
            )
            
            # Extract just the assistant's response
            if "ASSISTANT:" in response:
                refined_prompt = response.split("ASSISTANT:")[-1].strip()
            else:
                refined_prompt = initial_prompt
                
            # Fallback if response is not useful
            if len(refined_prompt) < 3 or initial_prompt.lower() not in refined_prompt.lower():
                return initial_prompt
                
            return refined_prompt
            
        except Exception as e:
            print(f"LLaVA prompt refinement failed: {str(e)}")
            return initial_prompt
    
    def validate_mask_with_llm(self, image: np.ndarray, mask: np.ndarray, prompt: str) -> bool:
        """Use LLaVA to validate if the masked region matches the prompt."""
        try:
            # Create masked image (ensure it's the right type)
            masked_image = (image * (mask[:, :, np.newaxis] if mask.ndim == 2 else mask)).astype(np.uint8)
            masked_image_pil = Image.fromarray(masked_image)
            
            # Prepare prompt for LLaVA - single string format
            validation_prompt = (
                f"USER: <image>\nDoes this image show a {prompt}? "
                f"Answer with just 'yes' or 'no'.\nASSISTANT:"
            )
            
            # Process inputs with correct text format
            inputs = llava_processor(
                text=validation_prompt,  # Single string input
                images=masked_image_pil,
                return_tensors="pt",
                padding=True
            ).to(DEVICE, torch.float16)
            
            # Generate response
            generate_ids = llava_model.generate(
                **inputs,
                max_new_tokens=2,  # We just want yes/no
                do_sample=False  # Disable sampling for deterministic output
            )
            
            # Decode response
            response = llava_processor.decode(
                generate_ids[0],
                skip_special_tokens=True,
                clean_up_tokenization_spaces=False
            )
            
            # Check for affirmative response
            return "yes" in response.lower()
            
        except Exception as e:
            print(f"LLaVA validation failed: {str(e)}")
            return True  # Default to keeping mask if validation fails

def segment(image: np.ndarray, sam_model: SamPredictor, boxes: torch.Tensor) -> List[np.ndarray]:
    """Enhanced segmentation with multi-mask output"""
    sam_model.set_image(image)
    H, W, _ = image.shape
    boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])

    transformed_boxes = sam_model.transform.apply_boxes_torch(boxes_xyxy.to(DEVICE), image.shape[:2])
    
    # Get multiple masks per box
    masks, _, _ = sam_model.predict_torch(
        point_coords=None,
        point_labels=None,
        boxes=transformed_boxes,
        multimask_output=True,
    )
    
    # Convert to numpy and filter empty masks
    all_masks = []
    for box_masks in masks:
        for mask in box_masks:
            mask_np = mask.cpu().numpy()
            if mask_np.sum() > 100:  # Minimum pixel threshold
                all_masks.append(mask_np)
    
    return all_masks

def draw_mask(mask: np.ndarray, image: np.ndarray, random_color: bool = True) -> np.ndarray:
    """Draw mask on image with optional random color"""
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.8])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    
    h, w = mask.shape
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    
    annotated_frame_pil = Image.fromarray(image).convert("RGBA")
    mask_image_pil = Image.fromarray((mask_image * 255).astype(np.uint8)).convert("RGBA")

    return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))

def segment_and_save_views():
    """Enhanced segmentation pipeline with LLM integration"""
    view_files = sorted([f for f in os.listdir(VIEWS_DIR) if f.endswith(('.png', '.jpg'))])
    
    if not view_files:
        print(f"No images found in {VIEWS_DIR}")
        return
    
    print(f"Found {len(view_files)} views to process")
    segmenter = PartSegmenter()
    
    for view_file in view_files:
        print(f"Processing {view_file}...")
        view_path = os.path.join(VIEWS_DIR, view_file)
        
        try:
            # Load and prepare image
            image_source, image = load_image(view_path)
            image_source_rgb = cv2.cvtColor(image_source, cv2.COLOR_BGR2RGB)
            
            # Refine prompt using LLaVA
            refined_prompt = segmenter.refine_prompt_with_llm(image_source_rgb, TEXT_PROMPT)
            print(f"Original prompt: {TEXT_PROMPT}, Refined prompt: {refined_prompt}")
            
            # Get boxes from GroundingDINO
            boxes, logits, _ = predict(
                model=groundingdino_model,
                image=image,
                caption=refined_prompt,
                box_threshold=BOX_THRESHOLD,
                text_threshold=TEXT_THRESHOLD,
                device=DEVICE
            )
            
            if len(boxes) == 0:
                print(f"No objects detected in {view_file}")
                continue
            
            # Get masks
            masks = segment(image_source_rgb, sam_predictor, boxes=boxes)
            
            # Validate masks with LLaVA
            valid_masks = []
            for mask in masks:
                if segmenter.validate_mask_with_llm(image_source_rgb, mask, refined_prompt):
                    valid_masks.append(mask)
                else:
                    print("LLaVA rejected a mask as not matching the prompt")
            
            if not valid_masks:
                print(f"No valid masks found for {view_file}")
                continue
            
            # Save results
            base_name = os.path.splitext(view_file)[0]
            
            # Save annotation with boxes
            annotated = annotate(
                image_source=image_source,
                boxes=boxes,
                logits=logits,
                phrases=[refined_prompt]*len(boxes)
            )
            
            # Save each valid mask
            for i, mask in enumerate(valid_masks):
                annotated_frame_with_mask = draw_mask(mask, annotated)
                cv2.imwrite(
                    os.path.join(OUTPUT_DIR, f"{base_name}_mask_{i}.png"),
                    cv2.cvtColor(annotated_frame_with_mask, cv2.COLOR_RGBA2BGR)
                )
            
            cv2.imwrite(os.path.join(OUTPUT_DIR, f"{base_name}_annotated.png"), annotated)
            print(f"Saved results for {view_file}")
            
        except Exception as e:
            print(f"Error processing {view_file}: {str(e)}")

if __name__ == "__main__":
    segment_and_save_views()
    print("Segmentation complete! Check the output directory for results.")

Loading models...
final text_encoder_type: bert-base-uncased
Loading LLaVA model...


Downloading shards:   0%|          | 0/3 [01:39<?, ?it/s]


KeyboardInterrupt: 

In [None]:
from GroundingDINO.groundingdino.util.inference import load_model as load_groundingdino, predict
from segment_anything.segment_anything import sam_model_registry, SamPredictor

# Load Grounding DINO
grounding_dino = load_groundingdino(
    "GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py",
    "groundingdino_swint_ogc.pth"
)

# Load SAM
sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth")
sam_predictor = SamPredictor(sam)

# Load Open-source LLM (e.g., LLaMA-3 via HuggingFace)
llm_tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-3B-Instruct")
llm = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-3B-Instruct", device_map="auto")

In [None]:
def generate_part_prompts(object_description):
    prompt = f"""
    You are a vision assistant. List unambiguous text prompts to segment all parts of: {object_description}.
    Use spatial references (e.g., 'left ear', 'central loop of the knot'). Return as a Python list.
    Example output for 'elephant':
    ['elephant trunk', 'left front leg', 'right ear', 'tail']
    """
    inputs = llm_tokenizer(prompt, return_tensors="pt").to("cpu")
    outputs = llm.generate(**inputs, max_new_tokens=100)
    return eval(llm_tokenizer.decode(outputs[0], skip_special_tokens=True).split("```")[0])

In [None]:
def segment_parts(image_path, object_class):
    # Load image
    image = cv2.imread(image_path)
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    # Step 1: Detect whole object
    boxes, _, _ = predict(
        grounding_dino, image_rgb, caption=object_class, box_threshold=0.3
    )
    if len(boxes) == 0:
        return []

    # Step 2: Generate part prompts with LLM
    part_prompts = generate_part_prompts(object_class)

    # Step 3: Detect and segment parts
    sam_predictor.set_image(image_rgb)
    all_masks = []
    for prompt in part_prompts:
        part_boxes, _, _ = predict(
            grounding_dino, image_rgb, caption=prompt, box_threshold=0.2
        )
        for box in part_boxes:
            # Convert box to SAM input (normalized XYXY)
            box_xyxy = box / np.array([image.shape[1], image.shape[0]] * 2)
            masks, _, _ = sam_predictor.predict(
                box=box_xyxy[None, :],
                multimask_output=True  # Get 3 candidate masks
            )
            all_masks.append(masks[0])  # Select the best mask

    return all_masks

In [None]:
def clean_masks(masks, iou_threshold=0.5):
    # Non-Max Suppression (NMS) to remove duplicates
    from torchvision.ops import nms
    boxes = torch.tensor([m.sum(axis=(0, 1)) for m in masks])  # Dummy boxes
    scores = torch.ones(len(masks))  # Dummy scores
    keep = nms(boxes, scores, iou_threshold)
    return [masks[i] for i in keep]

In [None]:
# Example: Segment parts of an elephant
masks = segment_parts("objects/1a6f615e8b1b5ae4dbbc9440457e303e/rendered_data/0_softflat_gray.png", "chair")
masks = clean_masks(masks)

# Visualize
for mask in masks:
    cv2.imshow("Part", mask.astype(np.uint8) * 255)
    cv2.waitKey(0)