In [None]:
!git clone https://github.com/ckapelonis02/sam2-fine-tune.git
%cd /kaggle/working/sam2-fine-tune

In [None]:
%pip install -e .

In [None]:
import kagglehub
path = kagglehub.model_download("metaresearch/segment-anything-2/pyTorch/sam2-hiera-tiny")

In [None]:
import sys
sys.path.append("/kaggle/input/segment-anything-2/pytorch/sam2-hiera-tiny/1/")

In [None]:
import hydra
import numpy as np
import torch
import cv2
import os
import random
import matplotlib.pyplot as plt
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from sam2.train_helper import read_batch
from sam2.train_helper import read_dataset
from sam2.train_helper import visualize_entry
from sam2.train_helper import cleanup

cleanup()

# Configurations
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

hydra.core.global_hydra.GlobalHydra.instance().clear()
hydra.initialize_config_module('sam2', version_base='1.2')

sam2_model = build_sam2(
    config_file="../sam2_configs/sam2_hiera_t.yaml",
    ckpt_path="/kaggle/input/segment-anything-2/pytorch/sam2-hiera-tiny/1/sam2_hiera_tiny.pt",
    device="cuda",
    apply_postprocessing=False
)

predictor = SAM2ImagePredictor(sam2_model)
predictor.model.sam_mask_decoder.train(True)
predictor.model.sam_prompt_encoder.train(True)
optimizer = torch.optim.AdamW(
    params=predictor.model.parameters(),
    lr=1e-5,
    weight_decay=4e-5
)

scaler = torch.cuda.amp.GradScaler()

# with open("/kaggle/input/data-2k-cropped/sorted_ancient.txt", "r") as file:
#     file_names = [int(line.strip()) for line in file]

data_size = 128
file_names = [i+1 for i in range(data_size)]
top_files = file_names[:data_size]

random.shuffle(top_files)

data_dict = read_dataset(
    images_path="/kaggle/input/real-data-128/images_final",
    masks_path="/kaggle/working/sam2-fine-tune/inverted",
    file_names=top_files
)

mean_iou = 0
max_masks = 150
epochs = 10
for itr in range(data_size * epochs):
    with torch.cuda.amp.autocast():
        image, masks, input_point, input_label = read_batch(data_dict, itr % data_size, max_masks)
        if (masks.shape[0] == 0):
            continue
        # visualize_entry(image, masks, input_point)

        # Segment the image using SAM
        predictor.set_image(image)  # apply SAM image encoder to the image

        # Prompt encoding
        mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(
            input_point, input_label, box=None, mask_logits=None, normalize_coords=True
            )
        sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(
            points=(unnorm_coords, labels), boxes=None, masks=None
            )

        # Mask decoder
        batched_mode = unnorm_coords.shape[0] > 1  # multi object prediction
        high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
        low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
            image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
            image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=True,
            repeat_image=batched_mode,
            high_res_features=high_res_features
            )

        # Upscale the masks to the original image resolution
        prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])

        # Segmentation Loss calculation
        gt_mask = torch.tensor((masks / 255).astype(np.float32)).cuda()
        prd_mask = torch.sigmoid(prd_masks[:, 0])  # Turn logit map to probability map
        seg_loss = (-gt_mask * torch.log(prd_mask + 0.00001) - (1 - gt_mask) * torch.log((1 - prd_mask) + 0.00001)).mean()

        # Score loss calculation (intersection over union) IoU
        inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
        iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)
        score_loss = torch.abs(prd_scores[:, 0] - iou).mean()
        loss = seg_loss + score_loss * 0.05  # mix losses

        # Backpropagation
        predictor.model.zero_grad()  # empty gradient
        scaler.scale(loss).backward()  # Backpropagate
        scaler.step(optimizer)
        scaler.update()  # Mix precision

        if (itr % 500 == 0):
            torch.save(predictor.model.state_dict(), f"model{itr}.torch")
            print("Saved model.")

        mean_iou = mean_iou * 0.99 + 0.01 * np.mean(iou.cpu().detach().numpy())
        if (itr % 100 == 0):
            print(f"step {itr} Accuracy (IoU) = {mean_iou}")

        # visualize_training_results(image, masks, prd_masks, mean_iou, itr)

