In [1]:
import sys, os
from pathlib import Path
import numpy as np
import cv2
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader

sys.path.append("/home/gregoryc25/CMAP/segment_anything_source_code")

from segment_anything.build_sam import sam_model_registry
from segment_anything.predictor import SamPredictor

# Define checkpoint path
home_dir = Path.home()
sam_checkpoint = home_dir / "CMAP/segment_anything_source_code/sam_vit_h.pth"

# Ensure the checkpoint file exists
if not sam_checkpoint.exists():
    raise FileNotFoundError(f"Checkpoint file not found: {sam_checkpoint}")

# Load the SAM ViT-H model and create the predictor
sam = sam_model_registry["vit_h"](checkpoint=str(sam_checkpoint))
predictor = SamPredictor(sam)

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sam.to(device)
print("Using device:", device)


Using device: cpu


In [2]:
import os
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2
from PIL import Image

class AerialDataset(Dataset):
    def __init__(self, image_dir, mask_dir, mask_prefix="mask_"):
        self.image_dir = image_dir
        self.mask_dir = mask_dir
        self.mask_prefix = mask_prefix
        
        # List all image files with .tif extension
        self.all_image_files = sorted([f for f in os.listdir(image_dir) if f.lower().endswith(".tif")])
        self.image_files = []
        self.mask_files = []
        for f in self.all_image_files:
            mask_file = self.mask_prefix + f  # prepend 'mask_' to the image filename
            mask_path = os.path.join(mask_dir, mask_file)
            if os.path.exists(mask_path):
                self.image_files.append(f)
                self.mask_files.append(mask_file)
            else:
                print(f"Warning: mask for {f} not found, expecting {mask_file}")
        print(f"Found {len(self.image_files)} image-mask pairs out of {len(self.all_image_files)} images.")

    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        mask_path = os.path.join(self.mask_dir, self.mask_files[idx])
        
        # Open image and mask
        image = Image.open(img_path).convert("RGB")
        mask_img = Image.open(mask_path).convert("L")
        
        # Resize image and mask so that the longest side equals 1024 pixels
        w, h = image.size
        r = min(1024.0 / w, 1024.0 / h)
        new_w, new_h = int(w * r), int(h * r)
        if (new_w, new_h) != (w, h):
            image = image.resize((new_w, new_h), resample=Image.BILINEAR)
            mask_img = mask_img.resize((new_w, new_h), resample=Image.NEAREST)
        
        image_np = np.array(image)
        mask_np = np.array(mask_img)
        
        # Binarize mask (all non-zero as foreground)
        mask_np = (mask_np > 0).astype(np.uint8)
        
        # Erode mask to avoid selecting boundary points (using a 5x5 kernel)
        if mask_np.max() > 0:
            kernel = np.ones((5, 5), np.uint8)
            eroded_mask = cv2.erode(mask_np, kernel, iterations=1)
        else:
            eroded_mask = mask_np
        if eroded_mask.max() == 0:
            eroded_mask = mask_np
        
        # Sample a random foreground point from the eroded mask
        coords = np.argwhere(eroded_mask > 0)
        if coords.size == 0:
            point = np.array([0, 0], dtype=np.int32)
        else:
            iy, ix = coords[np.random.randint(len(coords))]
            point = np.array([int(ix), int(iy)], dtype=np.int32)
        
        return image_np, mask_np, point

train_image_dir = "/net/projects/cmap/data/KC-images"
train_mask_dir  = "/net/projects/cmap/data/KC-masks/single-band-masks"

# Create dataset with updated naming convention
train_dataset = AerialDataset(train_image_dir, train_mask_dir)

# collate function
def sam_collate_fn(batch):
    images, masks, points = zip(*batch)
    return list(images), list(masks), list(points)

if len(train_dataset) > 0:
    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=2, collate_fn=sam_collate_fn)
    print(f"Loaded {len(train_dataset)} training samples.")
else:
    print("Error: No training samples found. Please check your mask directory or file naming convention.")


Found 50 image-mask pairs out of 50 images.
Loaded 50 training samples.


