<a href="https://colab.research.google.com/github/kumar-piyush12/semantic_seg_oral_cancer/blob/main/SAM_w/out_mask_samples.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

[![Roboflow Notebooks](https://media.roboflow.com/notebooks/template/bannertest2-2.png?ik-sdk-version=javascript-1.4.3&updatedAt=1672932710194)](https://github.com/roboflow/notebooks)

# Segment Anything Model (SAM)

---

[![GitHub](https://badges.aleen42.com/src/github.svg)](https://github.com/facebookresearch/segment-anything) [![arXiv](https://img.shields.io/badge/arXiv-2304.02643-b31b1b.svg)](https://arxiv.org/abs/2304.02643)

Segment Anything Model (SAM): a new AI model from Meta AI that can "cut out" any object, in any image, with a single click. SAM is a promptable segmentation system with zero-shot generalization to unfamiliar objects and images, without the need for additional training. This notebook is an extension of the [official notebook](https://colab.research.google.com/github/facebookresearch/segment-anything/blob/main/notebooks/automatic_mask_generator_example.ipynb) prepared by Meta AI.

![segment anything model](https://media.roboflow.com/notebooks/examples/segment-anything-model-paper.png)

## Complementary Materials

---

[![Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/how-to-segment-anything-with-sam.ipynb) [![YouTube](https://badges.aleen42.com/src/youtube.svg)](https://youtu.be/D-D6ZmadzPE) [![Roboflow](https://raw.githubusercontent.com/roboflow-ai/notebooks/main/assets/badges/roboflow-blogpost.svg)](https://blog.roboflow.com/how-to-use-segment-anything-model-sam)

We recommend that you follow along in this notebook while reading the blog post on Segment Anything Model.

![segment anything model blogpost](https://media.roboflow.com/notebooks/examples/segment-anything-model-blogpost.png)

## Pro Tip: Use GPU Acceleration

If you are running this notebook in Google Colab, navigate to `Edit` -> `Notebook settings` -> `Hardware accelerator`, set it to `GPU`, and then click `Save`. This will ensure your notebook uses a GPU, which will significantly speed up model training times.

## Steps in this Tutorial

In this tutorial, we are going to cover:

- **Before you start** - Make sure you have access to the GPU
- Install Segment Anything Model (SAM)
- Download Example Data
- Load Model
- Automated Mask Generation
- Generate Segmentation with Bounding Box
- Segment Anything in Roboflow Universe Dataset

## Let's begin!

## Before you start

Let's make sure that we have access to GPU. We can use `nvidia-smi` command to do that. In case of any problems navigate to `Edit` -> `Notebook settings` -> `Hardware accelerator`, set it to `GPU`, and then click `Save`.

In [None]:
!nvidia-smi

Wed Apr 16 16:24:19 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   60C    P8             11W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

**NOTE:** To make it easier for us to manage datasets, images and models we create a `HOME` constant.

In [None]:
import os
HOME = os.getcwd()
print("HOME:", HOME)

HOME: /content


## Install Segment Anything Model (SAM) and other dependencies

In [None]:
!pip install -q 'git+https://github.com/facebookresearch/segment-anything.git'

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for segment_anything (setup.py) ... [?25l[?25hdone


In [None]:
!pip install -q jupyter_bbox_widget roboflow dataclasses-json supervision==0.23.0

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/151.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m151.5/151.5 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[?25h[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/85.2 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m85.2/85.2 kB[0m [31m5.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.8/66.8 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m49.9/49.9 MB[0m [31m18.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m220.7/220.7 kB[0m [31m14.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m50.9/50.9 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0

### Download SAM weights

In [None]:
!mkdir -p {HOME}/weights
!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth -P {HOME}/weights

In [None]:
CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")
print(CHECKPOINT_PATH, "; exist:", os.path.isfile(CHECKPOINT_PATH))

/content/weights/sam_vit_h_4b8939.pth ; exist: True


## Mounting Google Drive: To access D3 images

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# List all files and folders in your Google Drive root directory
dataset_path = "/content/drive/My Drive/Oral Cancer Dataset"
print(os.listdir(dataset_path))

['Annotation.json', 'Imagewise_Data.csv', 'Patientwise_Data.csv', 'Images.zip', 'Images', 'SAM_checkpoints', 'Healthy (1).csv', '.ipynb_checkpoints', 'Healthy.csv']


## To extract Images.zip and access first 10 images

In [None]:
images_folder = "/content/drive/My Drive/Oral Cancer Dataset/Images/Images"
print(os.listdir(images_folder)[:10])  # Show first 10 extracted files

['R-235-01.jpg', 'R-235-02.jpg', 'R-235-03.jpg', 'R-235-04.jpg', 'R-235-05.jpg', 'R-235-06.jpg', 'R-235-07.jpg', 'R-235-08.jpg', 'R-236-01.jpg', 'R-236-02.jpg']


## Printing total number of images in D3

In [None]:
# Count number of files in the extracted images directory
num_images = len(os.listdir(images_folder))
print(num_images)

3000


## Load the SAM Model with GPU support

In [None]:
import os
import json
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
from segment_anything import sam_model_registry, SamPredictor
from sklearn.metrics import jaccard_score, precision_score, recall_score

# Paths
CHECKPOINT_PATH = "/content/weights/sam_vit_h_4b8939.pth"
ANNOTATION_PATH = "/content/drive/MyDrive/Oral Cancer Dataset/Annotation.json"
IMAGE_FOLDER = "/content/drive/My Drive/Oral Cancer Dataset/Images/Images"

# Select model type and device
model_type = "vit_h"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load model and move it to GPU
sam = sam_model_registry[model_type](checkpoint=CHECKPOINT_PATH)
sam.to(device)

# Initialize predictor
predictor = SamPredictor(sam)


## Ensuring GPU Usage

In [None]:
print(next(sam.parameters()).device)

cuda:0


## Loading the annotations of images

In [None]:
with open(ANNOTATION_PATH, "r") as f:
    ann_data = json.load(f)

# Map image ID to file name
id_to_filename = {img["id"]: img["file_name"] for img in ann_data["images"]}

## Polygon to Mask Function (Polygon information of lesions are given in Annotations.json)

In [None]:
def polygon_to_mask(polygon, img_shape):
    mask = np.zeros(img_shape, dtype=np.uint8)
    polygon = np.array(polygon).reshape((-1, 2)).astype(np.int32)
    cv2.fillPoly(mask, [polygon], color=1)
    return mask

## Evaluation Metrics Design

In [None]:
def dice_score(gt, pred):
    intersection = np.sum(gt * pred)
    return (2. * intersection) / (np.sum(gt) + np.sum(pred) + 1e-7)

def evaluate(gt_mask, pred_mask):
    gt_flat = gt_mask.flatten()
    pred_flat = pred_mask.flatten()
    iou = jaccard_score(gt_flat, pred_flat)
    dice = dice_score(gt_mask, pred_mask)
    pixel_acc = np.mean(gt_flat == pred_flat)
    return iou, dice, pixel_acc

## Loop Over Dataset and Evaluate: SAM generates mask based on Bounding Boxes annotations, finally they are crossed checked with "Polygon information" in .json. Thus, metrics are calculated.

In [None]:
#Paths to store progress safely in Google Drive
BASE_FOLDER = "/content/drive/MyDrive/Oral Cancer Dataset/SAM_checkpoints"
CHECKPOINT_PATH = os.path.join(BASE_FOLDER, "sam_checkpoint.json")

#Ensure the folder exists
os.makedirs(BASE_FOLDER, exist_ok=True)

In [None]:
#Load previous checkpoint (if it exists)
if os.path.exists(CHECKPOINT_PATH):
    with open(CHECKPOINT_PATH, "r") as f:
        checkpoint = json.load(f)
    processed_ids = set(checkpoint["processed_ids"])
    all_metrics = checkpoint["metrics"]
    print(f"🔄 Resuming from checkpoint. {len(processed_ids)} images already processed.")
else:
    processed_ids = set()
    all_metrics = []
    print("🆕 Starting fresh evaluation...")

skipped = 0

🆕 Starting fresh evaluation...


In [None]:
#Load and format healthy image names (e.g., add '.jpg')
import pandas as pd
healthy_df = pd.read_csv('/content/drive/MyDrive/Oral Cancer Dataset/Healthy.csv')
healthy_images = set(name.strip() + ".jpg" for name in healthy_df['Image Name'].dropna())

In [None]:
healthy_df

Unnamed: 0,Image Name,Category,Clinical Diagnosis,Lesion Annotation Count
0,R-02-05,Healthy,Normal Mucosa,0
1,R-02-06,Healthy,Normal Mucosa,0
2,R-02-07,Healthy,Normal Mucosa,0
3,R-02-08,Healthy,Normal Mucosa,0
4,R-03-01,Healthy,Normal Mucosa,0
...,...,...,...,...
724,N-351-05,Healthy,Normal Mucosa,0
725,N-351-06,Healthy,Normal Mucosa,0
726,N-351-07,Healthy,Normal Mucosa,0
727,N-351-08,Healthy,Normal Mucosa,0


In [None]:
#Loop through annotations and evaluate
for ann in ann_data["annotations"]:
    image_id = ann["image_id"]
    if image_id in processed_ids:
        continue  # Already processed

    image_name = id_to_filename[image_id]

    #Skip healthy images
    if image_name in healthy_images:
        print(f"⏭️ Skipping healthy image: {image_name}")
        skipped += 1
        continue

    image_path = os.path.join(IMAGE_FOLDER, image_name)

    #Load image
    image = cv2.imread(image_path)
    if image is None:
        print(f"❌ Image not found or unreadable: {image_path}")
        skipped += 1
        continue

    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    predictor.set_image(image)

    # Bounding box as SAM prompt
    x, y, w, h = ann["bbox"]
    bbox = np.array([x, y, x + w, y + h])

    # Predict mask using SAM
    masks, _, _ = predictor.predict(box=bbox, multimask_output=False)
    pred_mask = masks[0].astype(np.uint8)

    # Ground truth from polygon segmentation
    gt_mask = polygon_to_mask(ann["segmentation"], image.shape[:2])
    if pred_mask.shape != gt_mask.shape:
        pred_mask = cv2.resize(pred_mask, (gt_mask.shape[1], gt_mask.shape[0]))

    # Evaluate metrics
    iou, dice, acc = evaluate(gt_mask, pred_mask)

    # Store metrics
    all_metrics.append({
        "image_id": image_id,
        "image_name": image_name,
        "iou": iou,
        "dice": dice,
        "accuracy": acc
    })
    processed_ids.add(image_id)

    #Save progress (checkpoint)
    with open(CHECKPOINT_PATH, "w") as f:
        json.dump({
            "processed_ids": list(processed_ids),
            "metrics": all_metrics
        }, f)

    #Visualize results
    fig, axs = plt.subplots(1, 3, figsize=(18, 6))

    axs[0].imshow(image)
    axs[0].set_title(f"Original Image\n{image_name}")
    axs[0].axis('off')

    axs[1].imshow(image)
    axs[1].imshow(pred_mask, alpha=0.5, cmap='jet')
    axs[1].set_title(f"SAM Predicted Mask\nIoU: {iou:.4f}")
    axs[1].axis('off')

    axs[2].imshow(image)
    axs[2].imshow(gt_mask, alpha=0.5, cmap='gray')
    axs[2].set_title(f"Ground Truth Mask\nDice: {dice:.4f}")
    axs[2].axis('off')

    plt.suptitle(f"Pixel Accuracy: {acc:.4f}", fontsize=14, y=1.05)
    plt.tight_layout()
    plt.show()

    print(f"✅ Processed: {image_name} | IoU: {iou:.4f}, Dice: {dice:.4f}, Acc: {acc:.4f}")

## Restarting from checkpoints (saved in sam_checkpoints.json) due to Collab's Memory Runout

In [None]:
#Load previous checkpoint (if it exists)
if os.path.exists(CHECKPOINT_PATH):
    with open(CHECKPOINT_PATH, "r") as f:
        checkpoint = json.load(f)
    processed_ids = set(checkpoint["processed_ids"])
    all_metrics = checkpoint["metrics"]
    print(f"🔄 Resuming from checkpoint. {len(processed_ids)} images already processed.")
else:
    processed_ids = set()
    all_metrics = []
    print("🆕 Starting fresh evaluation...")

skipped = 0

🔄 Resuming from checkpoint. 1383 images already processed.


In [None]:
#Loop through annotations and evaluate
for ann in ann_data["annotations"]:
    image_id = ann["image_id"]
    if image_id in processed_ids:
        continue  # Already processed

    image_name = id_to_filename[image_id]

    #Skip healthy images
    if image_name in healthy_images:
        print(f"⏭️ Skipping healthy image: {image_name}")
        skipped += 1
        continue

    image_path = os.path.join(IMAGE_FOLDER, image_name)

    #Load image
    image = cv2.imread(image_path)
    if image is None:
        print(f"❌ Image not found or unreadable: {image_path}")
        skipped += 1
        continue

    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    predictor.set_image(image)  # <-- Make sure predictor is initialized earlier

    # Bounding box as SAM prompt
    x, y, w, h = ann["bbox"]
    bbox = np.array([x, y, x + w, y + h])

    # Predict mask using SAM
    masks, _, _ = predictor.predict(box=bbox, multimask_output=False)
    pred_mask = masks[0].astype(np.uint8)

    # Ground truth from polygon segmentation
    gt_mask = polygon_to_mask(ann["segmentation"], image.shape[:2])  # <-- define this elsewhere
    if pred_mask.shape != gt_mask.shape:
        pred_mask = cv2.resize(pred_mask, (gt_mask.shape[1], gt_mask.shape[0]))

    # Evaluate metrics
    iou, dice, acc = evaluate(gt_mask, pred_mask)  # <-- define this elsewhere

    # Store metrics
    all_metrics.append({
        "image_id": image_id,
        "image_name": image_name,
        "iou": iou,
        "dice": dice,
        "accuracy": acc
    })
    processed_ids.add(image_id)

    #Save progress (checkpoint)
    with open(CHECKPOINT_PATH, "w") as f:
        json.dump({
            "processed_ids": list(processed_ids),
            "metrics": all_metrics
        }, f)

    #Visualize results
    fig, axs = plt.subplots(1, 3, figsize=(18, 6))

    axs[0].imshow(image)
    axs[0].set_title(f"Original Image\n{image_name}")
    axs[0].axis('off')

    axs[1].imshow(image)
    axs[1].imshow(pred_mask, alpha=0.5, cmap='jet')
    axs[1].set_title(f"SAM Predicted Mask\nIoU: {iou:.4f}")
    axs[1].axis('off')

    axs[2].imshow(image)
    axs[2].imshow(gt_mask, alpha=0.5, cmap='gray')
    axs[2].set_title(f"Ground Truth Mask\nDice: {dice:.4f}")
    axs[2].axis('off')

    plt.suptitle(f"Pixel Accuracy: {acc:.4f}", fontsize=14, y=1.05)
    plt.tight_layout()
    plt.show()

    print(f"✅ Processed: {image_name} | IoU: {iou:.4f}, Dice: {dice:.4f}, Acc: {acc:.4f}")

##Checking whether all 3000 images have been considered or not?
(Cancerous: 2271, Healthy: 729)

In [None]:
#Load previous checkpoint (if it exists)
if os.path.exists(CHECKPOINT_PATH):
    with open(CHECKPOINT_PATH, "r") as f:
        checkpoint = json.load(f)
    processed_ids = set(checkpoint["processed_ids"])
    all_metrics = checkpoint["metrics"]
    print(f"{len(processed_ids)} images already processed.")

2271 images already processed.


In [None]:
#Loop through annotations and evaluate
skipped = 0
for ann in ann_data["annotations"]:
    image_id = ann["image_id"]
    if image_id in processed_ids:
        continue  # Already processed

    image_name = id_to_filename[image_id]

    #Skip healthy images
    if image_name in healthy_images:
        print(f"⏭️ Skipping healthy image: {image_name}")
        skipped += 1
        continue

print(f'No. of Healthy images, which were skipped for mask formation by SAM model: {skipped}')

⏭️ Skipping healthy image: R-02-05.jpg
⏭️ Skipping healthy image: R-02-06.jpg
⏭️ Skipping healthy image: R-02-07.jpg
⏭️ Skipping healthy image: R-02-08.jpg
⏭️ Skipping healthy image: R-03-01.jpg
⏭️ Skipping healthy image: R-03-05.jpg
⏭️ Skipping healthy image: R-04-03.jpg
⏭️ Skipping healthy image: R-04-04.jpg
⏭️ Skipping healthy image: R-04-05.jpg
⏭️ Skipping healthy image: R-06-04.jpg
⏭️ Skipping healthy image: R-06-05.jpg
⏭️ Skipping healthy image: R-06-06.jpg
⏭️ Skipping healthy image: R-06-07.jpg
⏭️ Skipping healthy image: R-07-01.jpg
⏭️ Skipping healthy image: R-07-02.jpg
⏭️ Skipping healthy image: R-07-03.jpg
⏭️ Skipping healthy image: R-07-04.jpg
⏭️ Skipping healthy image: R-07-05.jpg
⏭️ Skipping healthy image: R-08-06.jpg
⏭️ Skipping healthy image: R-13-02.jpg
⏭️ Skipping healthy image: R-14-02.jpg
⏭️ Skipping healthy image: R-14-03.jpg
⏭️ Skipping healthy image: R-14-04.jpg
⏭️ Skipping healthy image: R-14-05.jpg
⏭️ Skipping healthy image: R-15-04.jpg
⏭️ Skipping healthy image

##Final Metrics of SAM model on 3000 images

In [None]:
#Final model-level evaluation
ious = [m["iou"] for m in all_metrics]
dices = [m["dice"] for m in all_metrics]
accs = [m["accuracy"] for m in all_metrics]

print("\n📊 Final Evaluation of SAM:")
print(f"✅ Total images evaluated: {len(ious)}")
print(f"📉 Mean IoU: {np.mean(ious):.4f}")
print(f"📉 Mean Dice: {np.mean(dices):.4f}")
print(f"📉 Mean Pixel Accuracy: {np.mean(accs):.4f}")
print(f"❌ Skipped images: {skipped}")


📊 Final Evaluation of SAM:
✅ Total images evaluated: 2271
📉 Mean IoU: 0.6139
📉 Mean Dice: 0.7405
📉 Mean Pixel Accuracy: 0.8044
❌ Skipped images: 729


##To view "Mask by SAM" and "Ground Truth Mask"
Downloaded .ipynb just after completion:
Screen-Recording: