# SAM

In [1]:
import os
import cv2
from copy import deepcopy
from typing import Tuple

import numpy as np

import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn import functional as F
import torchvision.transforms as transforms
from torchvision.transforms.functional import resize, to_pil_image  # type: ignore

from pycocotools.coco import COCO

from transformers import SamModel, SamProcessor

device = "cuda" if torch.cuda.is_available() else "cpu" # device = "cuda:0"

In [16]:
class COCODataset(Dataset):

    def __init__(self, image_dir, annotation_file, transform=None):
        self.coco = COCO(annotation_file)
        self.image_dir = image_dir
        self.image_ids = list(self.coco.imgs.keys())
        self.transform = transform

        # filter out image_ids without any annotations
        self.image_ids = [image_id for image_id in self.image_ids if len(self.coco.getAnnIds(imgIds=image_id)) > 0]
    
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx):
        image_id = self.image_ids[idx]
        image_info = self.coco.loadImgs(image_id)[0]
        
        image_path = os.path.join(self.image_dir, image_info['file_name'])
        image = cv2.imread(image_path)
        if image is None:
            raise FileNotFoundError(f"Failed to load image: {image_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        image_height, image_width, _ = image.shape # (height, width, 3)
        
        ann_ids = self.coco.getAnnIds(imgIds=image_id)
        anns = self.coco.loadAnns(ann_ids)
        bboxes = []
        masks = []
        
        for ann in anns:
            x, y, w, h = ann['bbox']
            bboxes.append([x, y, x+w, y+h])
            # mask = self.coco.annToMask(ann)
            # get binary mask of each bbox
            mask = np.zeros((image_height, image_width), dtype=np.uint8)
            mask[y:y+h, x:x+w] = 1 # or mask[y:y+h+1, x:x+w+1] = 1
            masks.append(mask)
            
        if self.transform:
            image, masks, bboxes = self.transform(image, masks, np.array(bboxes))

        bboxes = np.stack(bboxes, axis=0)
        masks = np.stack(masks, axis=0)
        
        return image, torch.tensor(bboxes), torch.tensor(masks).float()


In [3]:
class ResizeLongestSide:
    """
    Resizes images to the longest side 'target_length', as well as provides
    methods for resizing coordinates and boxes. Provides methods for
    transforming both numpy array and batched torch tensors.
    """

    def __init__(self, target_length: int) -> None:
        self.target_length = target_length

    def apply_image(self, image: np.ndarray) -> np.ndarray:
        """
        Expects a numpy array with shape HxWxC in uint8 format.
        """
        target_size = self.get_preprocess_shape(image.shape[0], image.shape[1], self.target_length)
        return np.array(resize(to_pil_image(image), target_size))

    def apply_coords(self, coords: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
        """
        Expects a numpy array of length 2 in the final dimension. Requires the
        original image size in (H, W) format.
        """
        old_h, old_w = original_size
        new_h, new_w = self.get_preprocess_shape(
            original_size[0], original_size[1], self.target_length
        )
        coords = deepcopy(coords).astype(float)
        coords[..., 0] = coords[..., 0] * (new_w / old_w)
        coords[..., 1] = coords[..., 1] * (new_h / old_h)
        return coords
    
    def apply_boxes(self, boxes: np.ndarray, original_size: Tuple[int, ...]) -> np.ndarray:
        """
        Expects a numpy array shape Bx4. Requires the original image size
        in (H, W) format.
        """
        boxes = self.apply_coords(boxes.reshape(-1, 2, 2), original_size)
        return boxes.reshape(-1, 4)

    def apply_image_torch(self, image: torch.Tensor) -> torch.Tensor:
        """
        Expects batched images with shape BxCxHxW and float format. This
        transformation may not exactly match apply_image. apply_image is
        the transformation expected by the model.
        """
        # Expects an image in BCHW format. May not exactly match apply_image.
        target_size = self.get_preprocess_shape(image.shape[2], image.shape[3], self.target_length)
        return F.interpolate(
            image, target_size, mode="bilinear", align_corners=False, antialias=True
        )
        
    def apply_coords_torch(
        self, coords: torch.Tensor, original_size: Tuple[int, ...]
    ) -> torch.Tensor:
        """
        Expects a torch tensor with length 2 in the last dimension. Requires the
        original image size in (H, W) format.
        """
        old_h, old_w = original_size
        new_h, new_w = self.get_preprocess_shape(
            original_size[0], original_size[1], self.target_length
        )
        coords = deepcopy(coords).to(torch.float)
        coords[..., 0] = coords[..., 0] * (new_w / old_w)
        coords[..., 1] = coords[..., 1] * (new_h / old_h)
        return coords

    def apply_boxes_torch(
        self, boxes: torch.Tensor, original_size: Tuple[int, ...]
    ) -> torch.Tensor:
        """
        Expects a torch tensor with shape Bx4. Requires the original image
        size in (H, W) format.
        """
        boxes = self.apply_coords_torch(boxes.reshape(-1, 2, 2), original_size)
        return boxes.reshape(-1, 4)
    
    @staticmethod
    def get_preprocess_shape(oldh: int, oldw: int, long_side_length: int) -> Tuple[int, int]:
        """
        Compute the output size given input size and target long side length.
        """
        scale = long_side_length * 1.0 / max(oldh, oldw)
        newh, neww = oldh * scale, oldw * scale
        neww = int(neww + 0.5)
        newh = int(newh + 0.5)
        return (newh, neww)


In [4]:
class ResizeAndPad:

    def __init__(self, target_size):
        self.target_size = target_size
        self.transform = ResizeLongestSide(target_size)
        self.to_tensor = transforms.ToTensor()

    def __call__(self, image, masks, bboxes):
        # resize image and masks
        og_h, og_w, _ = image.shape
        image = self.transform.apply_image(image)
        masks = [torch.tensor(self.transform.apply_image(mask)) for mask in masks]
        image = self.to_tensor(image)

        # pad image and masks to form a square (e.g. go from torch.Size([3, 215, 160]) to torch.Size([3, 215, 215]))
        _, h, w = image.shape
        max_dim = max(w, h)
        pad_w = (max_dim - w) // 2
        pad_h = (max_dim - h) // 2

        padding = (pad_w, pad_h, max_dim - w - pad_w, max_dim - h - pad_h)
        image = transforms.Pad(padding)(image)
        masks = [transforms.Pad(padding)(mask) for mask in masks]

        # adjust bounding boxes
        bboxes = self.transform.apply_boxes(bboxes, (og_h, og_w))
        bboxes = [[bbox[0] + pad_w, bbox[1] + pad_h, bbox[2] + pad_w, bbox[3] + pad_h] for bbox in bboxes]
        
        return image, masks, bboxes
    
def collate_fn(batch):
    images, bboxes, masks = zip(*batch)
    images = torch.stack(images)
    return images, bboxes, masks
    

In [17]:
batch_size = 12
image_size = 256
num_workers = 4

transform = ResizeAndPad(image_size)

train = COCODataset(image_dir="/scratch/students/danae/data/images", 
                    annotation_file="/scratch/students/danae/data/model_data_format/yolo/coco_annotations.json",
                    transform=transform)

val = COCODataset(image_dir="/scratch/students/danae/data/images", 
                  annotation_file="/scratch/students/danae/data/model_data_format/yolo/coco_annotations.json",
                  transform=transform)

train_dataloader = DataLoader(train,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=num_workers,
                              collate_fn=collate_fn)

val_dataloader = DataLoader(train,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=num_workers,
                            collate_fn=collate_fn)

loading annotations into memory...
Done (t=0.72s)
creating index...
index created!
loading annotations into memory...
Done (t=0.68s)
creating index...
index created!


In [18]:
class SAMDataset(Dataset):
  def __init__(self, dataset, processor):
    self.dataset = dataset
    self.processor = processor

  def __len__(self):
    return len(self.dataset)
  
  def get_bounding_box(self, bboxes):
    x = [x.detach().item() for i in [0,2] for x in bboxes[:,i]]
    y = [y.detach().item() for i in [1,3] for y in bboxes[:,i]]
    x0, y0, x1, y1 = min(x), min(y), max(x), max(y)
    assert x1 >= x0 and y1 >= y0
    bbox = [x0, y0, x1, y1]
    return bbox

  def __getitem__(self, idx):
    item = self.dataset[idx]
    image = item[0]
    prompts = item[1] # bboxes
    
    # prepare prompt for the model
    prompt = self.get_bounding_box(prompts) # 1 big bbox
    inputs = self.processor(image, input_boxes=[[prompt]], return_tensors="pt")

    # remove batch dimension which the processor adds by default
    inputs = {k:v.squeeze(0) for k,v in inputs.items()}

    # add ground truth segmentation
    masks = item[2]
    inputs["ground_truth_mask"] = torch.any(masks, dim=0).int()

    return inputs
  

In [19]:
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
model = SamModel.from_pretrained("facebook/sam-vit-base")
model = model.to(device)

# make sure we only compute gradients for mask decoder
for name, param in model.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)

In [20]:
train_dataset = SAMDataset(dataset=train, processor=processor)

In [22]:
train_dataloader = DataLoader(train_dataset, batch_size=5, shuffle=True)

In [23]:
from torch.optim import Adam
import monai

# Note: hyperparameter tuning could improve performance here
optimizer = Adam(model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)

seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

In [None]:
from tqdm import tqdm
from statistics import mean
import torch
from torch.nn.functional import threshold, normalize

num_epochs = 100

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

model.train()
for epoch in range(num_epochs):
    epoch_losses = []
    for batch in tqdm(train_dataloader):
      # forward pass
      outputs = model(pixel_values=batch["pixel_values"].to(device),
                      input_boxes=batch["input_boxes"].to(device),
                      multimask_output=False)

      # compute loss
      predicted_masks = outputs.pred_masks.squeeze(1)
      ground_truth_masks = batch["ground_truth_mask"].float().to(device)
      loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))

      # backward pass (compute gradients of parameters w.r.t. loss)
      optimizer.zero_grad()
      loss.backward()

      # optimize
      optimizer.step()
      epoch_losses.append(loss.item())

    print(f'EPOCH: {epoch}')
    print(f'Mean loss: {mean(epoch_losses)}')

