# Segment images by using text prompts
The purpose of this notebook is to illustrate different methods for segmenting images using text prompts. Text prompts are used to semantically identify the parts of the image to be segmented.

## Segment Anything Model
The Segment Anything Model (SAM) is the main component of the proposed methods for segmenting satellite imagery using text prompts. SAM produces high quality object masks from input prompts such as points or boxes, and can be used to generate masks for all objects in an image. However, it doesn't accept text prompts, and implementing this feature requires combining SAM with other visual foundation models.

### SAM Automatic mask generator
SAM can generate masks for all objects in an image. However, objects with the same semantic don't share the same mask index (different masks for each instance of an object). A notebook showing how to use this feature is available [here](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb).

### SAM Predictor
SAM is able to predict object masks given prompts that indicate the desired object. Ideally, if a box is given as a prompt, SAM will segment the main object within the box, while if a point is given, SAM will segment the object to which the point belongs. A notebook showing how to use this feature is available [here](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb).

## Set-up

Create new virtual environment.

In [None]:
# python -m venv .venv
# . .venv/bin/activate
# pip install ipykernel
# python -m ipykernel install --user --name text-based-seg

Install the torch version that matches your installed CUDA version.

In [None]:
%pip install torch==1.10.1+cu111 torchvision==0.11.2+cu111 torchaudio==0.10.1 -f https://download.pytorch.org/whl/cu111/torch_stable.html
%pip install ipywidgets pandas scikit-learn jinja2

Change kernel with the new one before proceeding.

In [None]:
# install Grounding DINO
!git clone 'https://github.com/IDEA-Research/GroundingDINO.git'
%cd GroundingDINO/
%pip install -e .
%cd ..
%mkdir weights
%cd weights
!wget -N 'https://github.com/IDEA-Research/GroundingDINO/releases/download/v0.1.0-alpha/groundingdino_swint_ogc.pth'
%cd ..

In [None]:
# install Segment Anything Model
%pip install 'git+https://github.com/facebookresearch/segment-anything.git'
%cd weights
!wget -N 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'
%cd ..

In [None]:
# install CLIP repository
%pip install git+https://github.com/openai/CLIP.git

In [None]:
# download CLIP Surgery repository
!git clone https://github.com/xmed-lab/CLIP_Surgery.git

If you want many images to play with you can download the Cityscapes dataset. The image used for different examples in this notebook is from this dataset.

In [None]:
# register https://www.cityscapes-dataset.com/register/
# change myusername and mypassword with your login data
# download Cityscapes dataset
%mkdir -p data/cityscapes
%cd data/cityscapes
!wget --keep-session-cookies --save-cookies=cookies.txt --post-data 'username=myusername&password=mypassword&submit=Login' 'https://www.cityscapes-dataset.com/login/'
!wget -nc --load-cookies cookies.txt --content-disposition 'https://www.cityscapes-dataset.com/file-handling/?packageID=1'
!wget -nc --load-cookies cookies.txt --content-disposition 'https://www.cityscapes-dataset.com/file-handling/?packageID=3'
!unzip -n gtFine_trainvaltest.zip
!unzip -n leftImg8bit_trainvaltest.zip
%cd ../..

Restart runtime to load new packages before going forward otherwise groundingdino will not be found.

In [None]:
import cv2
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from pathlib import Path
from PIL import Image
import random
from sklearn.metrics import accuracy_score, precision_score, recall_score
import torch
from torch.utils.data import DataLoader, Dataset, Subset
import torchvision.transforms as T


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Utils

Define methods to show the results.

In [None]:
# given a mask and an image, this method highlights the parts of the image referenced by the mask.
def draw_mask(mask, image):
    color_map = np.array([255/255, 51/255, 51/255, 0.6])
    h, w = mask.shape[-2:]
    mask = mask.reshape(h, w, 1) * color_map.reshape(1, 1, -1)
    annotated_frame_pil = Image.fromarray(image).convert("RGBA")
    mask_image_pil = Image.fromarray((mask * 255).astype(np.uint8)).convert("RGBA")
    return np.array(Image.alpha_composite(annotated_frame_pil, mask_image_pil))

