# Reproduce sam1 processing
    
A snippet of the the logic from `generate_gsa_results.py` script. This notebook are the minimal steps for making the basic sam(1.0) masks.

In [1]:
ADT_DIR             = '/home/ubuntu/cs-747-project/adt'
ADT_PROCESSED_DIR   = '/home/ubuntu/cs-747-project/adt_processed_new'
SCENE_NAME          = 'Apartment_release_golden_skeleton_seq100_10s_sample'
SAM_ENCODER_VERSION = 'vit_h'
SAM_CHECKPOINT_PATH = 'sam_vit_h_4b8939.pth'

In [2]:
from segment_anything import sam_model_registry
from segment_anything import SamAutomaticMaskGenerator

device = 'cuda'

# Get sam mask generator
sam = sam_model_registry[SAM_ENCODER_VERSION](checkpoint=SAM_CHECKPOINT_PATH)
sam.to(device)
mask_generator = SamAutomaticMaskGenerator(sam)

  state_dict = torch.load(f)


In [3]:
from typing import Any
import argparse
import json

# define classes (from generate_gsa_results.py)
class Dataset():
    def __init__(self, args) -> None:
        self.input_folder = args.input_folder
        assert self.input_folder.exists(), f"Input folder {self.input_folder} does not exist. "

        self.detection_save_folder = self.input_folder / f"gsa_det_{args.class_set}_{args.sam_variant}"
        self.detection_save_folder.mkdir(exist_ok=True)

        self.vis_save_folder = self.input_folder / f"gsa_vis_{args.class_set}_{args.sam_variant}"
        self.vis_save_folder.mkdir(exist_ok=True)

    def __getitem__(self, index: int) -> Any:
        raise NotImplementedError

    def __len__(self) -> int:
        raise NotImplementedError

class AriaDataset(Dataset):
    def __init__(self, args: argparse.Namespace) -> None:
        super().__init__(args)

        transform_path = self.input_folder / "transforms.json"
        with open(transform_path) as json_file:
            frames = json.loads(json_file.read())["frames"]
        
        # Only keep the RGB images
        self.frames = [f for f in frames if f['camera_name'] == 'rgb']

        self.frames.sort(key=lambda f: f["image_path"])

    def __getitem__(self, index: int) -> Any:
        subpath = self.frames[index]["image_path"]
        image_path = self.input_folder / subpath
        image_filename = subpath[:-4] # remove the .png/.jpg extension

        return image_path, image_filename
    
    def __len__(self) -> int:
        return len(self.frames)

In [4]:
import numpy as np

# The SAM based on automatic mask generation, without bbox prompting
def get_sam_segmentation_dense(
    variant:str, model: Any, image: np.ndarray
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    '''
    The SAM based on automatic mask generation, without bbox prompting
    
    Args:
        model: The mask generator or the YOLO model
        image: )H, W, 3), in RGB color space, in range [0, 255]
        
    Returns:
        mask: (N, H, W)
        xyxy: (N, 4)
        conf: (N,)
    '''
    if variant == "sam":
        results = model.generate(image)
        mask = []
        xyxy = []
        conf = []
        for r in results:
            mask.append(r["segmentation"])
            r_xyxy = r["bbox"].copy()
            # Convert from xyhw format to xyxy format
            r_xyxy[2] += r_xyxy[0]
            r_xyxy[3] += r_xyxy[1]
            xyxy.append(r_xyxy)
            conf.append(r["predicted_iou"])
        mask = np.array(mask)
        xyxy = np.array(xyxy)
        conf = np.array(conf)
        return mask, xyxy, conf
    elif variant == "fastsam":
        # The arguments are directly copied from the GSA repo
        results = model(
            image,
            imgsz=1024,
            device="cuda",
            retina_masks=True,
            iou=0.9,
            conf=0.4,
            max_det=100,
        )
        raise NotImplementedError
    else:
        raise NotImplementedError

In [5]:
import supervision as sv
from supervision.draw.color import Color, ColorPalette
import dataclasses

def vis_result_fast(
    image: np.ndarray, 
    detections: sv.Detections, 
    classes: list[str], 
    color: Color | ColorPalette = ColorPalette.DEFAULT, 
    instance_random_color: bool = False,
    draw_bbox: bool = True,
) -> np.ndarray:
    '''
    Annotate the image with the detection results. 
    This is fast but of the same resolution of the input image, thus can be blurry. 
    '''
    # annotate image with detections
    box_annotator = sv.BoxAnnotator(
        color = color,
    )
    label_annontator = sv.LabelAnnotator(
        text_scale=0.3,
        text_thickness=1,
        text_padding=2,
    )
    mask_annotator = sv.MaskAnnotator(
        color = color,
        opacity=0.35,
    )
    labels = [
        f"{classes[class_id]} {confidence:0.2f}" 
        for _, _, confidence, class_id, _, _
        in detections]
    
    if instance_random_color:
        # generate random colors for each segmentation
        # First create a shallow copy of the input detections
        detections = dataclasses.replace(detections)
        detections.class_id = np.arange(len(detections))
        
    annotated_image = mask_annotator.annotate(scene=image.copy(), detections=detections)
    
    if draw_bbox:
        annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections)
        annotated_image = label_annontator.annotate(scene=annotated_image, detections=detections, labels=labels)
    return annotated_image, labels

In [6]:
from pathlib import Path