In [None]:
num_epochs = 100

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

# Training loop with validation
model.train()
for epoch in range(num_epochs):
    # Training phase
    model.train()  # Ensure the model is in training mode
    epoch_losses = []
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs} - Training"):
        # Forward pass
        outputs = model(pixel_values=batch["pixel_values"].to(device),
                        input_boxes=batch["input_boxes"].to(device),
                        multimask_output=False)

        # Compute loss
        predicted_masks = outputs.pred_masks.squeeze(1)
        ground_truth_masks = batch["ground_truth_mask"].float().to(device)
        loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))

        # Backward pass (compute gradients of parameters w.r.t. loss)
        optimizer.zero_grad()
        loss.backward()

        # Optimize
        optimizer.step()
        epoch_losses.append(loss.item())

    print(f'EPOCH: {epoch+1}')
    print(f'Mean Training Loss: {mean(epoch_losses):.4f}')

    # Validation phase
    model.eval()  # Set the model to evaluation mode
    val_losses = []
    with torch.no_grad():  # Disable gradient computation for validation
        for batch in tqdm(val_dataloader, desc=f"Epoch {epoch+1}/{num_epochs} - Validation"):
            # Forward pass
            outputs = model(pixel_values=batch["pixel_values"].to(device),
                            input_boxes=batch["input_boxes"].to(device),
                            multimask_output=False)

            # Compute loss
            predicted_masks = outputs.pred_masks.squeeze(1)
            ground_truth_masks = batch["ground_truth_mask"].float().to(device)
            loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))
            val_losses.append(loss.item())

    print(f'Mean Validation Loss: {mean(val_losses):.4f}')

