In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
import os
from PIL import Image

from tqdm import tqdm

import hydra
from hydra import compose, initialize
from omegaconf import OmegaConf

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator

import numpy as np
import matplotlib.pyplot as plt
import cv2

  OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()


In [2]:
CONFIG_PATH = "./sam2/sam2/configs/sam2.1/"
YAML_CONFIG = "sam2.1_hiera_b+.yaml"
# YAML_CONFIG = "sam2.1_hiera_t.yaml"

FULL_PATH = os.path.join(CONFIG_PATH, YAML_CONFIG)

TEST_IMG_PATH = "./test_images/"
images = os.listdir(TEST_IMG_PATH)
images = [f"{TEST_IMG_PATH}{img}" for img in images]

initialize(version_base=None, config_path=CONFIG_PATH, job_name="test_app")
cfg = compose(config_name=YAML_CONFIG, overrides=[])

# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

if device.type == "cuda":
    # use bfloat16 for the entire notebook
    torch.autocast("cuda", dtype=torch.bfloat16).__enter__()
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    if torch.cuda.get_device_properties(0).major >= 8:
        torch.backends.cuda.matmul.allow_tf32 = True
        torch.backends.cudnn.allow_tf32 = True
elif device.type == "mps":
    print(
        "\nSupport for MPS devices is preliminary. SAM 2 is trained with CUDA and might "
        "give numerically different outputs and sometimes degraded performance on MPS. "
        "See e.g. https://github.com/pytorch/pytorch/issues/84936 for a discussion."
    )

using device: cuda


In [3]:
checkpoint = "./sam2/checkpoints/sam2.1_hiera_base_plus.pt"
# checkpoint = "./sam2/checkpoints/sam2.1_hiera_tiny.pt"
model_cfg = YAML_CONFIG
# predictor = SAM2ImagePredictor(build_sam2(YAML_CONFIG, checkpoint))
sam2 = build_sam2(model_cfg, checkpoint, device=device, apply_postprocessing=False)
mask_generator = SAM2AutomaticMaskGenerator(
    model=sam2,
    # points_per_side=16,
    # points_per_batch=16,
    pred_iou_thresh=0.9,
    # stability_score_thresh=0.92,
    # stability_score_offset=0.7,
    # crop_n_layers=1,
    # box_nms_thresh=0.7,
    # crop_n_points_downscale_factor=2,
    # min_mask_region_area=200.0,
    # use_m2m=True,
)

# mask_generator = SAM2ImagePredictor(
#     sam_model=sam2,
#     min_mask_region_area=2000.0,
# )

In [4]:
class ASSRDataset(Dataset):
    def __init__(self, image_dir, mask_dir, transform=None):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.image_filenames = os.listdir(image_dir)
        self.transform = transform

    def __len__(self):
        return len(self.image_filenames)

    def __getitem__(self, idx):
        img_name = os.path.join(self.image_dir, self.image_filenames[idx])
        mask_name = os.path.join(self.mask_dir, self.image_filenames[idx].replace('.jpg', '.png'))  # adjust if needed

        image = Image.open(img_name).convert('RGB')
        mask = Image.open(mask_name).convert('L')  # Assuming masks are grayscale

        if self.transform:
            image = self.transform(image)
            mask = self.transform(mask)

        return image, mask


In [5]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

In [6]:
train_dataset = ASSRDataset(image_dir='./ASSR/images/train', mask_dir='./ASSR/gt/train', transform=transform)
train_loader = DataLoader(train_dataset, batch_size=8)

In [7]:
# Fine tune SAM2 on ASSR dataset
sam2.train()

# Define loss function
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(sam2.parameters(), lr=0.001)

# Train the model
num_epochs = 10
for epoch in range(num_epochs):
    for i, (images, masks) in enumerate(train_loader):
        images = images.to(device)
        masks = masks.to(device)

        # Forward pass
        outputs = sam2(images)
        loss = criterion(outputs, masks)

        # Backward and optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

NotImplementedError: Please use the corresponding methods in SAM2VideoPredictor for inference or SAM2Train for training/fine-tuningSee notebooks/video_predictor_example.ipynb for an inference example.