# Installation

In [None]:
!git clone https://github.com/ckapelonis02/sam2-fine-tune.git
%cd /kaggle/working/sam2-fine-tune

In [None]:
%pip install -e .

## Error:
If the following error appears:
```python
/kaggle/working/sam2-fine-tune/sam2/utils/misc.py in get_connected_components(mask)
     59               components for foreground pixels and 0 for background pixels.
     60     """
---> 61     from sam2 import _C
     62 
     63     return _C.get_connected_componnets(mask.to(torch.uint8).contiguous())

ImportError: /kaggle/working/sam2-fine-tune/sam2/_C.so: undefined symbol
```
Run the cell below.

In [None]:
!python setup.py build_ext --inplace

# Training

In [None]:
import torch
import numpy as np
import random
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from sam2.train_helper import *
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from datetime import datetime

cleanup()

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

hydra.core.global_hydra.GlobalHydra.instance().clear()
hydra.initialize_config_module('sam2', version_base='1.2')

sam2_model = build_sam2(
    config_file="../sam2_configs/sam2_hiera_l.yaml",
    ckpt_path="/kaggle/input/segment-anything-2/pytorch/sam2-hiera-large/1/sam2_hiera_large.pt",
    device="cuda",
    apply_postprocessing=False
)

predictor = SAM2ImagePredictor(sam2_model)
predictor.model.sam_mask_decoder.train(True)
predictor.model.sam_prompt_encoder.train(True)
# predictor.model.image_encoder.train(True)

optimizer = optim.AdamW(predictor.model.parameters(), lr=1e-5, weight_decay=1e-5)
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)

train_sam2(
    images_path="/kaggle/input/mosaic-training-dataset/real_data_1k/augmented_images",
    masks_path="/kaggle/input/mosaic-training-dataset/real_data_1k/augmented_masks",
    data_size=984,
    epochs=10,
    grad_steps=4,
    log_dir=f"runs/sam2_training_{datetime.now().strftime('%Y%m%d_%H%M%S')}",
    predictor=predictor,
    optimizer=optimizer,
    scheduler=scheduler,
    seed=22,
    train_percentage=0.8,
    score_weight=0.2,
    config_file="../sam2_configs/sam2_hiera_l.yaml",
    ckpt_path="/kaggle/input/segment-anything-2/pytorch/sam2-hiera-large/1/sam2_hiera_large.pt",
    points_per_side=16,
    points_per_batch=4,
    pred_iou_thresh=0.7,
    stability_score_thresh=0.7,
    stability_score_offset=1.0,
    mask_threshold=0.5,
    output_path="val_results",
    crops_csv_file="/kaggle/input/validation-dataset/validate_crops.csv",
    gt_path="/kaggle/input/validation-dataset/masks"
)

# Testing

In [None]:
import torch
import hydra
import os
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.test_helper import *
from sam2.train_helper import cleanup

cleanup()

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:256"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

hydra.core.global_hydra.GlobalHydra.instance().clear()
hydra.initialize_config_module('sam2', version_base='1.2')

sam2_model = build_sam2(
    config_file="../sam2_configs/sam2_hiera_l.yaml",
    ckpt_path="/kaggle/input/segment-anything-2/pytorch/sam2-hiera-large/1/sam2_hiera_large.pt",
    device="cuda",
    apply_postprocessing=False
)

mask_generator = SAM2AutomaticMaskGenerator(
    model=sam2_model,
    points_per_side=32,
    points_per_batch=4,
    pred_iou_thresh=0.7346384369116341,
    stability_score_thresh=0.8644191593971671,
    stability_score_offset=0.90922774537074,
    mask_threshold=0.38300043389010024,
    box_nms_thresh=0.7,
    crop_n_layers=2,
    crop_nms_thresh=0.7,
    crop_overlap_ratio=0.3,
    crop_n_points_downscale_factor=2,
    point_grids=None,
    min_mask_region_area=25.0,
    output_mode="binary_mask",
    use_m2m=False,
    multimask_output=True,
    load_model="/kaggle/input/run-7b/pytorch/default/1/best_model_9.torch"
)

