In [None]:
import torch
import torchvision
import numpy as np
import cv2
from torchvision.ops import roi_align
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from segment_anything import sam_model_registry, SamPredictor

# ---------------------------
# Step 1: Load SAM and get masks
# ---------------------------
def load_sam(model_type="vit_h", checkpoint_path="sam_vit_h.pth"):
    sam = sam_model_registry[model_type](checkpoint=checkpoint_path)
    predictor = SamPredictor(sam)
    return predictor

def get_sam_boxes(predictor, image):
    predictor.set_image(image)
    masks, _, _ = predictor.predict(multimask_output=True)

    boxes = []
    for mask in masks:
        pos = np.where(mask)
        if pos[0].size == 0 or pos[1].size == 0:
            continue
        y_min, x_min = np.min(pos[0]), np.min(pos[1])
        y_max, x_max = np.max(pos[0]), np.max(pos[1])
        if (x_max - x_min) * (y_max - y_min) > 500:  # Filter small regions
            boxes.append([x_min, y_min, x_max, y_max])
    return boxes  # list of [x1, y1, x2, y2]

# ---------------------------
# Step 2: Load Backbone
# ---------------------------
def get_backbone_feature_map(image_tensor, device="cuda"):
    backbone = torchvision.models.resnet50(pretrained=True)
    backbone = torch.nn.Sequential(*list(backbone.children())[:-2])  # Remove FC and avgpool
    backbone = backbone.to(device)  # 🔥 Move model to same device as input
    backbone.eval()
    with torch.no_grad():
        feature_map = backbone(image_tensor.to(device))  # 🔥 Also move input to device
    return feature_map


# ---------------------------
# Step 3: Apply ROI Align
# ---------------------------
def apply_roi_align(feature_map, boxes, image_size, output_size=(7, 7)):
    # Convert boxes to (index, x1, y1, x2, y2) format required by roi_align
    boxes_tensor = torch.tensor(boxes, dtype=torch.float32)
    batch_indices = torch.zeros((boxes_tensor.shape[0], 1))
    boxes_tensor = torch.cat([batch_indices, boxes_tensor], dim=1)  # Add batch index

    # Normalize box coordinates to feature map scale
    scale_y = feature_map.shape[2] / image_size[0]
    scale_x = feature_map.shape[3] / image_size[1]
    boxes_tensor[:, 1::2] *= scale_x
    boxes_tensor[:, 2::2] *= scale_y

    # Move boxes to the same device as feature_map
    boxes_tensor = boxes_tensor.to(feature_map.device)

    roi_features = roi_align(feature_map, boxes_tensor, output_size=output_size)
    return roi_features


# ---------------------------
# Step 4: Classification Head
# ---------------------------
class DetectionHead(torch.nn.Module):
    def __init__(self, in_channels, num_classes):
        super().__init__()
        self.head = FastRCNNPredictor(in_channels, num_classes)

    def forward(self, roi_feats):
        roi_feats = torch.flatten(roi_feats, start_dim=1)
        return self.head(roi_feats)

# ---------------------------
# Full Pipeline
# ---------------------------
def run_pipeline(image_path, sam_checkpoint, device="cuda"):
    image = cv2.imread(image_path)
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image_tensor = torchvision.transforms.functional.to_tensor(image_rgb).unsqueeze(0).to(device)

    predictor = load_sam(checkpoint_path=sam_checkpoint)
    boxes = get_sam_boxes(predictor, image_rgb)

    if not boxes:
        print("No valid boxes found by SAM.")
        return

    feature_map = get_backbone_feature_map(image_tensor).to(device)
    roi_feats = apply_roi_align(feature_map, boxes, image_rgb.shape[:2])

    # Use dummy head for now; in real training, replace with trained head
    detection_head = DetectionHead(in_channels=roi_feats.shape[1] * 7 * 7, num_classes=91).to(device)
    outputs = detection_head(roi_feats)

    # save the image
    cv2.imwrite("output.jpg", image_rgb)

    print("Predicted class scores:", outputs)

# Example usage
run_pipeline("rail.jpg", sam_checkpoint="sam_vit_h.pth")
