<a href="https://colab.research.google.com/github/geoaigroup/geoaigroup-website/blob/main/content/media/SAM_26May2023/SAM_GEOAI_Demo.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Object masks from prompts with SAM
This section was prepared by Ali Mayladan.

The Segment Anything Model (SAM) predicts object masks given prompts that indicate the desired object. The model first converts the image into an image embedding that allows high quality masks to be efficiently produced from a prompt.

The `SamPredictor` class provides an easy interface to the model for prompting the model. It allows the user to first set an image using the `set_image` method, which calculates the necessary image embeddings. Then, prompts can be provided via the `predict` method to efficiently predict masks from those prompts. The model can take as input both point and box prompts, as well as masks from the previous iteration of prediction.

## Environment Set-up

If running locally using jupyter, first install `segment_anything` in your environment using the [installation instructions](https://github.com/facebookresearch/segment-anything#installation) in the repository. If running from Google Colab, set `using_colab=True` below and run the cell. In Colab, be sure to select 'GPU' under 'Edit'->'Notebook Settings'->'Hardware accelerator'.

In [None]:
!wget https://github.com/geoaigroup/geoaigroup-website/raw/main/content/media/SAM_26May2023/data.zip
!unzip data.zip

In [None]:
using_colab = True

In [None]:
if using_colab:
    import torch
    import torchvision
    print("PyTorch version:", torch.__version__)
    print("Torchvision version:", torchvision.__version__)
    print("CUDA is available:", torch.cuda.is_available())
    import sys
    !{sys.executable} -m pip install opencv-python matplotlib
    !{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'

    # !mkdir images
    # !wget -P images https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/truck.jpg
    # !wget -P images https://raw.githubusercontent.com/facebookresearch/segment-anything/main/notebooks/images/groceries.jpg

    !wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

## Set-up

Necessary imports and helper functions for displaying points, boxes, and masks.

In [None]:
!pip install geopandas

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import geopandas as gpd
import os
import json
import glob
from tqdm import tqdm
from shapely.geometry import Point, Polygon
import random
from PIL import Image, ImageDraw



In [None]:
def show_mask(mask,ax, random_color=False):

    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    #return mask_image

def show_mask_box(mask, ax, random_color=False):

    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))



## Selecting objects with SAM

First, load the SAM model and predictor. Change the path below to point to the SAM checkpoint. Running on CUDA and using the default model are recommended for best results.

In [None]:
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)

Process the image to produce an image embedding by calling `SamPredictor.set_image`. `SamPredictor` remembers this embedding and will use it for subsequent mask prediction.

To select the truck, choose a point on it. Points are input to the model in (x,y) format and come with labels 1 (foreground point) or 0 (background point). Multiple points can be input; here we use only one. The chosen point will be shown as a star on the image.

In [None]:
def calculateIoU(gtMask, predMask):
        # Calculate the true positives,
        # false positives, and false negatives
        tp = 0
        fp = 0
        fn = 0

        for i in range(len(gtMask)):
            for j in range(len(gtMask[0])):
                if gtMask[i][j] == 1 and predMask[i][j] == 1:
                    tp += 1
                elif gtMask[i][j] == 0 and predMask[i][j] == 1:
                    fp += 1
                elif gtMask[i][j] == 1 and predMask[i][j] == 0:
                    fn += 1

        # Calculate IoU
        iou = tp / (tp + fp + fn)

        return iou

In [None]:
width = 512
height = 512

def convert_polygon_to_mask(geo):
      gtmask=np.zeros((512,512))
      for orig_row in geo:
            polygon=[]

            if orig_row.geom_type=="Polygon":
                for point in orig_row.exterior.coords:
                    polygon.append(point)
                img = Image.new('L', (width, height), 0)
                ImageDraw.Draw(img).polygon(polygon, outline=1, fill=1)
                gt_mask_building = np.array(img)
                gtmask=gtmask+gt_mask_building
            else:
                for x in orig_row.geoms:
                 for point in x.exterior.coords:
                    polygon.append(point)

                img = Image.new('L', (width, height), 0)
                ImageDraw.Draw(img).polygon(polygon, outline=1, fill=1)
                gt_mask_building = np.array(img)
                gtmask=gtmask+gt_mask_building
      return gtmask


In [None]:
shapefile="data/shapefile"
images='data/images'
orig_shp="data/orig_shp"
output_dir="data/output_images"
score_dir="data/scores"

flag=0
width = 512
height = 512