# plot image, ground truth and prediction
def plot_result(image, gt, pred):
    _, axes = plt.subplots(1, 3, figsize=(15, 5))
    # show input
    axes[0].imshow(image)
    axes[0].set_title("Image")
    # show prediction
    axes[1].imshow(gt)
    axes[1].set_title("Ground truth")
    # show target
    axes[2].imshow(pred)
    axes[2].set_title("Prediction")

    plt.tight_layout()
    plt.imshow()

# show annotations of object segmentation masks
def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)

    img = np.ones((sorted_anns[0]['segmentation'].shape[0], sorted_anns[0]['segmentation'].shape[1], 4))
    img[:,:,3] = 0
    for ann in sorted_anns:
        m = ann['segmentation']
        color_mask = np.concatenate([np.random.random(3), [0.35]])
        img[m] = color_mask

    return (img * 255).astype(np.uint8)

## Grounding DINO + SAM
Grounding DINO is a model used to detect objects given one or more text inputs. Grounding DINO return a box for each detection and then all these boxes can become an input of SAM to perform the segmentation.

In [None]:
from groundingdino.util import box_ops
from groundingdino.util.inference import load_model, load_image, predict, annotate
from segment_anything import SamPredictor, sam_model_registry

### Load models

In [None]:
groundingDINO_checkpoint = Path("weights", "groundingdino_swint_ogc.pth")
groundingDINO = load_model(Path("GroundingDINO", "groundingdino", "config", "GroundingDINO_SwinT_OGC.py"), groundingDINO_checkpoint)
groundingDINO.to(device)

sam_checkpoint = Path("weights", "sam_vit_h_4b8939.pth")
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device)
sam_predictor = SamPredictor(sam)

### Define predict methods

In [None]:
# returns the boxes of the detections from an image given a text_prompt
def predict_dino(image_path, text_prompt):
    image_source, image = load_image(image_path)
    if type(text_prompt) == str:
        text_prompt = [text_prompt]

    boxes_lst = []
    for prompt in text_prompt:
        boxes, logits, phrases = predict(
            model=groundingDINO,
            image=image,
            caption=prompt,
            box_threshold=0.35,
            text_threshold=0.25,
            device=device
        )
        annotated_frame = annotate(image_source=image_source, boxes=boxes, logits=logits, phrases=phrases)
        annotated_frame = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
        boxes_lst.extend(boxes)

        if len(boxes_lst) > 0:
            boxes_lst = torch.stack(boxes_lst)
            H, W, _ = image_source.shape
            boxes_xyxy = box_ops.box_cxcywh_to_xyxy(boxes_lst) * torch.Tensor([W, H, W, H])
        else:
            boxes_xyxy = None
    return boxes_xyxy, annotated_frame

# returns the segmentation mask given a set of boxes and/or points
def predict_sam(image_path, boxes=None, point_coords=None, point_labels=None):
    image_source, image = load_image(image_path)
    sam_predictor.set_image(image_source)

    if boxes != None:  # Grounding DINO
        boxes = sam_predictor.transform.apply_boxes_torch(boxes.to(device), image_source.shape[:2])
        masks, _, _ = sam_predictor.predict_torch(
            point_coords = None,
            point_labels = None,
            boxes = boxes,
            multimask_output = False
        )

        masks = ((masks.sum(dim=0)>0)[0]*1).cpu().numpy()
    else:  # CLIPS
        masks, _, _ = sam_predictor.predict(
            point_coords = point_coords,
            point_labels = point_labels,
            multimask_output = False
        )
        masks = np.array(masks)[0, :, :]

    annotated_frame_with_mask = draw_mask(masks, image_source)

    return masks, annotated_frame_with_mask

### Image example

In [None]:
image_path = Path("data", "cityscapes", "leftImg8bit", "val", "lindau", "lindau_000000_000019_leftImg8bit.png")
text_prompt = "tree"

