In [3]:
!git clone https://github.com/sagieppel/fine-tune-train_segment_anything_2_in_60_lines_of_code.git
%cd /kaggle/working/fine-tune-train_segment_anything_2_in_60_lines_of_code
%pip install -e .

fatal: destination path 'fine-tune-train_segment_anything_2_in_60_lines_of_code' already exists and is not an empty directory.
/kaggle/working/fine-tune-train_segment_anything_2_in_60_lines_of_code
Obtaining file:///kaggle/working/fine-tune-train_segment_anything_2_in_60_lines_of_code
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports build_editable ... [?25l[?25hdone
  Getting requirements to build editable ... [?25l[?25hdone
  Preparing editable metadata (pyproject.toml) ... [?25l[?25hdone
Collecting hydra-core>=1.3.2 (from SAM-2==1.0)
  Using cached hydra_core-1.3.2-py3-none-any.whl.metadata (5.5 kB)
Collecting iopath>=0.1.10 (from SAM-2==1.0)
  Using cached iopath-0.1.10.tar.gz (42 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting portalocker (from iopath>=0.1.10->SAM-2==1.0)
  Using cached portalocker-3.1.1-py3-none-any.whl.metadata (8.6 kB)
Using cached hydra_core-1.3.2-py3-none-any.whl (154 kB)
Using cached portalock

In [5]:
import kagglehub
path = kagglehub.model_download("metaresearch/segment-anything-2/pyTorch/sam2-hiera-tiny")

In [7]:
import sys
sys.path.append("/kaggle/input/segment-anything-2/pytorch/sam2-hiera-tiny/1/")

In [None]:
import hydra
import numpy as np
import torch
import cv2
import os
import gc
import matplotlib.pyplot as plt
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

# Configurations
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

# Load model
hydra.core.global_hydra.GlobalHydra.instance().clear()
hydra.initialize_config_module('sam2', version_base='1.2')
sam2_checkpoint = f"/kaggle/input/segment-anything-2/pytorch/sam2-hiera-tiny/1/sam2_hiera_tiny.pt"
model_cfg = "../sam2_configs/sam2_hiera_t.yaml" #  model config

sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda") # load model
predictor = SAM2ImagePredictor(sam2_model)

max_res = 1024

predictor.model.sam_mask_decoder.train(True) # enable training of mask decoder
predictor.model.sam_prompt_encoder.train(True) # enable training of prompt encoder

optimizer = torch.optim.AdamW(params=predictor.model.parameters(), lr=1e-5, weight_decay=4e-5)
scaler = torch.cuda.amp.GradScaler() # mixed precision

# Path to your dataset folder
data_dir = "/kaggle/input/subset-mosaic/"
data = []  # list of files in dataset

# Function to clean up memory after each operation
def cleanup():
    gc.collect()
    torch.cuda.empty_cache()

cleanup()

# Prepare dataset
for name in os.listdir(data_dir):
    if name.endswith('.jpg'):  # Only process .jpg images
        image_path = os.path.join(data_dir, name)
        mask_path = os.path.join(data_dir, name[:-4] + ".png")
        data.append({"image": image_path, "annotation": mask_path})

def read_batch(data):
    ent = data[np.random.randint(len(data))]
    img = cv2.imread(ent["image"])[..., ::-1]  
    ann_map = cv2.imread(ent["annotation"], cv2.IMREAD_GRAYSCALE)  

    r = np.min([max_res / img.shape[1], max_res / img.shape[0]])  
    img = cv2.resize(img, (int(img.shape[1] * r), int(img.shape[0] * r)))  
    ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)), interpolation=cv2.INTER_NEAREST)  

    # Ensure binary
    ann_map = (ann_map > 127).astype(np.uint8) * 255  

    masks = []
    points = []

    # Find contours
    contours, _ = cv2.findContours((255 - ann_map).copy(), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

    for i, contour in enumerate(contours):
        if i == 75: break
        if len(contour) >= 3:
            mask = np.zeros_like(ann_map)
            cv2.drawContours(mask, [contour], -1, 255, thickness=cv2.FILLED)
            
            # Compute centroid using image moments
            M = cv2.moments(contour)
            if M["m00"] != 0:
                cx = int(M["m10"] / M["m00"])
                cy = int(M["m01"] / M["m00"])
                if (cx != 0 and cy != 0):
                    points.append([[cx, cy]])
                    masks.append(mask)

    return img, np.array(masks), np.array(points), np.ones([len(masks), 1])

def visualize_entry(img, masks, points):
    # Plot the input image
    plt.figure(figsize=(8, 8))
    plt.imshow(img)
    plt.title("Input Image (Resized)")
    plt.axis("on")
    plt.show()

    # Plot the combined binary annotation mask
    plt.figure(figsize=(8, 8))
    combined_mask = np.zeros_like(masks[0], dtype=np.uint8)
    for mask in masks:
        combined_mask = np.maximum(combined_mask, mask)  
    plt.imshow(combined_mask, cmap='gray')  
    plt.title("Combined Mask (Tesserae in White)")
    plt.axis("on")
    plt.show()

    # Plot the image with truly random points inside tesserae
    plt.figure(figsize=(8, 8))
    plt.imshow(img)
    for point in points:
        plt.plot(point[0][0], point[0][1], 'ro', markersize=2)
    plt.title("Image with Randomly Distributed Points")
    plt.axis("on")
    plt.show()

# for i in range(40):
#     image, masks, input_point, input_label = read_batch(data)
#     visualize_entry(image, masks, input_point)


def visualize_training_results(image, masks, predicted_mask, iou, itr):
    # Plot the input image
    plt.figure(figsize=(8, 8))
    plt.imshow(image)
    plt.title("Input Image (Resized)")
    plt.axis("on")
    plt.show()

    # Plot the ground truth mask
    plt.figure(figsize=(8, 8))
    combined_mask = np.zeros_like(masks[0], dtype=np.uint8)
    for mask in masks:
        combined_mask = np.maximum(combined_mask, mask)
    plt.imshow(combined_mask, cmap='gray')
    plt.title("Ground Truth Mask")
    plt.axis("on")
    plt.show()

    # Plot the predicted mask
    plt.figure(figsize=(8, 8))
    plt.imshow(predicted_mask[0, 0].cpu().detach().numpy(), cmap='gray')
    plt.title("Predicted Mask")
    plt.axis("on")
    plt.show()

    # Plot the IoU score over time
    plt.figure(figsize=(8, 8))
    plt.plot(itr, iou, label="IoU")
    plt.xlabel('Iterations')
    plt.ylabel('IoU')
    plt.title("IoU Over Time")
    plt.legend()
    plt.show()
    
torch.cuda.empty_cache()
gc.collect()

mean_iou = 0
# Inside your training loop:
for itr in range(100000):
    with torch.cuda.amp.autocast(): # cast to mixed precision
        image, masks, input_point, input_label = read_batch(data)
        if (masks.shape[0] == 0):
            continue
        # visualize_entry(image, masks, input_point)

        # Segment the image using SAM
        predictor.set_image(image)  # apply SAM image encoder to the image

        # Prompt encoding
        mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(
            input_point, input_label, box=None, mask_logits=None, normalize_coords=True
            )
        sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
            points=(unnorm_coords, labels), boxes=None, masks=None
            )

        # Mask decoder
        batched_mode = unnorm_coords.shape[0] > 1  # multi object prediction
        high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
        low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
            image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
            image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=True,
            repeat_image=batched_mode,
            high_res_features=high_res_features
            )

        # Upscale the masks to the original image resolution
        prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])

        # Segmentation Loss calculation
        gt_mask = torch.tensor((masks / 255).astype(np.float16)).cuda()
        prd_mask = torch.sigmoid(prd_masks[:, 0])  # Turn logit map to probability map
        seg_loss = (-gt_mask * torch.log(prd_mask + 0.00001) - (1 - gt_mask) * torch.log((1 - prd_mask) + 0.00001)).mean()

        # Score loss calculation (intersection over union) IoU
        inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
        iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)
        score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
        loss = seg_loss + score_loss * 0.05  # mix losses

        # Backpropagation
        predictor.model.zero_grad()  # empty gradient
        scaler.scale(loss).backward()  # Backpropagate
        scaler.step(optimizer)
        scaler.update()  # Mix precision

        if itr % 1000 == 0:
            torch.save(predictor.model.state_dict(), "model.torch")
            print("Saved model.")

        mean_iou = mean_iou * 0.99 + 0.01 * np.mean(iou.cpu().detach().numpy())
        print(f"step {itr} Accuracy (IoU) = {mean_iou}")

        # visualize_training_results(image, masks, prd_masks, mean_iou, itr)


