In [0]:
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.

# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

In [0]:
%sh pip install git+https://github.com/facebookresearch/segment-anything-2

In [0]:
%%bash
[ -d checkpoints ] && [ "$(ls -A checkpoints)" ] && exit 0

mkdir -p checkpoints
cd checkpoints
curl -O https://raw.githubusercontent.com/facebookresearch/sam2/main/checkpoints/download_ckpts.sh
bash download_ckpts.sh || true

In [0]:
import os
import mlflow
from argparse import ArgumentParser
from hydra import initialize_config_dir
from hydra.core.global_hydra import GlobalHydra
from sam2.build_sam import build_sam2
from training.train import main as train_main
from training.utils.train_utils import register_omegaconf_resolvers
from dataset_4channels import SnuplassDataset
from inference import vos_inference
from image_predictor import SAM2ImagePredictor

In [0]:
GlobalHydra.instance().clear()


if __name__ == "__main__":
    mlflow.pytorch.autolog()

    # Initialize Hydra config directory
    initialize_config_dir(config_dir=os.getcwd(), version_base="1.3")

    parser = ArgumentParser()
    parser.add_argument(
        "-c", "--config", default="./configs/sam2.1_training/sam2.1_4channels_finetune.yaml", type=str,
        help="path to config file",
    )
    parser.add_argument(
        "--use-cluster", type=int, default=0,
        help="0: run locally, 1: run on a cluster",
    )
    parser.add_argument("--partition", type=str, default=None, help="SLURM partition")
    parser.add_argument("--account", type=str, default=None, help="SLURM account")
    parser.add_argument("--qos", type=str, default=None, help="SLURM qos")
    parser.add_argument("--num-gpus", type=int, default=1, help="number of GPUS per node")
    parser.add_argument("--num-nodes", type=int, default=1, help="Number of nodes")
    args, unknown = parser.parse_known_args()

    # Register resolvers for omegaconf
    register_omegaconf_resolvers()

    with mlflow.start_run():
        train_main(args)

        model_cfg = "./configs/sam2.1/sam2.1_hiera_t_4channels.yaml"
        model_ckpt = "./sam2_logs/configs/sam2.1_training/sam2.1_4channels_finetune.yaml/checkpoints/checkpoint.pt"
        sam2_model = build_sam2(model_cfg, model_ckpt, device="cuda")
        predictor = SAM2ImagePredictor(sam2_model)

        image_dir = "/Volumes/land_topografisk-gdb_dev/external_dev/static_data/DL_SNUPLASSER/img/"
        dom_dir = "/Volumes/land_topografisk-gdb_dev/external_dev/static_data/DL_SNUPLASSER/dom/"
        mask_dir = "/Volumes/land_topografisk-gdb_dev/external_dev/static_data/DL_SNUPLASSER/lab/"
        dataset = SnuplassDataset(image_dir, mask_dir, dom_dir, "val")
        vos_inference(predictor, dataset, 4)