for image in tqdm(os.listdir(images)):

    name=image.split('.')[0]

    image = cv2.imread(images+'/'+image)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    predictor.set_image(image)

    for i,orig_g in zip(glob.glob(shapefile+'/'+name),glob.glob(orig_shp+'/'+name)):
        input_point=[]
        input_label=[]
        mask_tile=np.zeros((3,512,512))
        mask_tile0=np.zeros((512,512))
        mask_tile1=np.zeros((512,512))
        mask_tile2=np.zeros((512,512))
        score_tile_buil=[]
        score_tile=[]
        gt_m=np.zeros((512,512))
        try:
          f=gpd.read_file(i)
          geo=f['geometry']
          f1=gpd.read_file(orig_g)
          geo1=f1['geometry']
          object_id=f['OBJECTID']
          gt_m=convert_polygon_to_mask(geo1)

        except Exception as e :
          print(e)
          flag=1
          break
        count=0

        for row,oid in zip(geo,object_id):

            x=row.x
            y=row.y
            i=[x,y]


            input_point.append(i)
            input_label.append(1)
            i=np.array([i])
            lab=np.array([1])

            masks, scores, logits = predictor.predict(
            point_coords=i,
            point_labels=lab,
            multimask_output=True,
            )
            masks=np.array(masks)
            mask_tile=mask_tile+masks


            msk0=masks[0]
            mask_tile0=mask_tile0+msk0
            msk1=masks[1]
            mask_tile1=mask_tile1+msk1
            msk2=masks[2]
            mask_tile2=mask_tile2+msk2

        iou0=calculateIoU(gt_m,mask_tile0)
        iou1=calculateIoU(gt_m,mask_tile1)
        iou2=calculateIoU(gt_m,mask_tile2)


        scores=[iou0,iou1,iou2]
        if flag==1:
          flag=0
          continue
        input_point=np.array(input_point)
        input_label=np.array(input_label)

        for i,(mask,score) in enumerate(zip(mask_tile,scores)):
            plt.figure(figsize=(10,10))
            plt.imshow(image)
            show_mask(mask, plt.gca())
            show_points(input_point, input_label, plt.gca())
            plt.title(f"Mask {i+1}, Score: {score:.3f}", fontsize=18)
            plt.axis('off')
            plt.show()



## Batched boxes inputs

SamPredictor can take multiple input prompts for the same image, using `predict_torch` method. This method assumes input points are already torch tensors and have already been transformed to the input frame. For example, imagine we have several box outputs from an object detector.

In [None]:
shapefile="data/orig_shp"
images='data/images'
output_dir="data/output_images"
score_dir="data/scores"
gt="data/gt"

score_val={}
score_=[]
score_mean=[]
for image in tqdm(os.listdir(images)):
    name=image.split('.')[0]
    image = cv2.imread(images+'/'+image)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    os.makedirs(score_dir,exist_ok=True)
    os.makedirs(output_dir+"/"+f'{name}',exist_ok=True)
    predictor.set_image(image)
  ######
    for i in glob.glob(shapefile+'/'+name):
        mask_tile=np.zeros((512,512))

        tile_boxes=[]
        f=gpd.read_file(i)
        geo=f['geometry']

        for p in geo:
            inbox=[]
            poly=p
            xmin,ymin,xmax,ymax=poly.bounds

            inbox=[xmin,ymin,xmax,ymax]

            tile_boxes.append(inbox)


        input_boxes=torch.tensor(tile_boxes, device=predictor.device)

        transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
        masks, _, _ = predictor.predict_torch(
            point_coords=None,
            point_labels=None,
            boxes=transformed_boxes,
            multimask_output=False,
        )

 #####calculating scores######
        gtmask=convert_polygon_to_mask(geo)
        #print(gtmask.shape)


        msk=masks.clone()
        msk=msk.int()
        msk=msk.cpu().numpy()


        for i in range(msk.shape[0]):

          batch=msk[i]

          for b in range(batch.shape[0]):
            mask_tile=mask_tile+batch[b]

        #print("mask_tile",mask_tile.shape)
        iou=calculateIoU(gtmask,mask_tile)
        score_val["name"]=name
        score_val["val"]=iou
        score_.append(score_val)
        score_mean.append(iou)
 ######



        for mask in masks:
            show_mask_box(mask.cpu().numpy(),plt.gca(), random_color=True)


        for box in input_boxes:
            show_box(box.cpu().numpy(),plt.gca())

        plt.axis('on')

        plt.savefig(f'{output_dir}/{name}/{name}_batch_box.png')
        plt.show()


with open(score_dir+'/scores.json', 'w') as ff:
        json.dump(score_, ff)
scores_mean=np.array(score_mean)
sa="average_score for all masks: "+str(np.mean(scores_mean))
with open(score_dir+'/all_box_scores.json', 'w') as ff:
        json.dump(sa, ff)

# LangSAM : https://github.com/luca-medeiros/lang-segment-anything
# GroundingDINO: https://github.com/IDEA-Research/GroundingDINO
# Segment Anything : https://github.com/facebookresearch/segment-anything