In [3]:
# Freeze image and prompt encoders; only fine-tune the mask decoder
for param in sam.image_encoder.parameters():
    param.requires_grad = False
for param in sam.prompt_encoder.parameters():
    param.requires_grad = False

sam.image_encoder.eval()
sam.prompt_encoder.eval()
sam.mask_decoder.train()

# Set up optimizer for the mask decoder only
optimizer = torch.optim.AdamW(sam.mask_decoder.parameters(), lr=1e-4, weight_decay=1e-4)

# Initialize mixed-precision gradient scaler (if running on GPU)
scaler = torch.cuda.amp.GradScaler()

print("Model encoders frozen. Optimizer for mask decoder is ready.")

Model encoders frozen. Optimizer for mask decoder is ready.


  scaler = torch.cuda.amp.GradScaler()


In [None]:
import torch
import numpy as np

num_epochs       = 5
grad_accum_steps = 4
log_interval     = 10

for epoch in range(num_epochs):
    running_loss = 0.0
    running_iou  = 0.0

    for i, (images, masks, points) in enumerate(train_loader, start=1):
        # 1) Unpack sample
        image_np   = images[0]
        gt_mask_np = masks[0]
        point      = points[0]

        # 2) Cache image embedding
        predictor.set_image(image_np)
        image_embedding = predictor.get_image_embedding().to(device)

        # 3) Prepare prompts
        input_point = np.expand_dims(point, 0).astype(np.int32)
        input_label = np.array([[1]], dtype=np.int32)
        pt_t = torch.from_numpy(input_point).to(device).unsqueeze(0)
        lbl_t = torch.from_numpy(input_label).to(device)

        # 4) Encode prompts
        sparse_emb, dense_emb = sam.prompt_encoder(
            (pt_t, lbl_t), None, None
        )

        # 5) Forward through decoder
        with torch.cuda.amp.autocast():
            low_res_masks, iou_scores = sam.mask_decoder(
                image_embeddings         = image_embedding,
                image_pe                 = sam.prompt_encoder.get_dense_pe().to(device),
                sparse_prompt_embeddings = sparse_emb.to(device),
                dense_prompt_embeddings  = dense_emb.to(device),
                multimask_output         = True,
            )
            low_res = low_res_masks[:, 0:1, ...]
            H, W = gt_mask_np.shape
            up_mask = torch.nn.functional.interpolate(
                low_res, size=(H, W),
                mode='bilinear', align_corners=False
            )

            # **Make GT tensor [1,1,H,W]**
            gt_tensor = torch.from_numpy(gt_mask_np).float().to(device)
            gt_tensor = gt_tensor.unsqueeze(0).unsqueeze(0)

            loss = torch.nn.functional.binary_cross_entropy_with_logits(up_mask, gt_tensor)

        # 6) Backprop + optimizer step
        scaler.scale(loss).backward()
        if i % grad_accum_steps == 0 or i == len(train_loader):
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        running_loss += loss.item()

        # 7) Compute IoU
        with torch.no_grad():
            pred = (torch.sigmoid(up_mask) > 0.5).float()
            inter = (pred * gt_tensor).sum().item()
            union = (pred + gt_tensor - pred*gt_tensor).sum().item()
            running_iou += (inter / union) if union > 0 else 0.0

        # 8) Log
        if i % log_interval == 0:
            avg_l = running_loss / i
            avg_i = running_iou  / i
            print(f"Epoch [{epoch+1}/{num_epochs}], "
                  f"Step [{i}/{len(train_loader)}] – "
                  f"Loss: {avg_l:.4f}, IoU: {avg_i:.4f}")

    # end of epoch
    avg_loss = running_loss / len(train_loader)
    avg_iou  = running_iou  / len(train_loader)
    print(f"Epoch {epoch+1} complete: Avg Loss: {avg_loss:.4f}, Avg IoU: {avg_iou:.4f}\n")


  with torch.cuda.amp.autocast():


In [None]:
# Save the updated model's state_dict (which now contains the fine-tuned mask decoder)
save_path = "sam_vit_h_finetuned.pth"
torch.save(sam.state_dict(), save_path)
print(f"Fine-tuned model saved as {save_path}")