## Import libraries

In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import logging
import random
import sys
import traceback

import pandas as pd
import numpy as np

from sam2.build_sam import build_sam2_video_predictor
import nibabel as nib
from PIL import Image
import nrrd
import yaml

from training.train import *
import submitit
import torch

from hydra import compose, initialize_config_module, initialize
from hydra.utils import instantiate, get_original_cwd
from hydra.core.global_hydra import GlobalHydra

from iopath.common.file_io import g_pathmgr
from omegaconf import OmegaConf

from training.utils.train_utils import makedir, register_omegaconf_resolvers

import argparse
import sys

os.environ["HYDRA_FULL_ERROR"] = "1"

## Predictor

In [2]:
class SAM2AnnotationTool:
    def __init__(
        self,
        ckpt_folder="checkpoints",
        cfg_folder="sam2/configs/sam2.1_training",
        ckpt_file="sam2.1_hiera_tiny.pt",
        cfg_file="lstm_sam2.1_hiera_t.yaml",
        search_path="configs/sam2.1_training",
        verbose=True
    ):
        self.search_path = search_path
        self.current_directory = os.path.dirname(os.getcwd())
        self.sam2_checkpoint = os.path.join(self.current_directory, ckpt_folder, ckpt_file)
        self.model_cfg = os.path.join(self.current_directory, cfg_folder, cfg_file)
        self.model_cfg_search_path = os.path.join(search_path, cfg_file)
        
        self.verbose = verbose
        if self.verbose:
            print("Welcome to SAM 2 Annotation Tool")

    def load_data(
        self,
        img_path,
        mask_path,
        data_type="npy",
        return_vol=True
    ):
        assert data_type in ["npy", "nib"]
        
        if data_type == "nib":
            self.img_vol = nib.load(img_path)
            self.img_vol_data = self.img_vol.get_fdata()
    
            self.mask_vol_data, self.mask_header = nrrd.read(mask_path)
            self.name_value_dict = self.get_label_name_value()
        elif data_type == "npy":
            self.img_vol_data = np.load(img_path)
            self.mask_vol_data = np.load(mask_path)

        self.length = self.img_vol_data.shape[2]

        if self.verbose:
            print("-" * 30)
            print(f"Image path: {img_path}")
            print(f"Mask path: {mask_path}")
            if data_type == "nib":
                print(f"Labels: {self.name_value_dict}")
            print("Load data successfully!")
            print("-" * 30)

        if return_vol:
            return self.img_vol_data, self.mask_vol_data

    def get_label_name_value(self):
        i = 0
        name_value_dict = {}
        while True:
            if f"Segment{i}_LabelValue" in self.mask_header:
                name_value_dict[self.mask_header[f"Segment{i}_Name"]] = int(self.mask_header[f"Segment{i}_LabelValue"])
                i = i + 1
            else:
                break
        return name_value_dict

    def get_label_value_by_name(self, label_name):
        return self.name_value_dict[label_name]

    def get_label_name_by_value(self, label_value):
        return [key for key, value in self.name_value_dict.items() if value == label_value][0]
        
    def preprocess_data_to_sam2_format(self, data_save_directory, img_save_directory, mask_save_directory, volume_name, img_vol_data=None, mask_vol_data=None, verbose=True):
        os.makedirs(os.path.join(img_save_directory, volume_name), exist_ok=True)
        os.makedirs(os.path.join(mask_save_directory, volume_name), exist_ok=True)

        if img_vol_data is not None:
            self.img_vol_data = img_vol_data
            self.mask_vol_data = mask_vol_data

        self.data_save_directory = data_save_directory
        self.img_save_directory = img_save_directory
        self.mask_save_directory = mask_save_directory
        self.txt_path = os.path.join(data_save_directory, "train.txt")
        
        for i in range(self.length):
            img_ = Image.fromarray(self.img_vol_data[i].astype("uint8"))
            mask_ = Image.fromarray(self.mask_vol_data[i].astype("uint8"))

            img_.save(os.path.join(img_save_directory, volume_name, f"{i:05d}.jpg"))
            mask_.save(os.path.join(mask_save_directory, volume_name, f"{i:05d}.png"))

        with open(self.txt_path, "w") as f:
            f.write(volume_name)
            
        if verbose:
            print("-" * 30)
            print(f"Preprocess {i} slices into SAM 2 format.")
            print(f"Img save directory: {img_save_directory}")
            print(f"Mask save directory: {mask_save_directory}")
            print("-" * 30)

    def update_yaml(self, verbose=True):
        # with open(self.model_cfg, "r") as file:
        #     self.config = yaml.safe_load(file)

        ## Update more parameters if wanted
        # self.config["dataset"]["img_folder"] = self.img_save_directory
        # self.config["dataset"]["gt_folder"] = self.mask_save_directory
        # self.config["dataset"]["file_list_txt"] = self.txt_path

        self.config = OmegaConf.load(self.model_cfg)

        self.config.dataset.img_folder = self.img_save_directory
        self.config.dataset.gt_folder = self.mask_save_directory
        self.config.dataset.file_list_txt = self.txt_path

        if verbose:
            print("-" * 30)
            print(f"Update image folder: {self.img_save_directory}.")
            print(f"Update mask folder: {self.mask_save_directory}.")
            print(f"Update txt path: {self.txt_path}.")
            print("-" * 30)

        # with open(self.model_cfg, 'w') as file:
        #     yaml.safe_dump(self.config, file)

        OmegaConf.save(self.config, self.model_cfg)

    def flatten_config(cfg, target_key="scratch"):
        """Finds and extracts the subdict that contains `target_key` at any depth."""
        if target_key in cfg:
            return cfg[target_key]
        
        for key, value in cfg.items():
            if isinstance(value, dict) and target_key in value:
                return value[target_key]  # Extract directly
    
        return cfg  # Return original if key not found

    def train(
        self, 
        data_save_directory, 
        img_save_directory, 
        mask_save_directory, 
        volume_name, 
        img_vol_data=None, 
        mask_vol_data=None,
        config=None,
        use_cluster=0,
        partition=None,
        account=None,
        qos=None,
        num_gpus=1,
        num_nodes=None
    ):
        self.preprocess_data_to_sam2_format(
            data_save_directory=data_save_directory,
            img_save_directory=img_save_directory,
            mask_save_directory=mask_save_directory,
            volume_name=volume_name,
            img_vol_data=img_vol_data,
            mask_vol_data=mask_vol_data
        )

        self.update_yaml()

        if config is None:
            config = self.model_cfg_search_path

        if GlobalHydra.instance().is_initialized():
            GlobalHydra.instance().clear()
        initialize_config_module(config_module="sam2", version_base="1.2")
        sys.argv = ["notebook"]

        # Ensure config is provided
        if config is not None:
            sys.argv.extend(["--config", str(config)])
        
        if use_cluster is not None:
            sys.argv.extend(["--use-cluster", str(use_cluster)])
        
        if partition is not None:
            sys.argv.extend(["--partition", str(partition)])
        
        if account is not None:
            sys.argv.extend(["--account", str(account)])
        
        if qos is not None:
            sys.argv.extend(["--qos", str(qos)])
        
        if num_gpus is not None:
            sys.argv.extend(["--num-gpus", str(num_gpus)])
        
        if num_nodes is not None:
            sys.argv.extend(["--num-nodes", str(num_nodes)])

        parser = ArgumentParser()
        parser.add_argument(
            "-c",
            "--config",
            required=True,
            type=str,
            help="path to config file (e.g. configs/sam2.1_training/sam2.1_hiera_b+_MOSE_finetune.yaml)",
        )
        parser.add_argument(
            "--use-cluster",
            type=int,
            default=None,
            help="whether to launch on a cluster, 0: run locally, 1: run on a cluster",
        )
        parser.add_argument("--partition", type=str, default=None, help="SLURM partition")
        parser.add_argument("--account", type=str, default=None, help="SLURM account")
        parser.add_argument("--qos", type=str, default=None, help="SLURM qos")
        parser.add_argument(
            "--num-gpus", type=int, default=None, help="number of GPUS per node"
        )
        parser.add_argument("--num-nodes", type=int, default=None, help="Number of nodes")

        try:
            self.args, self.unknown = parser.parse_known_args()
        except SystemExit as e:
            print(f"Argument parsing failed: {e}")
        self.args.use_cluster = bool(self.args.use_cluster) if self.args.use_cluster is not None else None
        register_omegaconf_resolvers()
        main(self.args)