In [4]:
class CocoSegmentationDataset(Dataset):
    def __init__(self, annotation_file, img_dir, transforms=None):
        self.coco = COCO(annotation_file)
        self.img_dir = img_dir
        self.transforms = transforms
        self.image_ids = list(self.coco.imgs.keys())
    
    def __len__(self):
        return len(self.image_ids)
    
    def __getitem__(self, idx):
        img_id = self.image_ids[idx]
        img_info = self.coco.loadImgs(img_id)[0]
        
        img_path = os.path.join(self.img_dir, img_info['filename'])
        image = cv2.imread(img_path)
        if image is None:
            raise FileNotFoundError(f"Failed to load image: {img_path}")
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        
        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        anns = self.coco.loadAnns(ann_ids)
        
        # create an empty mask
        mask = np.zeros((img_info['height'], img_info['width']), dtype=np.uint8)
        # convert bounding boxes to masks
        for ann in anns:
            bbox = ann['bbox']  # COCO format: [x, y, width, height]
            x, y, w, h = map(int, bbox)
            cv2.rectangle(mask, (x, y), (x + w, y + h), color=1, thickness=-1)
        
        if self.transforms:
            augmented = self.transforms(image=image, mask=mask)
            image = augmented["image"]
            mask = augmented["mask"]
        
        return image, mask