In [None]:
!mkdir /kaggle/working/sam2-fine-tune/results

In [None]:
import numpy as np
import torch
import cv2
import hydra
import matplotlib.pyplot as plt
import os
import time
from PIL import Image
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.test_helper import test_generator
from sam2.train_helper import cleanup

cleanup()

# Configurations
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

hydra.core.global_hydra.GlobalHydra.instance().clear()
hydra.initialize_config_module('sam2', version_base='1.2')

sam2_model = build_sam2(
    config_file="../sam2_configs/sam2_hiera_t.yaml",
    ckpt_path="/kaggle/input/segment-anything-2/pytorch/sam2-hiera-tiny/1/sam2_hiera_tiny.pt",
    device="cuda",
    apply_postprocessing=False
)

mask_generator = SAM2AutomaticMaskGenerator(
    model=sam2_model,
    points_per_side=128,
    points_per_batch=32,
    pred_iou_thresh=0.7,
    stability_score_thresh=0.88,
    stability_score_offset=1.0,
    mask_threshold=0.0,
    box_nms_thresh=0.7,
    crop_n_layers=2,
    crop_nms_thresh=0.7,
    # crop_overlap_ratio=0.5,
    crop_n_points_downscale_factor=2,
    # point_grids=None,
    min_mask_region_area=25.0,
    # output_mode="binary_mask",
    use_m2m=False,
    # multimask_output=True,
    # load_model="/kaggle/input/ancient-model-2k-cropped/pytorch/default/1/model3500ancient_2k_cropped.torch"
)

start_time = time.time()
test_generator(
    mask_generator=mask_generator,
    img_path="/kaggle/input/mosaic-images/butterfly.jpg",
    output_path=f"/kaggle/working/sam2-fine-tune/results/masks_{time.time()}.png",
    rows=1,
    cols=1,
    max_mask_crop_region=0.1,
    show_masks=True
)
print(f"Time taken: {time.time() - start_time}")


In [None]:
!pip install optuna


In [None]:
import optuna
import numpy as np
import time
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.test_helper import test_generator
from evaluate import *

# Hyperparameter optimization using Bayesian Optimization
def objective(trial):
    # Sample hyperparameters from the search space
    points_per_side = 128
    points_per_batch = 32
    pred_iou_thresh = trial.suggest_float('pred_iou_thresh', 0.5, 0.9)
    stability_score_thresh = trial.suggest_float('stability_score_thresh', 0.7, 0.95)
    stability_score_offset = trial.suggest_float('stability_score_offset', 0.7, 1.2)
    mask_threshold = trial.suggest_float('mask_threshold', 0.0, 0.6)
    box_nms_thresh = 0.7
    crop_n_layers = 2
    crop_nms_thresh = 0.7
    crop_overlap_ratio = 0.3
    crop_n_points_downscale_factor = 2
    min_mask_region_area = 25.0
    use_m2m = False

    # Set up the model with the sampled hyperparameters
    sam2_model = build_sam2(
        config_file="../sam2_configs/sam2_hiera_t.yaml",
        ckpt_path="/kaggle/input/segment-anything-2/pytorch/sam2-hiera-tiny/1/sam2_hiera_tiny.pt",
        device="cuda",
        apply_postprocessing=False
    )

    mask_generator = SAM2AutomaticMaskGenerator(
        model=sam2_model,
        points_per_side=points_per_side,
        points_per_batch=points_per_batch,
        pred_iou_thresh=pred_iou_thresh,
        stability_score_thresh=stability_score_thresh,
        stability_score_offset=stability_score_offset,
        mask_threshold=mask_threshold,
        box_nms_thresh=box_nms_thresh,
        crop_n_layers=crop_n_layers,
        crop_nms_thresh=crop_nms_thresh,
        crop_overlap_ratio=crop_overlap_ratio,
        crop_n_points_downscale_factor=crop_n_points_downscale_factor,
        min_mask_region_area=min_mask_region_area,
        use_m2m=use_m2m
    )

    # Run the model and evaluate with your metrics function
    img_path = "/kaggle/input/evaluation-dataset/images_set/butterfly.jpg"
    output_path = "/kaggle/working/sam2-fine-tune/results/butterfly.png"

    start_time = time.time()
    test_generator(
        mask_generator=mask_generator,
        img_path=img_path,
        output_path=output_path,
        rows=1,
        cols=1,
        max_mask_crop_region=0.1,
        show_masks=False
    )
    print(f"Test run took {time.time() - start_time} seconds")

    # Use your metrics function here to calculate the IoU (or any other metric)
    gt, pred = read_masks("/kaggle/input/evaluation-dataset/masks_set/butterfly.png", output_path)
    metrics = evaluate_pred(gt, pred)  # Assuming you have this function
    iou_score = metrics['IoU']  # Example: IoU score

    return iou_score  # Return the IoU as the objective function to maximize

