In [1]:
import torch
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator

In [2]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

sam = sam_model_registry["vit_h"](checkpoint="data/sam_vit_h_4b8939.pth").to(DEVICE)
mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=16,       # start modest; increase if recall is low
    pred_iou_thresh=0.88,     # SAMâ€™s own confidence (tune later)
    stability_score_thresh=0.92,
    crop_n_layers=0,          # 0 to keep it fast at first
    output_mode="binary_mask"
)

In [3]:
import numpy as np
from typing import List

from utils import get_image

In [4]:
def generate_masks(img_id: str) -> List[np.array]:
    img = get_image(img_id)
    masks = mask_generator.generate(img)
        
    return masks

In [5]:
from utils import split_mask_into_blobs, filter_mask

In [6]:
def filter_masks(masks) -> List[np.array]:
    blob_masks = []
    for m in masks:
        blobs = split_mask_into_blobs(m['segmentation'])
        for blob in blobs:
            blob_masks.append({
                "segmentation": blob,
                "score": m.get("score"),
                "bbox": m.get("bbox"),
                "area": blob.sum()
            })
            
    blob_masks = sorted(blob_masks, key=lambda m: m["area"], reverse=False)

    filtered_masks = [
        m for m in blob_masks
        if not filter_mask(m['segmentation'])
    ]
        
    return filtered_masks

In [7]:
from typing import Dict, Any
from utils import use_cache, non_max_mask_suppression, enforce_no_overlap

In [8]:
def predict_masks(img_id: str) -> List[Dict[str, Any]]:
    masks = use_cache(generate_masks)(img_id)
    masks = non_max_mask_suppression(masks)
    masks = enforce_no_overlap(masks)
    filtered_masks = filter_masks(masks)
    
    return filtered_masks

In [9]:
from utils import get_2d_mask_from_rle, iou_validation, load_masks

In [10]:
def evaluate_image(img_id: str) -> List[Dict[str, Any]]:
    df = load_masks()
    
    rles = df[df.ImageId == img_id]['EncodedPixels'].values
    ground_truth_masks = [
        get_2d_mask_from_rle(rle) for rle in rles
    ]

    filtered_masks = use_cache(predict_masks)(img_id)
    filtered_masks = [m['segmentation'] for m in filtered_masks]

    return iou_validation(ground_truth_masks, filtered_masks)

evaluate_image("000155de5.jpg")

array([], shape=(1, 0), dtype=float32)