# get boxes of objects related to the text prompt
boxes_xyxy, annotated_frame = predict_dino(image_path, [text_prompt])
# for each box segment the object inside
masks, annotated_frame_with_mask = predict_sam(image_path, boxes=boxes_xyxy)

print("Grounding DINO Output")
plt.imshow(annotated_frame)
plt.show()

print("SAM Output")
plt.imshow(annotated_frame_with_mask)
plt.show()

## CLIP Surgery + SAM

CLIP Surgery (CLIPS) is a model made to enhance explainability of CLIP. CLIP Surgery can be used to convert text to point prompts and then these points can be used by SAM to perform the segmentation.

In [None]:
import CLIP_Surgery.clip as clips

### Load Model

In [None]:
clips_model, _ = clips.load("CS-ViT-B/16", device=device)

### Define predict method

In [None]:
# returns the points of the features from an image given a text_prompt
def predict_CLIPS(image_path, text_prompt):
    pil_img = Image.open(image_path)
    image_source = np.array(pil_img)

    if type(text_prompt) == str:
        text_prompt = [text_prompt]
    
    # you can change height and width of the resize transformation based on the dataset
    height, width = 1024, 1024

    # predict
    preprocess =  T.Compose([T.Resize((height, width), interpolation=T.InterpolationMode.BICUBIC),
                             T.ToTensor(), T.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))])
    cv2_img = cv2.cvtColor(image_source, cv2.COLOR_RGB2BGR)
    image = preprocess(pil_img).unsqueeze(0).to(device)
    with torch.no_grad():
        # CLIP architecture surgery acts on the image encoder
        image_features = clips_model.encode_image(image)
        image_features = image_features / image_features.norm(dim=1, keepdim=True)

        # Prompt ensemble for text features with normalization
        text_features = clips.encode_text_with_prompt_ensemble(clips_model, text_prompt, device)

        # Extract redundant features from an empty string
        redundant_features = clips.encode_text_with_prompt_ensemble(clips_model, [""], device)

        # CLIP feature surgery with costum redundant features
        similarity = clips.clip_feature_surgery(image_features, text_features, redundant_features)[0]
            
        # Inference SAM with points from CLIP Surgery
        coords, labels = clips.similarity_map_to_points(similarity[1:, 0], cv2_img.shape[:2], t=0.8)

        # Annotate with points
        annotated_frame = cv2_img.copy()
        for i, [x, y] in enumerate(coords):
            cv2.circle(annotated_frame, (x, y), 3, (0, 102, 255) if labels[i] == 1 else (255, 102, 51), 3)
        annotated_frame = cv2.cvtColor(annotated_frame.astype('uint8'), cv2.COLOR_BGR2RGB)

        return np.array(coords), labels, annotated_frame

### Image Example

In [None]:
image_path = Path("data", "cityscapes", "leftImg8bit", "val", "lindau", "lindau_000000_000019_leftImg8bit.png")
text_prompt = "tree"

# get points of objects related to the text prompt
coords, labels, annotated_frame = predict_CLIPS(image_path, [text_prompt])
# segment the objects related with the given points
masks, annotated_frame_with_mask = predict_sam(image_path, point_coords=coords, point_labels=labels)

print("CLIP Surgery Output")
plt.imshow(annotated_frame)
plt.show()

print("SAM Output")
plt.imshow(annotated_frame_with_mask)
plt.show()

## SAM + CLIP
CLIP is a model used to compute image-text similarity. Another method to segment based on a text prompt is to use SAM Automatic mask generation to generate masks for each object and then let CLIP compute the similarity between the text prompt and each object.

In [None]:
import clip
from segment_anything import SamAutomaticMaskGenerator

### Load Model

In [None]:
generic_mask_generator = SamAutomaticMaskGenerator(sam, points_per_side=10)
clip_model, clip_preprocess = clip.load("ViT-B/16", device=device)

### Define predict method