### Language Segment-Anything is an open-source project that combines the power of instance segmentation and text prompts to generate masks for specific objects in images. Built on the recently released Meta model, segment-anything, and the GroundingDINO detection model, it's an easy-to-use and effective tool for object detection and image segmentation.

This section was prepared by Hasan Moughnieh.

In [None]:
!wget https://github.com/geoaigroup/geoaigroup-website/raw/main/content/media/SAM_26May2023/LangSAM.zip
!unzip LangSAM.zip

In [None]:
import sys
!{sys.executable} -m pip install opencv-python matplotlib
!{sys.executable} -m pip install 'git+https://github.com/facebookresearch/segment-anything.git'
!pip install git+https://github.com/openai/CLIP.git

In [None]:
import os
HOME = os.getcwd()
print(HOME)

In [None]:
%cd {HOME}
!git clone https://github.com/IDEA-Research/GroundingDINO.git
%cd {HOME}/GroundingDINO
!pip install -q -e .
!pip install -q roboflow

In [None]:
#useful functions for displaying results

def load_ground_truth_masks(ground_truth_masks):
    loaded_masks = []
    mask_image = Image.open(ground_truth_masks).convert('L')
    mask_array = np.array(mask_image)
    return mask_array

def compute_accuracy(predicted_masks, ground_truth_masks):
    combined_predicted_mask = np.any(predicted_masks.numpy(), axis=0)
    accuracy = np.sum(np.logical_and(combined_predicted_mask, ground_truth_masks)) / np.sum(np.logical_or(combined_predicted_mask, ground_truth_masks))
    return accuracy

def display_images_with_masks(image, masks):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 6))
    ax1.axis('off')
    ax1.imshow(image)
    ax1.set_title('Original Image')

    ax2.axis('off')
    ax2.imshow(image)
    ax2.set_title('Image with Masks')

    num_masks = masks.shape[0]
    for i in range(num_masks):
        mask = masks[i].numpy()
        mask = np.ma.masked_where(mask < 0.5, mask)
        ax2.imshow(mask, alpha=0.5, cmap='jet')

    accuracy = compute_accuracy(masks, ground_truth_masks)
    # Adjust spacing between subplots
    plt.tight_layout()

    # Display the figure in the notebook
    plt.show()
    print("Accuracy:",accuracy)

In [None]:
import cv2
import numpy as np
import torch
from PIL import Image
from torchvision.utils import draw_bounding_boxes
from torchvision.utils import draw_segmentation_masks

MIN_AREA = 100


def load_image(image_path: str):
    return Image.open(image_path).convert("RGB")


def draw_image(image, masks, boxes, labels, alpha=0.4):
    image = torch.from_numpy(image).permute(2, 0, 1)
    image = draw_bounding_boxes(image, boxes, colors=['red'] * len(boxes), labels=labels, width=2)
    image = draw_segmentation_masks(image, masks=masks, colors=['cyan'] * len(boxes), alpha=alpha)
    return image.numpy().transpose(1, 2, 0)


def get_contours(mask):
    if len(mask.shape) > 2:
        mask = np.squeeze(mask, 0) #make it a 3d array
    mask = mask.astype(np.uint8) # [0,1] range
    mask *= 255 # [0,255] range
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    # findContours take 3 arguments :
    #binary mask , contour retrieval mode(external contours),contour approximation method
    #it returns two outputs , hierarchy and contours. hierachy is set as dummy variable "_"

    effContours = []
    for c in contours:
        area = cv2.contourArea(c)
        if area > MIN_AREA:
            effContours.append(c)
    return effContours


def contour_to_points(contour):
    pointsNum = len(contour)
    contour = contour.reshape(pointsNum, -1).astype(np.float32)
    points = [point.tolist() for point in contour]
    return points


def generate_labelme_json(binary_masks, labels, image_size, image_path=None):
    """Generate a LabelMe format JSON file from binary mask tensor.
    Args:
        binary_masks: Binary mask tensor of shape [N, H, W].
        labels: List of labels for each mask.
        image_size: Tuple of (height, width) for the image size.
        image_path: Path to the image file (optional).
    Returns:
        A dictionary representing the LabelMe JSON file.
    """
    num_masks = binary_masks.shape[0]
    binary_masks = binary_masks.numpy()

    json_dict = {
        "version": "4.5.6",
        "imageHeight": image_size[0],
        "imageWidth": image_size[1],
        "imagePath": image_path,
        "flags": {},
        "shapes": [],
        "imageData": None
    }

    # Loop through the masks and add them to the JSON dictionary
    for i in range(num_masks):
        mask = binary_masks[i]
        label = labels[i]
        effContours = get_contours(mask)

        for effContour in effContours:
            points = contour_to_points(effContour)
            shape_dict = {
                "label": label,
                "line_color": None,
                "fill_color": None,
                "points": points,
                "shape_type": "polygon"
            }

            json_dict["shapes"].append(shape_dict)

    return json_dict