Saved model.
step 0 Accuracy (IoU) = 6.256641820073128e-05
step 1 Accuracy (IoU) = 0.0001366119908168912
step 2 Accuracy (IoU) = 0.0001975794444959611
step 3 Accuracy (IoU) = 0.0003218547161545791
step 4 Accuracy (IoU) = 0.00046944520556084046
step 5 Accuracy (IoU) = 0.0006993483209741136
step 6 Accuracy (IoU) = 0.0008327832877229353
step 7 Accuracy (IoU) = 0.0009829001537456903
step 8 Accuracy (IoU) = 0.0011296000460664277
step 9 Accuracy (IoU) = 0.0013178687968433943
step 10 Accuracy (IoU) = 0.0015247658800875382
step 11 Accuracy (IoU) = 0.0016138274486042345
step 12 Accuracy (IoU) = 0.0019940228089715554
step 13 Accuracy (IoU) = 0.0023870510152752458
step 14 Accuracy (IoU) = 0.0028576113898413883
step 15 Accuracy (IoU) = 0.0029800553595194152
step 16 Accuracy (IoU) = 0.0035281063770668942
step 17 Accuracy (IoU) = 0.003848554401483443
step 18 Accuracy (IoU) = 0.004483121526040081
step 19 Accuracy (IoU) = 0.004894046088436712
step 20 Accuracy (IoU) = 0.005285397911444977
step 21 Accur