<a href="https://colab.research.google.com/github/SauravMaheshkar/samv2/blob/main/examples/notebooks/samv2_prompted_segmentation_with_wandb_tables.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Please add a `W&B` named secret containing your API Key to Colab.

## ðŸ“¦ Packages and Basic Setup
---

In [None]:
%%capture
!pip install git+https://github.com/SauravMaheshkar/samv2.git wandb
!wget https://dl.fbaipublicfiles.com/segment_anything_2/072824/sam2_hiera_tiny.pt

url = "https://github.com/SauravMaheshkar/SauravMaheshkar/blob/main/assets/text2img/llama_spiderman_coffee.png?raw=true"

In [None]:
import os

import wandb
from google.colab import userdata

os.environ["WANDB_API_KEY"] = userdata.get("W&B")

run = wandb.init(project="samv2", entity="sauravmaheshkar")  # @param {type: "string"}

columns = ["image", "mask", "score"]
wandb_table = wandb.Table(columns=columns)

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import requests
from PIL import Image

image = Image.open(requests.get(url, stream=True).raw)
image = np.array(image.convert("RGB"))

In [None]:
plt.imshow(image)

In [None]:
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.utils.misc import variant_to_config_mapping
from sam2.utils.visualization import show_masks

model = build_sam2(
    variant_to_config_mapping["tiny"],
    "/content/sam2_hiera_tiny.pt",
)
image_predictor = SAM2ImagePredictor(model)
image_predictor.set_image(image)

## Perform Segmentation with a single point

In [None]:
input_point = np.array([[300, 600]])
input_label = np.array([1])

In [None]:
masks, scores, logits = image_predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    box=None,
    multimask_output=True,
)

In [None]:
output_mask = show_masks(image, masks, scores)

In [None]:
sorted_ind = np.argsort(scores)[::-1]
print(f"Top Score: {scores[sorted_ind[0]]}")

In [None]:
output_mask

In [None]:
wandb_table.add_data(
    wandb.Image(image), wandb.Image(output_mask), scores[sorted_ind[0]]
)

## Perform Segmentation using multiple points

In [None]:
multi_point_coords = np.array([[300, 600], [700, 700]])
multi_point_labels = np.array([1, 1])

In [None]:
masks, scores, _ = image_predictor.predict(
    point_coords=multi_point_coords,
    point_labels=multi_point_labels,
    box=None,
    multimask_output=False,
)

In [None]:
output_mask = show_masks(image, masks, scores)

In [None]:
sorted_ind = np.argsort(scores)[::-1]
print(f"Top Score: {scores[sorted_ind[0]]}")

In [None]:
output_mask

In [None]:
wandb_table.add_data(
    wandb.Image(image), wandb.Image(output_mask), scores[sorted_ind[0]]
)

## Perform Segmentation using a single bounding box

In [None]:
single_box_coords = np.array([656, 655, 798, 816])

In [None]:
masks, scores, _ = image_predictor.predict(
    point_coords=None,
    point_labels=None,
    box=single_box_coords,
    multimask_output=False,
)

In [None]:
output_mask = show_masks(image, masks, scores=None, display_image=False)

In [None]:
sorted_ind = np.argsort(scores)[::-1]
print(f"Top Score: {scores[sorted_ind[0]]}")

In [None]:
output_mask

In [None]:
wandb_table.add_data(
    wandb.Image(image), wandb.Image(output_mask), scores[sorted_ind[0]]
)

## Perform Segmentation using multiple bounding boxes

In [None]:
multi_box_coords = np.array([[656, 655, 798, 816], [263, 518, 408, 653]])

In [None]:
masks, scores, _ = image_predictor.predict(
    point_coords=None,
    point_labels=None,
    box=multi_box_coords,
    multimask_output=False,
)

In [None]:
output_mask = show_masks(
    image, masks, scores=None, only_best=False, display_image=False
)

In [None]:
sorted_ind = np.argsort(scores)[::-1]
print(f"Top Score: {scores[sorted_ind[0]][0][0]}")

In [None]:
output_mask

In [None]:
wandb_table.add_data(
    wandb.Image(image), wandb.Image(output_mask), scores[sorted_ind[0]][0][0]
)

## Perform Segmentation using a collection of boxes and points

In [None]:
box = np.array([263, 518, 408, 653])
point = np.array([[300, 600]])
label = np.array([1])

In [None]:
masks, scores, _ = image_predictor.predict(
    point_coords=point,
    point_labels=label,
    box=box,
    multimask_output=False,
)

In [None]:
output_mask = show_masks(image, masks, scores=None, display_image=False)

In [None]:
sorted_ind = np.argsort(scores)[::-1]
print(f"Top Score: {scores[sorted_ind[0]]}")

In [None]:
output_mask

In [None]:
wandb_table.add_data(
    wandb.Image(image), wandb.Image(output_mask), scores[sorted_ind[0]]
)

In [None]:
run.log({"samv2_prompt_segmentation": wandb_table})

wandb.finish()