In [None]:
# 1. Install SAM (and dependencies)
%pip install torch torchvision
#git clone https://github.com/facebookresearch/segment-anything.git
#cd segment-anything
#%pip install -e .             # installs `segment_anything`

# 2. (Optional) Also install albumentations for data augmentation
%pip install albumentations

In [None]:
#!/usr/bin/env python3
import os
from glob import glob

import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from PIL import Image, ImageDraw
import torch.nn.functional as F
import tqdm

from segment_anything import sam_model_registry

# --- USER CONFIG ---
DATA_DIR      = "/datax/scratch/jliang/augmented"
TRAIN_IMG_DIR = os.path.join(DATA_DIR, "train", "images")
TRAIN_LBL_DIR = os.path.join(DATA_DIR, "train", "labels")
VAL_IMG_DIR   = os.path.join(DATA_DIR, "val",   "images")
VAL_LBL_DIR   = os.path.join(DATA_DIR, "val",   "labels")

CHECKPOINT    = "/home/jliang/gbt-rfi/segment-anything/sam_vit_b_01ec64.pth"
OUT_WEIGHTS   = "sam_finetuned.pth"
BATCH_SIZE    = 8
LR            = 1e-3
EPOCHS        = 30
IMG_SIZE      = (1024, 1024)
DEVICE        = torch.device("cpu")
# -------------------

class YoloMaskDataset(Dataset):
    def __init__(self, images_dir, labels_dir, img_size, transform=None):
        self.images = sorted(glob(os.path.join(images_dir, "*.*")))
        self.labels = [
            os.path.join(labels_dir, os.path.splitext(os.path.basename(p))[0] + ".txt")
            for p in self.images
        ]
        self.transform = transform or T.Compose([
            T.Resize(img_size),
            T.ToTensor(),
        ])

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = Image.open(self.images[idx]).convert("RGB")
        W, H = img.size
        mask = Image.new("L", (W, H), 0)
        draw = ImageDraw.Draw(mask)
        with open(self.labels[idx]) as f:
            for line in f:
                _, xc, yc, w, h = line.strip().split()
                xc, yc, w, h = map(float, (xc, yc, w, h))
                x0 = (xc - w/2) * W
                y0 = (yc - h/2) * H
                x1 = (xc + w/2) * W
                y1 = (yc + h/2) * H
                draw.rectangle([x0, y0, x1, y1], fill=1)
        img_t = self.transform(img)
        mask_t = self.transform(mask).float()
        return img_t, mask_t


def make_dataloaders():
    train_ds = YoloMaskDataset(TRAIN_IMG_DIR, TRAIN_LBL_DIR, IMG_SIZE)
    val_ds = YoloMaskDataset(VAL_IMG_DIR, VAL_LBL_DIR, IMG_SIZE)
    train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
    val_dl = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
    return train_dl, val_dl


def build_model():
    sam = sam_model_registry["vit_b"](checkpoint=CHECKPOINT)
    for p in sam.image_encoder.parameters():
        p.requires_grad = False
    sam.to(DEVICE)
    return sam


def train():
    train_dl, val_dl = make_dataloaders()
    model = build_model()
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=2, factor=0.5)
    criterion = nn.BCEWithLogitsLoss()

    best_val_loss = float("inf")
    for epoch in range(1, EPOCHS+1):
        print(f"\n=== Epoch {epoch}/{EPOCHS} ===", flush=True)
        # Training
        model.train()
        running_loss = 0.0
        for imgs, masks in tqdm.tqdm(train_dl, desc="Train batches"):
            imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)  # masks shape [B,1,H,W]
            # Encode image to get embeddings
            image_embeddings = model.image_encoder(imgs)
            # Generate prompt embeddings (no prompt)
            sparse_embeddings, dense_embeddings = model.prompt_encoder(points=None, boxes=None, masks=None)
            image_pe = model.prompt_encoder.get_dense_pe().to(DEVICE)

            # Decode mask logits [B,1,Hp,Wp]
            logits, _ = model.mask_decoder(
                image_embeddings=image_embeddings,
                image_pe=image_pe,
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=False,
            )

            # Upsample logits to GT resolution [H, W]
            logits = F.interpolate(logits, size=masks.shape[-2:], mode="bilinear", align_corners=False)

            # Compute loss against masks [B,1,H,W]
            loss = criterion(logits, masks)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            running_loss += loss.item() * imgs.size(0)
        avg_train = running_loss / len(train_dl.dataset)

        # Validation
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for imgs, masks in val_dl:
                imgs, masks = imgs.to(DEVICE), masks.to(DEVICE)
                image_embeddings = model.image_encoder(imgs)
                sparse_embeddings, dense_embeddings = model.prompt_encoder(points=None, boxes=None, masks=None)
                image_pe = model.prompt_encoder.get_dense_pe().to(DEVICE)
                logits, _ = model.mask_decoder(
                    image_embeddings=image_embeddings,
                    image_pe=image_pe,
                    sparse_prompt_embeddings=sparse_embeddings,
                    dense_prompt_embeddings=dense_embeddings,
                    multimask_output=False,
                )
                logits = F.interpolate(logits, size=masks.shape[-2:], mode="bilinear", align_corners=False)
                val_loss += criterion(logits, masks).item() * imgs.size(0)
        avg_val = val_loss / len(val_dl.dataset)
        scheduler.step(avg_val)

        print(f"Epoch {epoch}/{EPOCHS} — train loss: {avg_train:.4f}, val loss: {avg_val:.4f}")

        # Save best
        if avg_val < best_val_loss:
            best_val_loss = avg_val
            torch.save(model.state_dict(), OUT_WEIGHTS)
            print("→ New best model saved.")

if __name__ == "__main__":
    train()
