<a href="https://colab.research.google.com/github/dsubedi753/TORTOISE/blob/main/notebooks/DL_fine_tune.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ref
# https://colab.research.google.com/github/roboflow-ai/notebooks/blob/main/notebooks/fine-tune-sam-2.1.ipynb#scrollTo=CbDFNKNDw6Pq
# https://learnopencv.com/finetuning-sam2/
# https://www.datacamp.com/tutorial/sam2-fine-tuning

In [None]:
from google.colab import drive
drive.mount('/content/drive')
import os
GOOGLE_DRIVE_FOLDER = '/content/drive/MyDrive/GeoCompassSegmentations'

# Install Dependencies
!pip install rasterio
!git clone https://github.com/facebookresearch/sam2.git
%cd ./sam2
!pip install -e .


Mounted at /content/drive
Cloning into 'sam2'...
remote: Enumerating objects: 1070, done.[K
remote: Total 1070 (delta 0), reused 0 (delta 0), pack-reused 1070 (from 1)[K
Receiving objects: 100% (1070/1070), 128.11 MiB | 38.08 MiB/s, done.
Resolving deltas: 100% (381/381), done.
/content/sam2
Obtaining file:///content/sam2
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Collecting hydra-core>=1.3.2 (from SAM-2==1.0)
  Downloading hydra_core-1.3.2-py3-none-any.whl.metadata (5.5 kB)
Collecting iopath>=0.1.10 (from SAM-2==1.0)
  Downloading iopath-0.1.10.tar.gz (42 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.2/42.2 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting portalocker (from iopath>=0.

In [None]:
# Unzip images into local machine (content directory)
!rm -rf '/content/data/tiles'
!mkdir '/content/data'
!cp "/content/drive/MyDrive/GeoCompassSegmentations/data/tile_index.csv" '/content/data/tile_index.csv'
!unzip -d '/content/data/' "/content/drive/MyDrive/GeoCompassSegmentations/data/tiles.zip"

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: /content/data/tiles/tile_rgb_06951.png.msk  
  inflating: /content/data/tiles/tile_rgb_01052.png.msk  
  inflating: /content/data/tiles/tile_rgb_03608.png.msk  
  inflating: /content/data/tiles/tile_label_02954.tif  
  inflating: /content/data/tiles/tile_label_11130.tif  
  inflating: /content/data/tiles/tile_rgb_01253.png.aux.xml  
  inflating: /content/data/tiles/tile_ms_14106.tif  
 extracting: /content/data/tiles/tile_rgb_05773.png  
  inflating: /content/data/tiles/tile_label_14018.tif  
  inflating: /content/data/tiles/tile_rgb_10797.png.msk  
  inflating: /content/data/tiles/tile_ms_03295.tif  
  inflating: /content/data/tiles/tile_rgb_10954.png.msk  
  inflating: /content/data/tiles/tile_rgb_09938.png.msk  
  inflating: /content/data/tiles/tile_label_10017.tif  
 extracting: /content/data/tiles/tile_rgb_05017.png  
  inflating: /content/data/tiles/tile_label_08049.tif  
 extracting: /content/data/tile

In [None]:
import sys
sys.path.append(GOOGLE_DRIVE_FOLDER + '/Colab_Notebooks/TORTOISE/src')

In [None]:
from tortoise.dataloader import build_dataloaders

BATCH_SIZE = 16

train_loader, val_loader, test_loader, _ = build_dataloaders(
    "/content/data/tiles",
    '/content/data/tile_index.csv',
    BATCH_SIZE,
    seed = 42,
    train_ratio = 0.8,
    val_ratio = 0.1,
    test_ratio = None,
    use_rgb=  True,
    use_ms=False,
    num_workers = 8,
)

In [None]:
import torch
import numpy as np
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor


sam2_checkpoint = "sam2.1_hiera_small.pt"
url = "https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt"
model_cfg = "configs/sam2.1/sam2.1_hiera_s.yaml"

device = "cuda"

if not os.path.exists(sam2_checkpoint):
  !wget {url} -O {sam2_checkpoint}

sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
predictor = SAM2ImagePredictor(sam2_model)

--2025-12-08 05:19:06--  https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_small.pt
Resolving dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)... 13.227.219.70, 13.227.219.59, 13.227.219.33, ...
Connecting to dl.fbaipublicfiles.com (dl.fbaipublicfiles.com)|13.227.219.70|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 184416285 (176M) [application/vnd.snesdev-page-table]
Saving to: ‘sam2.1_hiera_small.pt’


2025-12-08 05:19:08 (91.0 MB/s) - ‘sam2.1_hiera_small.pt’ saved [184416285/184416285]



In [None]:
import glob
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

In [None]:
def get_batch_bounding_boxes(masks):
    """
    masks: (B, 1, H, W)
    """
    B, _, H, W = masks.shape
    # (B, 1, H, W) -> (B, H, W)
    masks_flat = masks.squeeze(1)
    rows = masks_flat.any(dim=2) # (B, H)
    cols = masks_flat.any(dim=1) # (B, W)

    y_min = rows.float().argmax(dim=1)
    y_max = (H - 1) - rows.flip(1).float().argmax(dim=1)

    x_min = cols.float().argmax(dim=1)
    x_max = (W - 1) - cols.flip(1).float().argmax(dim=1)
    # Shape: (B, 4)
    boxes = torch.stack([x_min, y_min, x_max, y_max], dim=1).float()
    # scaling
    scale_x = 1024 / W
    scale_y = 1024 / H
    boxes[:, [0, 2]] *= scale_x
    boxes[:, [1, 3]] *= scale_y

    return boxes.unsqueeze(1) # shape (B, 1, 4)


In [None]:
# config
lr = 1e-5
epochs = 5
acc_steps = 4

# only train the mask decoder to save time
for param in sam2_model.image_encoder.parameters():
    param.requires_grad = False
for param in sam2_model.sam_prompt_encoder.parameters():
    param.requires_grad = False
for param in sam2_model.sam_mask_decoder.parameters():
    param.requires_grad = True

# optimizer
optimizer = torch.optim.AdamW(sam2_model.sam_mask_decoder.parameters(), lr=lr)
scaler = torch.cuda.amp.GradScaler()
# loss function
bce_loss_func = nn.BCEWithLogitsLoss()

  scaler = torch.cuda.amp.GradScaler()


In [None]:
# train loop
sam2_model.train()

for epoch in range(epochs):
    #
    epoch_loss = 0
    optimizer.zero_grad()
    # tqdm
    pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f'epoch {epoch+1}')

    for batch_idx, batch in pbar:
        images = batch['rgb'].to(device).float()
        gt_masks = batch['label'].to(device)

        # image resize
        images_resized = F.interpolate(images, size=(1024, 1024), mode='bilinear', align_corners=False)

        # image encoder
        with torch.no_grad(), torch.autocast('cuda', dtype=torch.bfloat16):
            backbone_out = sam2_model.forward_image(images_resized)
            _, vision_feats, _, _ = sam2_model._prepare_backbone_features(backbone_out)

        if sam2_model.directly_add_no_mem_embed:
                    vision_feats[-1] = vision_feats[-1] + sam2_model.no_mem_embed
        # reshape feature(L, B, C)->(B, C, L)->(B, C, H, W)
        feats = []
        for feat in vision_feats:
            feat = feat.permute(1, 2, 0)
            B, C, L = feat.shape
            size = int(L**0.5)
            feat = feat.view(B, C, size, size)
            feats.append(feat)
        # extract feature for decoder
        image_embeddings = feats[-1]
        high_res_features = feats[:-1]

        # promt(bounding boxes)
        boxes = get_batch_bounding_boxes(gt_masks)
        # promt encoder
        with torch.autocast('cuda', dtype=torch.bfloat16):
            #
            sparse_embeddings, dense_embeddings = sam2_model.sam_prompt_encoder(
                points=None,
                boxes=boxes,
                masks=None,
            )
            low_res_masks, iou_predictions, _, _ = sam2_model.sam_mask_decoder(
                image_embeddings=image_embeddings,
                image_pe=sam2_model.sam_prompt_encoder.get_dense_pe(),
                sparse_prompt_embeddings=sparse_embeddings,
                dense_prompt_embeddings=dense_embeddings,
                multimask_output=False,
                repeat_image=False,
                high_res_features=high_res_features
            )
            # upscale masks back to 48x48
            upscaled_masks = F.interpolate(
                low_res_masks,
                size=(48, 48),
                mode='bilinear',
                align_corners=False
            )
            # update loss
            loss = bce_loss_func(upscaled_masks, gt_masks.float()) / acc_steps
        #
        scaler.scale(loss).backward()
        if (batch_idx + 1) % acc_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        epoch_loss += loss.item() * acc_steps
        pbar.set_postfix({'Loss': loss.item() * acc_steps})

    print(f"Epoch {epoch+1} Average Loss: {epoch_loss / len(train_loader):.4f}")

epoch 1: 100%|██████████| 2118/2118 [04:44<00:00,  7.44it/s, Loss=0.168]

Epoch 1 Average Loss: 0.1726



epoch 2: 100%|██████████| 2118/2118 [04:44<00:00,  7.43it/s, Loss=0.00448]

Epoch 2 Average Loss: 0.1621



epoch 3: 100%|██████████| 2118/2118 [04:45<00:00,  7.43it/s, Loss=0.00215]

Epoch 3 Average Loss: 0.1554



epoch 4: 100%|██████████| 2118/2118 [04:44<00:00,  7.43it/s, Loss=0.104]

Epoch 4 Average Loss: 0.1506



epoch 5: 100%|██████████| 2118/2118 [04:45<00:00,  7.43it/s, Loss=0.145]

Epoch 5 Average Loss: 0.1473



epoch 6: 100%|██████████| 2118/2118 [04:44<00:00,  7.43it/s, Loss=0.241]

Epoch 6 Average Loss: 0.1444



epoch 7: 100%|██████████| 2118/2118 [04:45<00:00,  7.43it/s, Loss=0.0602]

Epoch 7 Average Loss: 0.1417



epoch 8: 100%|██████████| 2118/2118 [04:45<00:00,  7.43it/s, Loss=0.0122]

Epoch 8 Average Loss: 0.1394



epoch 9: 100%|██████████| 2118/2118 [04:45<00:00,  7.43it/s, Loss=0.0866]

Epoch 9 Average Loss: 0.1376



epoch 10: 100%|██████████| 2118/2118 [04:45<00:00,  7.43it/s, Loss=0.022]

Epoch 10 Average Loss: 0.1360