# Example: Use Optuna to perform Bayesian Optimization
study = optuna.create_study(direction="maximize")  # 'maximize' since we want to maximize IoU
study.optimize(objective, n_trials=20)  # Number of trials to run

# Get the best hyperparameters and the best score
print("Best Hyperparameters:", study.best_params)
print("Best IoU Score:", study.best_value)


[I 2025-03-26 16:28:56,077] A new study created in memory with name: no-name-f36d6655-90b7-46da-b395-56993241e75b


Processing 1 of 1
1861 masks found


[I 2025-03-26 16:32:15,662] Trial 0 finished with value: 0.94 and parameters: {'pred_iou_thresh': 0.7701764472168076, 'stability_score_thresh': 0.9195053653804519, 'stability_score_offset': 0.7873863652006368, 'mask_threshold': 0.2545665866945082}. Best is trial 0 with value: 0.94.


Final stitched segmentation saved as /kaggle/working/sam2-fine-tune/results/butterfly.png
Test run took 198.76357698440552 seconds
Processing 1 of 1
2328 masks found


[I 2025-03-26 16:35:47,378] Trial 1 finished with value: 0.914 and parameters: {'pred_iou_thresh': 0.5916063441187859, 'stability_score_thresh': 0.752410173944368, 'stability_score_offset': 1.0745639640562423, 'mask_threshold': 0.1012628398893242}. Best is trial 0 with value: 0.94.


Final stitched segmentation saved as /kaggle/working/sam2-fine-tune/results/butterfly.png
Test run took 210.9339189529419 seconds
Processing 1 of 1
2572 masks found


[I 2025-03-26 16:39:21,907] Trial 2 finished with value: 0.868 and parameters: {'pred_iou_thresh': 0.5265245145322136, 'stability_score_thresh': 0.7096591671018168, 'stability_score_offset': 0.9601039633893467, 'mask_threshold': 0.5740862281086649}. Best is trial 0 with value: 0.94.


Final stitched segmentation saved as /kaggle/working/sam2-fine-tune/results/butterfly.png
Test run took 213.8806402683258 seconds
Processing 1 of 1
2411 masks found


[I 2025-03-26 16:42:52,183] Trial 3 finished with value: 0.906 and parameters: {'pred_iou_thresh': 0.5483758716429645, 'stability_score_thresh': 0.768688191433413, 'stability_score_offset': 0.9273714394962258, 'mask_threshold': 0.5806184031358964}. Best is trial 0 with value: 0.94.


Final stitched segmentation saved as /kaggle/working/sam2-fine-tune/results/butterfly.png
Test run took 209.63797640800476 seconds
Processing 1 of 1
1810 masks found


[I 2025-03-26 16:46:10,072] Trial 4 finished with value: 0.943 and parameters: {'pred_iou_thresh': 0.7467155072206236, 'stability_score_thresh': 0.927620940299105, 'stability_score_offset': 0.9176923697930055, 'mask_threshold': 0.40962544501462417}. Best is trial 4 with value: 0.943.