## Test

In [3]:
tool = SAM2AnnotationTool()

Welcome to SAM 2 Annotation Tool


In [4]:
# tool.register_model()
img_vol_data, mask_vol_data = tool.load_data(
    img_path="../breast_test_data/img.npy",
    mask_path="../breast_test_data/mask.npy"
)

# tool.preprocess_data_to_sam2_format(
#     data_save_directory="./breast_test_data/",
#     img_save_directory="./breast_test_data/images/",
#     mask_save_directory="./breast_test_data/masks/",
#     volume_name="volume1"
# )

# tool.update_yaml()

------------------------------
Image path: ../breast_test_data/img.npy
Mask path: ../breast_test_data/mask.npy
Load data successfully!
------------------------------


In [5]:
tool.train(
    data_save_directory="../breast_test_data/",
    img_save_directory="../breast_test_data/images/",
    mask_save_directory="../breast_test_data/masks/",
    volume_name="volume1",
    img_vol_data=img_vol_data,
    mask_vol_data=mask_vol_data,
    config=None,
    use_cluster=0,
    partition=None,
    account=None,
    qos=None,
    num_gpus=1,
    num_nodes=None
)

------------------------------
Preprocess 207 slices into SAM 2 format.
Img save directory: ../breast_test_data/images/
Mask save directory: ../breast_test_data/masks/
------------------------------
------------------------------
Update image folder: ../breast_test_data/images/.
Update mask folder: ../breast_test_data/masks/.
Update txt path: ../breast_test_data/train.txt.
------------------------------


MissingConfigException: Cannot find primary config 'configs/sam2.1_training/lstm_sam2.1_hiera_t.yaml'. Check that it's in your config search path.

Config search path:
	provider=hydra, path=pkg://hydra.conf
	provider=main, path=pkg://sam2
	provider=schema, path=structured://