In [None]:
def predict_sam_auto(image_path):
    image_source, _ = load_image(image_path)
    
    # use SAM to generate masks
    segmented_frame_masks = generic_mask_generator.generate(image_source)

    # object annotations
    annotated_frame = show_anns(segmented_frame_masks)
    annotated_frame = annotated_frame[:, :, :3]

    return segmented_frame_masks, annotated_frame


@torch.no_grad()
def retriev(elements, search_text):
    preprocessed_images = [clip_preprocess(image).to(device) for image in elements]
    tokenized_text = clip.tokenize([search_text]).to(device)
    stacked_images = torch.stack(preprocessed_images)
    image_features = clip_model.encode_image(stacked_images)
    text_features = clip_model.encode_text(tokenized_text)
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    probs = 100. * image_features @ text_features.T
    return probs[:, 0].softmax(dim=0)


def get_indices_of_values_above_threshold(values, threshold):
    return [i for i, v in enumerate(values) if v > threshold]


def segment_image(image, segmentation_mask):
    image_array = image
    segmented_image_array = np.zeros_like(image_array)
    segmented_image_array[segmentation_mask] = image_array[segmentation_mask]
    segmented_image = Image.fromarray(segmented_image_array)
    black_image = Image.new("RGB", image.shape[:2], (0, 0, 0))
    transparency_mask = np.zeros_like(segmentation_mask, dtype=np.uint8)
    transparency_mask[segmentation_mask] = 255
    transparency_mask_image = Image.fromarray(transparency_mask, mode='L')
    black_image.paste(segmented_image, mask=transparency_mask_image)
    return black_image


def convert_box_xywh_to_xyxy(box):
    x1 = box[0]
    y1 = box[1]
    x2 = box[0] + box[2]
    y2 = box[1] + box[3]
    return [x1, y1, x2, y2]


def predict_CLIP(image_path, segmented_frame_masks, text_prompt):
    image_source, _ = load_image(image_path)

    if type(text_prompt) == str:
        text_prompt = [text_prompt]

    # Cut out all masks
    cropped_boxes = []
    for mask in segmented_frame_masks:
        cropped_boxes.append(segment_image(image_source, mask["segmentation"]).crop(convert_box_xywh_to_xyxy(mask["bbox"])))  
        
    indices_lst = []
    for prompt in text_prompt:
        scores = retriev(cropped_boxes, prompt)
        indices = get_indices_of_values_above_threshold(scores, 0.05)
        indices_lst.extend(indices)      
            
    segmentation_masks = []
    for seg_idx in np.unique(indices_lst):
        segmentation_mask_image = segmented_frame_masks[seg_idx]["segmentation"]
        segmentation_masks.append(segmentation_mask_image)
    segmentation_masks = np.array(segmentation_masks).sum(axis=0)>0

    if segmentation_masks.ndim == 0:
        segmentation_masks = np.zeros(shape=image_source.shape[0:2])

    annotated_frame_with_mask = draw_mask(segmentation_masks, image_source)

    return segmentation_masks, annotated_frame_with_mask

### Example Image

In [None]:
image_path = Path("data", "cityscapes", "leftImg8bit", "val", "lindau", "lindau_000000_000019_leftImg8bit.png")
text_prompt = "tree"

# get object segmentation masks
segmented_frame_masks, annotated_frame = predict_sam_auto(image_path)
# check which objects are related to the text prompt
masks, annotated_frame_with_mask = predict_CLIP(image_path, segmented_frame_masks, text_prompt)

print("SAM Output")
plt.imshow(annotated_frame)
plt.show()

print("CLIP Output")
plt.imshow(annotated_frame_with_mask)
plt.show()

## Visual comparison

Compare the models with different text prompts.

In [None]:
image_path = Path("data", "cityscapes", "leftImg8bit", "val", "lindau", "lindau_000000_000019_leftImg8bit.png")
text_prompts = ["banner", "car", "plate", "road", "roof", "tree"]