Final stitched segmentation saved as /kaggle/working/sam2-fine-tune/results/butterfly.png
Test run took 197.254376411438 seconds
Processing 1 of 1
1926 masks found


[I 2025-03-26 16:49:28,878] Trial 5 finished with value: 0.934 and parameters: {'pred_iou_thresh': 0.8155265038661501, 'stability_score_thresh': 0.7995016710168169, 'stability_score_offset': 1.0660359342025623, 'mask_threshold': 0.12795675378904883}. Best is trial 4 with value: 0.943.


Final stitched segmentation saved as /kaggle/working/sam2-fine-tune/results/butterfly.png
Test run took 198.17264223098755 seconds
Processing 1 of 1
1933 masks found


[I 2025-03-26 16:52:48,394] Trial 6 finished with value: 0.935 and parameters: {'pred_iou_thresh': 0.8039072283559566, 'stability_score_thresh': 0.8288776310084137, 'stability_score_offset': 0.9621058808445495, 'mask_threshold': 0.33484133596334326}. Best is trial 4 with value: 0.943.


Final stitched segmentation saved as /kaggle/working/sam2-fine-tune/results/butterfly.png
Test run took 198.7157633304596 seconds
Processing 1 of 1
1980 masks found


[I 2025-03-26 16:56:11,097] Trial 7 finished with value: 0.934 and parameters: {'pred_iou_thresh': 0.6735528186033698, 'stability_score_thresh': 0.8537786659419957, 'stability_score_offset': 1.1267202192387926, 'mask_threshold': 0.2718032845741301}. Best is trial 4 with value: 0.943.


Final stitched segmentation saved as /kaggle/working/sam2-fine-tune/results/butterfly.png
Test run took 201.9125382900238 seconds
Processing 1 of 1
1986 masks found


[I 2025-03-26 16:59:34,856] Trial 8 finished with value: 0.933 and parameters: {'pred_iou_thresh': 0.5315011894960423, 'stability_score_thresh': 0.9258735488846309, 'stability_score_offset': 0.7143012737426908, 'mask_threshold': 0.08978692148500318}. Best is trial 4 with value: 0.943.


Final stitched segmentation saved as /kaggle/working/sam2-fine-tune/results/butterfly.png
Test run took 202.97058749198914 seconds
Processing 1 of 1
2066 masks found


[I 2025-03-26 17:02:58,055] Trial 9 finished with value: 0.933 and parameters: {'pred_iou_thresh': 0.747879942377363, 'stability_score_thresh': 0.7889945858826082, 'stability_score_offset': 0.842475454013801, 'mask_threshold': 0.433435527557335}. Best is trial 4 with value: 0.943.


Final stitched segmentation saved as /kaggle/working/sam2-fine-tune/results/butterfly.png
Test run took 202.55713057518005 seconds
Processing 1 of 1
1775 masks found


[I 2025-03-26 17:06:12,860] Trial 10 finished with value: 0.937 and parameters: {'pred_iou_thresh': 0.8706528575203315, 'stability_score_thresh': 0.8773481888588376, 'stability_score_offset': 1.182364167815913, 'mask_threshold': 0.4422149529318393}. Best is trial 4 with value: 0.943.


Final stitched segmentation saved as /kaggle/working/sam2-fine-tune/results/butterfly.png
Test run took 194.1377341747284 seconds
Processing 1 of 1
1812 masks found


[I 2025-03-26 17:09:32,085] Trial 11 finished with value: 0.942 and parameters: {'pred_iou_thresh': 0.6843746051872427, 'stability_score_thresh': 0.9414298230112895, 'stability_score_offset': 0.7725378413309506, 'mask_threshold': 0.2764547643376969}. Best is trial 4 with value: 0.943.


Final stitched segmentation saved as /kaggle/working/sam2-fine-tune/results/butterfly.png
Test run took 198.55978441238403 seconds
Processing 1 of 1