In [None]:
class SAMDataset(Dataset):
    def __init__(self, dataset, processor, image_size=(512, 512)):
        self.dataset = dataset
        self.processor = processor
        self.image_size = image_size

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        # get item from dataset
        item = self.dataset[idx]
        image = item[0]  # assumes image is in PIL or array format
        prompts = item[1]  # List of bounding boxes

        # convert bounding boxes to prompts
        input_boxes = self._convert_bboxes_to_prompts(prompts)

        # Prepare image and prompts for the processor
        inputs = self.processor(image, input_boxes=input_boxes, return_tensors="pt")

        # Remove batch dimension added by the processor
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}

        # Create ground truth segmentation mask
        ground_truth_mask = self._create_ground_truth_map(prompts, image.size)

        # Add ground truth mask to the inputs
        inputs["ground_truth_mask"] = ground_truth_mask

        return inputs

    def _convert_bboxes_to_prompts(self, bboxes):
        """
        Converts bounding boxes into the format required by the processor.
        Args:
            bboxes: List of bounding boxes in [x, y, w, h] format.
        Returns:
            List of bounding boxes in [x_min, y_min, x_max, y_max] format.
        """
        converted = [[x, y, x + w, y + h] for x, y, w, h in bboxes]
        return converted

    def _create_ground_truth_map(self, bboxes, image_size):
        """
        Creates a binary mask where bounding box areas are marked as 1.
        Args:
            bboxes: List of bounding boxes in [x_min, y_min, x_max, y_max] format.
            image_size: Tuple of (width, height) of the image.
        Returns:
            A PyTorch tensor representing the ground truth mask.
        """
        mask = torch.zeros(image_size[1], image_size[0], dtype=torch.float32)  # (height, width)
        for bbox in bboxes:
            x_min, y_min, x_max, y_max = map(int, bbox)
            mask[y_min:y_max, x_min:x_max] = 1.0
        return mask

In [112]:
from transformers import SamProcessor

processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
model = SamModel.from_pretrained("facebook/sam-vit-base")
model = model.to(device)

# make sure we only compute gradients for mask decoder
for name, param in model.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)

In [22]:
! export CUDA_VISIBLE_DEVICES=0
# nvidia-smi # ps aux | grep 1045526

In [23]:
train_dataset = SAMDataset(dataset=dataset["train"], processor=processor)
train_dataloader = DataLoader(train_dataset, batch_size=2, shuffle=True)

NameError: name 'SAMDataset' is not defined

In [None]:
from torch.optim import Adam
import monai

# Hyperparameter tuning could improve performance here
optimizer = Adam(model.mask_decoder.parameters(), lr=1e-5, weight_decay=0)

seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

In [None]:
from tqdm import tqdm
from statistics import mean
import torch
from torch.nn.functional import threshold, normalize

num_epochs = 100

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

model.train()
for epoch in range(num_epochs):
    epoch_losses = []
    for batch in tqdm(train_dataloader):
      # forward pass
      outputs = model(pixel_values=batch["pixel_values"].to(device),
                      input_boxes=batch["input_boxes"].to(device),
                      multimask_output=False)

      # compute loss
      predicted_masks = outputs.pred_masks.squeeze(1)
      ground_truth_masks = batch["ground_truth_mask"].float().to(device)
      loss = seg_loss(predicted_masks, ground_truth_masks.unsqueeze(1))

      # backward pass (compute gradients of parameters w.r.t. loss)
      optimizer.zero_grad()
      loss.backward()

      # optimize
      optimizer.step()
      epoch_losses.append(loss.item())

    print(f'EPOCH: {epoch}')
    print(f'Mean loss: {mean(epoch_losses)}')
     