for text_prompt in text_prompts:
    # DINO + SAM
    boxes_xyxy, _ = predict_dino(image_path, [text_prompt])
    _, annotated_frame_with_mask_dino = predict_sam(image_path, boxes=boxes_xyxy)
    # CLIPS + SAM
    coords, labels, _ = predict_CLIPS(image_path, [text_prompt])
    _, annotated_frame_with_mask_clips = predict_sam(image_path, point_coords=coords, point_labels=labels)
    # SAM + CLIP
    segmented_frame_masks, _ = predict_sam_auto(image_path)
    _, annotated_frame_with_mask_clip = predict_CLIP(image_path, segmented_frame_masks, text_prompt)
    # plot
    fig, axes = plt.subplots(1, 3, figsize=(14, 4))
    fig.suptitle(text_prompt.capitalize())
    axes[0].set_title("Grounding DINO + SAM")
    axes[0].imshow(annotated_frame_with_mask_dino)
    axes[1].set_title("CLIPS + SAM")
    axes[1].imshow(annotated_frame_with_mask_clips)
    axes[2].set_title("SAM + CLIP")
    axes[2].imshow(annotated_frame_with_mask_clip)
    plt.tight_layout()
    plt.show()

After trying the three models with different prompts, visual inspection didn't  show better results from a model.

## Metrics
Compute metrics over the Cityscapes dataset. Metrics are computed for each text prompt (= the name of the class) over the ground truth. The same pixel can be of different classes because the resulting masks intersection is not always empty. For this reason a confusion matrix can't be built properly.

### Define Cityscapes dataset class.

In [None]:
class Cityscapes(Dataset):
    def __init__(self, root, split='train', mode='fine', target_type='semantic', transform=None):
        self.root = os.path.expanduser(root)
        self.mode = 'gtFine'
        self.target_type = target_type
        self.images_dir = os.path.join(self.root, 'leftImg8bit', split)

        self.targets_dir = os.path.join(self.root, self.mode, split)
        self.transform = transform

        self.split = split
        self.images = []
        self.targets = []
        
        for city in os.listdir(self.images_dir):
            img_dir = os.path.join(self.images_dir, city)
            target_dir = os.path.join(self.targets_dir, city)

            for file_name in os.listdir(img_dir):
                self.images.append(os.path.join(img_dir, file_name))
                target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
                                             self._get_target_suffix(self.mode, self.target_type))
                self.targets.append(os.path.join(target_dir, target_name))

    def __getitem__(self, index):
        image_path = self.images[index]
        target = Image.open(self.targets[index])
        target = np.array(target)
        return image_path, target

    def __len__(self):
        return len(self.images)

    def _get_target_suffix(self, mode, target_type):
        if target_type == 'instance':
            return '{}_instanceIds.png'.format(mode)
        elif target_type == 'semantic':
            return '{}_labelIds.png'.format(mode)
        elif target_type == 'color':
            return '{}_color.png'.format(mode)
        elif target_type == 'polygon':
            return '{}_polygons.json'.format(mode)
        elif target_type == 'depth':
            return '{}_disparity.png'.format(mode)

A subset of the test_dataset is used, otherwise metrics evaluation takes too long.

In [None]:
test_dataset = Cityscapes('./data/cityscapes', split='val', mode='fine', target_type='semantic')
random.seed(42)
test_subset_indices = random.choices(range(0, len(test_dataset)), k=int(len(test_dataset) * 0.1))
test_subset = Subset(test_dataset, test_subset_indices)
n_cpu = os.cpu_count()
test_dataloader = DataLoader(test_subset, batch_size=1, shuffle=False, num_workers=n_cpu)

### Define index-prompt relation
The first value is the index, the second element is the name of the class (and also the input prompt).<br>
e.g. each pixel that belongs to road class has 7 as value in Cityscapes GT. Ideally, if one of the models is prompted with the string "road" I expect that the output mask is the same as Cityscapes GT considering only the 7 index.

In [None]:
classes = (
    (7, 'road'),
    (8, 'sidewalk'),
    (11, 'building'),
    (12, 'wall'),
    (13, 'fence'),
    (17, 'pole'),
    (18, 'pole group'),
    (19, 'traffic light'),
    (20, 'traffic sign'),
    (21, 'vegetation'),
    (22, 'terrain'),
    (23, 'sky'),
    (24, 'person'),
    (25, 'rider'),
    (26, 'car'),
    (27, 'truck'),
    (28, 'bus'),
    (31, 'train'),
    (32, 'motorcycle'),
    (33, 'bicycle'),
)