test_generator(
    mask_generator=mask_generator,
    images_path=f"/kaggle/input/evaluation-dataset/evaluation_dataset/images_set",
    output_path=f"/kaggle/working/sam2-fine-tune/results/",
    crops_csv_file="/kaggle/working/sam2-fine-tune/sam2/testing_crops.csv",
    max_mask_crop_region=0.1,
    show_masks=False
)

# Bayesian Optimization

In [None]:
import optuna
import logging
import time
from sam2.build_sam import build_sam2
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
from sam2.test_helper import test_generator
from evaluate import *
import os

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def objective(trial):
    points_per_side = 32
    points_per_batch = 4
    pred_iou_thresh = trial.suggest_float('pred_iou_thresh', 0.5, 0.9)
    stability_score_thresh = trial.suggest_float('stability_score_thresh', 0.6, 0.9)
    stability_score_offset = trial.suggest_float('stability_score_offset', 0.7, 1.2)
    mask_threshold = trial.suggest_float('mask_threshold', 0.0, 0.6)
    box_nms_thresh = 0.7
    crop_n_layers = 2
    crop_nms_thresh = 0.7
    crop_overlap_ratio = 0.3
    crop_n_points_downscale_factor = 2
    min_mask_region_area = 25.0
    use_m2m = False

    logger.info(f"Trial {trial.number}: Starting trial with parameters: "
                f"pred_iou_thresh={pred_iou_thresh}, stability_score_thresh={stability_score_thresh}, "
                f"stability_score_offset={stability_score_offset}, mask_threshold={mask_threshold}")

    sam2_model = build_sam2(
        config_file="../sam2_configs/sam2_hiera_l.yaml",
        ckpt_path="/kaggle/input/segment-anything-2/pytorch/sam2-hiera-large/1/sam2_hiera_large.pt",
        device="cuda",
        apply_postprocessing=False
    )

    mask_generator = SAM2AutomaticMaskGenerator(
        model=sam2_model,
        points_per_side=points_per_side,
        points_per_batch=points_per_batch,
        pred_iou_thresh=pred_iou_thresh,
        stability_score_thresh=stability_score_thresh,
        stability_score_offset=stability_score_offset,
        mask_threshold=mask_threshold,
        box_nms_thresh=box_nms_thresh,
        crop_n_layers=crop_n_layers,
        crop_nms_thresh=crop_nms_thresh,
        crop_overlap_ratio=crop_overlap_ratio,
        crop_n_points_downscale_factor=crop_n_points_downscale_factor,
        min_mask_region_area=min_mask_region_area,
        use_m2m=use_m2m,
        load_model="/kaggle/working/sam2-fine-tune/best_model_9.torch"
    )

    start_time = time.time()
    logger.info(f"Trial {trial.number}: Running test_generator...")

    test_generator(
        mask_generator=mask_generator,
        images_path=f"/kaggle/input/evaluation-dataset/evaluation_dataset/images_set",
        output_path=f"/kaggle/working/sam2-fine-tune/results/",
        crops_csv_file="/kaggle/input/optimizer/optimizing_crops.csv",
        max_mask_crop_region=0.1,
        show_masks=False
    )

    elapsed_time = time.time() - start_time
    logger.info(f"Trial {trial.number}: Test run took {elapsed_time} seconds")

    gt, pred = read_masks(
        f"/kaggle/input/evaluation-dataset/evaluation_dataset/masks_set/chicken.png",
        f"/kaggle/working/sam2-fine-tune/results/chicken.png"
    )
    metrics_1 = evaluate_pred(gt, pred)

    gt, pred = read_masks(
        f"/kaggle/input/evaluation-dataset/evaluation_dataset/masks_set/rabbit.png",
        f"/kaggle/working/sam2-fine-tune/results/rabbit.png"
    )
    metrics_2 = evaluate_pred(gt, pred)

    avg_dice = (metrics_1['Dice Coefficient'] + metrics_2['Dice Coefficient'])/2
    logger.info(f"Trial {trial.number}: Dice Score = {avg_dice}")

    return avg_dice

study = optuna.create_study(storage="sqlite:///db.sqlite3", direction="maximize")
logger.info("Starting optimization process...")
study.optimize(objective, n_trials=10)

logger.info(f"Optimization complete. Best hyperparameters: {study.best_params}")
logger.info(f"Best mean score: {study.best_value}")