In [None]:
#%%     **************** Training and validation *****************

# This script is designed to run the SAM2.1 training and validation process.
import os
import yaml
import subprocess

# Set working directory to the same one you use in terminal
os.chdir("")

# Define configuration relative to that directory
yaml_file = "configs/sam2.1_training/sam_train_val_json_win.yaml"

def load_yaml_config(yaml_path):
    try:
        # Construct the absolute path
        abs_path = os.path.join("insert_path", yaml_path)
        with open(abs_path, "r") as f:
            cfg = yaml.safe_load(f)
        return cfg
    except Exception as e:
        print(f"Error loading YAML: {e}")
        return None

def run_training(config_path):
    print(f"Launching training with config: {config_path}")
    try:
        process = subprocess.Popen([
            "python", "training/train.py",
            "-c", config_path,
            "--use-cluster", "0",
            "--num-gpus", "1"
        ], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, universal_newlines=True)
    
        # Stream output live
        for line in process.stdout:
            print(line, end="")  # Already includes newline characters
    
        process.wait()
        if process.returncode != 0:
            raise subprocess.CalledProcessError(process.returncode, process.args)
    
    except subprocess.CalledProcessError as e:
        print("Error during training:")
        # Since we're streaming, there might not be a stored error output.
        raise

if __name__ == "__main__":
    config_path = yaml_file  
    config = load_yaml_config(config_path)
    print("Loaded Configuration:", config)
    run_training(config_path)