### Compute Metrics

In [None]:
def get_metrics(y_true, y_pred):
    y_true = y_true.flatten()
    y_pred = y_pred.flatten()
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred)
    recall = recall_score(y_true, y_pred)
    return accuracy, precision, recall

metrics = np.zeros(shape=(3, len(classes), 3))
for i, (image_path, gt) in enumerate(test_dataloader):
    image_path, gt = image_path[0], gt.squeeze()

    for j, (index, text_prompt) in enumerate(classes):
        print(f"{i * len(classes) + j + 1}/{len(test_dataloader) * len(classes)}")
        gt_index = (gt == index).numpy().astype(int)
        
        # DINO + SAM
        boxes_xyxy, _ = predict_dino(image_path, [text_prompt])
        mask_dino, _ = predict_sam(image_path, boxes=boxes_xyxy)
        # CLIPS + SAM
        coords, labels, _ = predict_CLIPS(image_path, [text_prompt])
        mask_clips, _ = predict_sam(image_path, point_coords=coords, point_labels=labels)
        # SAM + CLIP
        segmented_frame_masks, _ = predict_sam_auto(image_path)
        mask_clip, _ = predict_CLIP(image_path, segmented_frame_masks, text_prompt)
        # compute metrics
        metrics[0, j] = metrics[0, j] + get_metrics(gt_index, mask_dino)
        metrics[1, j] = metrics[1, j] + get_metrics(gt_index, mask_clips)
        metrics[2, j] = metrics[2, j] + get_metrics(gt_index, mask_clip)

metrics = metrics / len(test_dataloader)
np.save('metrics.npy', metrics)

### Print comparison results

Export results as LaTeX tables.

In [None]:
np.set_printoptions(suppress=True)
metrics = np.load('metrics.npy')
dino_metrics, clips_metrics, clip_metrics = metrics[0], metrics[1], metrics[2]

dino_df = pd.DataFrame({
    'class': [x[1] for x in classes],
    'accuracy': dino_metrics[:, 0],
    'precision': dino_metrics[:, 1],
    'recall': dino_metrics[:, 2]
})
dino_df.loc[len(dino_df)] = {
    'class': 'Overall',
    'accuracy': dino_df['accuracy'].mean(),
    'precision': dino_df['precision'].mean(),
    'recall': dino_df['recall'].mean()
}
dino_df = dino_df.round(6)
print(dino_df.to_latex())

clips_df = pd.DataFrame({
    'class': [x[1] for x in classes],
    'accuracy': clips_metrics[:, 0],
    'precision': clips_metrics[:, 1],
    'recall': clips_metrics[:, 2]
})
clips_df.loc[len(clips_df)] = {
    'class': "Overall",
    'accuracy': clips_df['accuracy'].mean(),
    'precision': clips_df['precision'].mean(),
    'recall': clips_df['recall'].mean()
}
clips_df = clips_df.round(6)
print(clips_df.to_latex())

clip_df = pd.DataFrame({
    'class': [x[1] for x in classes],
    'accuracy': clip_metrics[:, 0],
    'precision': clip_metrics[:, 1],
    'recall': clip_metrics[:, 2]
})
clip_df.loc[len(clip_df)] = {
    'class': "Overall",
    'accuracy': clip_df['accuracy'].mean(),
    'precision': clip_df['precision'].mean(),
    'recall': clip_df['recall'].mean()
}
clip_df = clip_df.round(6)
print(clip_df.to_latex())

## Concluding Remarks
Segmentation based on text prompts is not reliable.

There are other combinations worth trying:
- Grounding DINO, then SAM and finally CLIP
- CLIP Surgery, then SAM and finally CLIP
- Both Grounding DINO and CLIP Surgery, then SAM and finally CLIP