In [None]:
import os
from urllib import request

import groundingdino.datasets.transforms as T
import numpy as np
import torch
import torch.nn as nn
from groundingdino.models import build_model
from groundingdino.util import box_ops
from groundingdino.util.inference import predict
from groundingdino.util.slconfig import SLConfig
from groundingdino.util.utils import clean_state_dict
from huggingface_hub import hf_hub_download
from segment_anything import sam_model_registry
from segment_anything import SamPredictor

SAM_MODELS = {
    "vit_h": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
    "vit_l": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
    "vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
}


def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
    cache_config_file = hf_hub_download(repo_id=repo_id, filename=ckpt_config_filename)

    args = SLConfig.fromfile(cache_config_file)
    model = build_model(args)
    args.device = device

    cache_file = hf_hub_download(repo_id=repo_id, filename=filename)
    checkpoint = torch.load(cache_file, map_location='cpu')
    log = model.load_state_dict(clean_state_dict(checkpoint['model']), strict=False)
    print(f"Model loaded from {cache_file} \n => {log}")
    model.eval()
    return model


def transform_image(image) -> torch.Tensor:
    transform = T.Compose([
        T.RandomResize([800], max_size=1333),
        T.ToTensor(),
        T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ])

    image_transformed, _ = transform(image, None)
    return image_transformed


class LangSAM():

    def __init__(self, sam_type="vit_h"):
        self.sam_type = sam_type
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.build_groundingdino()
        self.build_sam(sam_type)

    def build_sam(self, sam_type):
        url = SAM_MODELS[sam_type]
        sam_checkpoint = os.path.basename(url)
        if not os.path.exists(sam_checkpoint):
            request.urlretrieve(url, sam_checkpoint)
        sam = sam_model_registry[sam_type](checkpoint=sam_checkpoint)
        sam.to(device=self.device)
        self.sam = SamPredictor(sam)

    def build_groundingdino(self):
        ckpt_repo_id = "ShilongLiu/GroundingDINO"
        ckpt_filenmae = "groundingdino_swinb_cogcoor.pth"
        ckpt_config_filename = "GroundingDINO_SwinB.cfg.py"
        self.groundingdino = load_model_hf(ckpt_repo_id, ckpt_filenmae, ckpt_config_filename)

    def predict_dino(self, image_pil, text_prompt, box_threshold, text_threshold):
        image_trans = transform_image(image_pil)
        boxes, logits, phrases = predict(model=self.groundingdino,
                                         image=image_trans,
                                         caption=text_prompt,
                                         box_threshold=box_threshold,
                                         text_threshold=text_threshold)
        W, H = image_pil.size
        boxes = box_ops.box_cxcywh_to_xyxy(boxes) * torch.Tensor([W, H, W, H])

        return boxes, logits, phrases

    def predict_sam(self, image_pil, boxes):
        image_array = np.asarray(image_pil)
        self.sam.set_image(image_array)
        transformed_boxes = self.sam.transform.apply_boxes_torch(boxes, image_array.shape[:2])
        masks, _, _ = self.sam.predict_torch(
            point_coords=None,
            point_labels=None,
            boxes=transformed_boxes.to(self.sam.device),
            multimask_output=False,
        )
        return masks.cpu()

    def predict(self, image_pil, text_prompt, box_threshold=0.28, text_threshold=0.25):
        boxes, logits, phrases = self.predict_dino(image_pil, text_prompt, box_threshold, text_threshold)
        masks = torch.tensor([])
        if len(boxes) > 0:
            masks = self.predict_sam(image_pil, boxes)
            masks = masks.squeeze(1)
        return masks, boxes, phrases, logits



box_threshold: This value is used for object detection in the image. A higher value makes the model more selective, identifying only the most confident object instances, leading to fewer overall detections. A lower value, conversely, makes the model more tolerant, leading to increased detections, including potentially less confident ones.

text_threshold: This value is used to associate the detected objects with the provided text prompt. A higher value requires a stronger association between the object and the text prompt, leading to more precise but potentially fewer associations. A lower value allows for looser associations, which could increase the number of associations but also introduce less precise matches.

The optimal threshold can vary depending on the quality and nature of your images, as well as the specificity of your text prompts.


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

index = '2_37'

image = f'/content/{index}_img.png'
ground_truth_masks = f'/content/{index}_gt.png'

model = LangSAM()

image_pil = Image.open(image).convert("RGB")
text_prompt = "house"
masks, boxes, phrases, logits = model.predict(image_pil, text_prompt)
ground_truth_masks = load_ground_truth_masks(ground_truth_masks)

#This function displays the original image , predicted masks , and accuracy compared to ground truth
display_images_with_masks(image_pil, masks)