In [None]:
# 1. Install SAM (and dependencies)
pip install torch torchvision
git clone https://github.com/facebookresearch/segment-anything.git
cd segment-anything
pip install -e .             # installs `segment_anything`

# 2. (Optional) Also install albumentations for data augmentation
pip install albumentations

In [None]:
# train_sam.py

import os
from glob import glob
from PIL import Image
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from segment_anything import sam_model_registry, SamPredictor

# 3. Dataset definition
class MaskDataset(Dataset):
    def __init__(self,
                 images_dir: str,
                 masks_dir: str,
                 transform=None):
        self.images = sorted(glob(os.path.join(images_dir, "*.jpg")))
        self.masks = [os.path.join(masks_dir, os.path.basename(p).replace('.jpg','.png'))
                      for p in self.images]
        self.transform = transform or T.Compose([
            T.Resize((512,512)),
            T.ToTensor(),
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = Image.open(self.images[idx]).convert("RGB")
        msk = Image.open(self.masks[idx]).convert("L")  # binary mask
        img = self.transform(img)
        msk = (self.transform(msk) > 0.5).float()
        return img, msk

# 4. Set up dataloaders
train_ds = MaskDataset("data/images/train", "data/masks/train")
val_ds   = MaskDataset("data/images/val",   "data/masks/val")
train_dl = DataLoader(train_ds, batch_size=8, shuffle=True, num_workers=4)
val_dl   = DataLoader(val_ds,   batch_size=8, shuffle=False,num_workers=4)

# 5. Load a pretrained SAM + attach simple mask head
#    We'll fine-tune the 'prompt_encoder' + mask decoder layers.
sam = sam_model_registry["vit_b"](checkpoint="sam_vit_b_01ec64.pth")
# freeze image encoder
for p in sam.image_encoder.parameters(): p.requires_grad = False

# optionally: freeze the prompt encoder and only train the mask decoder:
# for p in sam.prompt_encoder.parameters(): p.requires_grad = False

sam.to(device := torch.device("cuda"))
optimizer = torch.optim.Adam(filter(lambda x: x.requires_grad, sam.parameters()), lr=1e-4)
criterion = nn.BCEWithLogitsLoss()

# 6. Training loop
for epoch in range(10):
    sam.train()
    total_loss = 0
    for imgs, masks in train_dl:
        imgs, masks = imgs.to(device), masks.to(device)
        # SAM forward: returns low-level mask logits
        outputs = sam.mask_decoder(
            image_embeddings = sam.image_encoder(imgs),
            # for simplicity we don't pass any prompt; decoder can predict full mask
            # you can experiment supplying sparse prompts if you have them
            sparse_prompt_embeddings=None,
            dense_prompt_embeddings=None,
        )  # dict with 'pred_mask'
        logits = outputs['pred_mask']
        loss = criterion(logits, masks.unsqueeze(1))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch} train loss: {total_loss/len(train_dl):.4f}")

    # validation
    sam.eval()
    with torch.no_grad():
        val_loss = 0
        for imgs, masks in val_dl:
            imgs, masks = imgs.to(device), masks.to(device)
            out = sam.mask_decoder(
                image_embeddings = sam.image_encoder(imgs),
                sparse_prompt_embeddings=None,
                dense_prompt_embeddings=None,
            )
            val_loss += criterion(out['pred_mask'], masks.unsqueeze(1)).item()
        print(f"   val loss: {val_loss/len(val_dl):.4f}")

# 7. Save your fine-tuned SAM
torch.save(sam.state_dict(), "sam_finetuned.pth")