input_folder = (Path(ADT_PROCESSED_DIR) / SCENE_NAME)
print(f"Input folder: {input_folder}")

Input folder: /home/ubuntu/cs-747-project/adt_processed_new/Apartment_release_golden_skeleton_seq100_10s_sample


In [7]:
classes = ['item']
##### initialize the dataset #####
rotate_back = False
if (input_folder / "global_points.csv.gz").exists():
    print(f"Found global_points.csv.gz file, assuming Aria data set!")
    # dataset = AriaDataset(args)
    dataset = AriaDataset(args=argparse.Namespace(
        input_folder=input_folder,
        class_set='none',
        sam_variant='sam',
    ))
    rotate_back = True
else:
    # Not implemented yet
    raise NotImplementedError("Only Aria data set is supported for now!")

Found global_points.csv.gz file, assuming Aria data set!


In [8]:
from tqdm import trange
import numpy as np
import imageio
import supervision as sv
import pickle
import gzip
from PIL import Image

stride =1
# TODO experiment with longer side... might be causing aliasing issues?
max_longer_side = 512

annotated_frames = []
global_classes = []

for idx in trange(0, len(dataset), stride):
    # image_path = args.input_folder / frames[idx]["image_path"]
    # image_filename = image_path.name.split('.')[0]
    image_path, image_filename = dataset[idx]

    image_pil = Image.open(image_path)
    # image_pil = image_pil.resize((args.output_width, args.output_height))
    longer_side = min(max(image_pil.size), max_longer_side)
    resize_scale = float(longer_side) / max(image_pil.size)
    image_pil = image_pil.resize(
        (int(image_pil.size[0] * resize_scale), int(image_pil.size[1] * resize_scale))
    )
    # If image is RGBA, drop the alpha channel
    if image_pil.mode == "RGBA":
        image_pil = image_pil.convert("RGB")
    
    if rotate_back:
        image_pil = image_pil.rotate(-90, expand=True)
    image_rgb = np.array(image_pil)
    image_bgr = image_rgb[:, :, ::-1].copy()

    # add classes to global classes
    for c in classes:
        if c not in global_classes:
            global_classes.append(c)
    
    # if args.accumu_classes:
    #     # Use all the classes that have been seen so far
    #     classes = global_classes

    ### Detection and segmentation ###
    # Directly use SAM in dense sampling mode to get segmentation
    mask, xyxy, conf = get_sam_segmentation_dense(
        'sam', mask_generator, image_rgb)

    detections = sv.Detections(
        xyxy=xyxy,
        confidence=conf,
        class_id=np.zeros_like(conf).astype(int),
        mask=mask,
    )

    # Remove the bounding boxes that are too large (they tend to capture the entire image)
    areas = (detections.xyxy[:, 2] - detections.xyxy[:, 0]) * (detections.xyxy[:, 3] - detections.xyxy[:, 1])
    area_ratios = areas / (image_rgb.shape[0] * image_rgb.shape[1])
    valid_idx = area_ratios < 0.6
    detections.xyxy = detections.xyxy[valid_idx]
    detections.confidence = detections.confidence[valid_idx]
    detections.class_id = detections.class_id[valid_idx]
    detections.mask = detections.mask[valid_idx]

    ### Compute CLIP features ###
    # if not args.no_clip:
    #     image_crops, image_feats, text_feats = compute_clip_features(
    #         image_rgb, detections, clip_model, clip_preprocess, clip_tokenizer, classes, args.device)
    # else:
    #     image_crops, image_feats, text_feats = None, None, None
    image_crops, image_feats, text_feats = None, None, None

    ### Save the detection results ###
    detection_save_path = dataset.detection_save_folder / f"{image_filename}.pkl.gz"
    detection_save_path.parent.mkdir(exist_ok=True, parents=True)
    det_results = {
        "image_path": image_path,
        "xyxy": detections.xyxy,
        "confidence": detections.confidence,
        "class_id": detections.class_id,
        "mask": detections.mask,
        "classes": classes,
        "image_crops": image_crops,
        "image_feats": image_feats,
        "text_feats": text_feats,
    }
    with gzip.open(str(detection_save_path), 'wb') as f:
        pickle.dump(det_results, f)

        
    ### Visualize results and save ###
    annotated_image, labels = vis_result_fast(
        image_rgb, detections, classes, 
        # instance_random_color = args.class_set=="none",
        instance_random_color = True,
        # draw_bbox = args.class_set!="none",
        draw_bbox = False,
    )

    vis_save_path = dataset.vis_save_folder / f"{image_filename}.png"
    vis_save_path.parent.mkdir(exist_ok=True, parents=True)
    imageio.imwrite(vis_save_path, annotated_image)
    
    # plt.figure(figsize=(10, 10))
    # plt.imshow(annotated_image)
    # plt.title(f"Frame {idx}")
    # plt.show()
    # cv2.imwrite(vis_save_path, annotated_image)
    annotated_frames.append(annotated_image)

# Save the annotated frames as a video
annotated_frames = np.stack(annotated_frames, axis=0)

imageio.mimwrite(
    input_folder / f"gsa_vis_none_sam.mp4",
    annotated_frames,
    fps=20,
)

  0%|          | 0/267 [00:00<?, ?it/s]

  3%|▎         | 9/267 [00:31<14:49,  3.45s/it]


KeyboardInterrupt: 

## sam2

Now we adapt the above workflow to work with sam2. (WIP)

In [66]:
#TODO