# Automatically generating object masks with MobileSAM

## Environment Set-up

In [15]:
%pip install -q git+https://github.com/openai/CLIP.git


## Set-up

In [None]:
def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask
    ax.imshow(img)


In [None]:
import sys
sys.path.append("MobileSAM")
from mobile_sam import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "MobileSAM/weights/mobile_sam.pt"
model_type = "vit_t"

device = "cuda" if torch.cuda.is_available() else "cpu"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
sam.eval()


Sam(
  (image_encoder): TinyViT(
    (patch_embed): PatchEmbed(
      (seq): Sequential(
        (0): Conv2d_BN(
          (c): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        (1): GELU(approximate='none')
        (2): Conv2d_BN(
          (c): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
          (bn): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
      )
    )
    (layers): ModuleList(
      (0): ConvLayer(
        (blocks): ModuleList(
          (0-1): 2 x MBConv(
            (conv1): Conv2d_BN(
              (c): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
              (bn): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            )
            (act1): GELU(approximate='none')
            (conv2): Conv2d_BN(
 

In [None]:
import numpy as np
import matplotlib.pyplot as plt


In [None]:
import clip
from PIL import Image

def generate_text_embeddings(classnames, templates, model):
    with torch.no_grad():
        class_embeddings_list = []
        for classname in classnames:
            texts = [template.format(classname) for template in templates] #format with class
            texts = clip.tokenize(texts).to(device) #tokenize
            class_embedding = model.encode_text(texts) #embed with text encoder
            class_embeddings_list.append(class_embedding)
        class_embeddings = torch.stack(class_embeddings_list, dim=1).to(device)
    return class_embeddings

def create_clip(classes=city_classes):
    clip_model, preprocess = clip.load("ViT-B/16", device=device)
    clip_model.eval()
    text_features = generate_text_embeddings(classes, ['a clean origami {}.'], clip_model)#['a rendering of a weird {}.'], model)
    return clip_model, preprocess, text_features


In [None]:
from utils import visualize

def predict(image_path, mask_generator, clip_model, clip_preprocess, text_features):
    image = np.array(Image.open(image_path))

    torch.cuda.empty_cache()  # Empty GPU memory
    outputs = mask_generator.generate(image)

    clip_model.eval()

    boxes = []
    masks = []
    class_ids = []
    scores = []

    for output in outputs:
        mask = output["segmentation"]
        masked_image = image.copy()
        ind = np.where(mask > 0)
        masked_image[mask == 0] = 0
        y1, x1, y2, x2 = min(ind[0]), min(ind[1]), max(ind[0]), max(ind[1])
        masked_image = Image.fromarray(masked_image[y1:y2+1, x1:x2+1])

        with torch.no_grad():
            masked_image = clip_preprocess(masked_image)
            image_features = clip_model.encode_image(masked_image.unsqueeze(0).to(device))

            # Pick the top 5 most similar labels for the image
            image_features /= image_features.norm(dim=-1, keepdim=True)
            text_features /= text_features.norm(dim=-1, keepdim=True)
            text_features = text_features.squeeze(0)
            similarity = (100.0 * image_features.float() @ text_features.float().T).softmax(dim=-1)
            score, index = similarity[0].topk(1)

        del masked_image  # Release image variables

        filtered_class = [
            'road', 'sidewalk', 'parking', # 'rail track',
            'person', 'rider',
            'car', 'truck', 'bus',
            #'on rails',
            'motorcycle', 'bicycle', #'caravan', 'trailer',
            # 'building', 'wall', 'fence',
            #'guard rail',
            'bridge', 'tunnel',
            # 'pole', 'pole group', 'traffic sign', 'traffic light',
            'vegetation', 'terrain',
            'sky',
            'ground', # 'dynamic', 'static'
        ]

        if city_classes[index.item()] in filtered_class:
            boxes.append(convert_xywh_yxyx(output["bbox"]))
            masks.append(mask)
            scores.append(score.item())
            class_ids.append(index.item())

    boxes = np.array(boxes)
    masks = np.stack(masks, axis=-1)
    class_ids = np.array(class_ids)
    scores = np.array(scores)

    torch.cuda.empty_cache()  # Empty GPU memory

    return image, boxes, masks, class_ids, scores


In [None]:
%load_ext autoreload
%autoreload 2
