In [None]:
import os
import sys
import re
import yaml

train_dataset_size = {
    # "0_5k": 0.05, 
    # "1k": 0.1, 
    # "5k": 0.5, 
    # "10k": 1.0
    "1k": 1.0,
}

messages_field_keys = {
    "deepseek": "messages", 
    "sos_react": "messages", 
    "sos": "messages", 
    "optimal": "messages"
    # "optimal": "messages_optimal",
    # "search": "messages_sos",
    # "search-react": "messages_sos_react",
    # "deepseek_r1_distill_llama_70b": "messages_deepseek_r1_distill_llama_70b",
    # "deepseek": "messages_deepseek",
}

def generate_config(base_config_path: str, dataset_size: str, traj_type: str, dataset_name: str = "yeok/stream-of-search-dataset_deepseek"):
    # Read the original file to extract content
    with open(base_config_path, "r") as f:
        content = f.read()
    
    # Parse the YAML content
    # Split by the comment section to separate modifiable fields
    sections = re.split(r"# These fields will be modified by the script", content, 1)
    
    if len(sections) > 1:
        first_part = sections[0]
        second_part = sections[1]
        
        # Extract chat template section
        chat_template_match = re.search(r"chat_template:\s*\|[\s\S]+?(?=\n# These fields|$)", content)
        chat_template = chat_template_match.group(0) if chat_template_match else ""
        
        # Remove chat template from first part if it exists there
        if chat_template:
            first_part = first_part.replace(chat_template, "").strip()
        
        # Load the first part as YAML (the non-modifiable part)
        config1 = yaml.safe_load(first_part) or {}
        
        # Load modifiable fields from second part as YAML 
        # but we'll only use this to get fields we're not explicitly modifying
        config2_text = second_part.split("# SFT trainer config", 1)
        modifiable_part = config2_text[0] if len(config2_text) > 0 else ""
        config2 = yaml.safe_load(modifiable_part) or {}
        
        # Load the trainer config part
        trainer_config_part = "# SFT trainer config" + config2_text[1] if len(config2_text) > 1 else ""
        trainer_config = yaml.safe_load(trainer_config_part.replace("# SFT trainer config", "")) or {}
        
        # Generate model name and related paths
        model_name = f"qwen-2.5-1.5B-instruct-sft-lora-countdown-{traj_type}-{dataset_size}"
        
        # Create the set of modifiable fields we want to update
        updated_modifiable_fields = {
            "dataset_mixer": {dataset_name: float(train_dataset_size[dataset_size])},
            "dataset_splits": config2.get("dataset_splits", ["train", "test"]),
            "preprocessing_num_workers": config2.get("preprocessing_num_workers", 12),
            "dataset_message_key": messages_field_keys[traj_type],
        }
        
        # Update trainer config fields
        updated_trainer_config = trainer_config.copy()
        updated_trainer_config["hub_model_id"] = model_name
        updated_trainer_config["output_dir"] = f"./models/{model_name}"
        updated_trainer_config["logging_dir"] = f"./logs/{model_name}"
        
        # Create output directory path
        output_dir = os.path.dirname(os.path.dirname(base_config_path))
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        output_file = f"{output_dir}/config_{traj_type}_{dataset_size}.yaml"
        
        # Write the combined content to the output file
        with open(output_file, "w") as f:
            # Write the first part (non-modifiable fields)
            f.write(yaml.dump(config1, default_flow_style=False, sort_keys=False).strip() + "\n")
            
            # Add chat template if it was before the modifiable fields
            if chat_template_match and content.find(chat_template) < content.find("# These fields will be modified by the script"):
                f.write(chat_template + "\n\n")
            
            # Write the comment for modifiable fields
            f.write("# These fields will be modified by the script\n")
            
            # Write the updated modifiable fields
            f.write(yaml.dump(updated_modifiable_fields, default_flow_style=False, sort_keys=False))
            
            # Write SFT trainer config
            f.write("\n# SFT trainer config\n")
            f.write(yaml.dump(updated_trainer_config, default_flow_style=False, sort_keys=False))
            
            # Add chat template if it was after the modifiable fields
            if chat_template_match and content.find(chat_template) > content.find("# These fields will be modified by the script"):
                f.write("\n" + chat_template)
    
    else:
        # If the file doesn't have the expected structure, just do a simple update
        config = yaml.safe_load(content) or {}
        
        # Generate model name and related paths
        model_name = f"qwen-2.5-1.5B-instruct-sft-lora-countdown-{traj_type}-{dataset_size}"
        
        # Update config values
        config["dataset_mixer"] = {dataset_name: float(train_dataset_size[dataset_size])}
        config["hub_model_id"] = model_name
        config["output_dir"] = f"./models/{model_name}"
        config["logging_dir"] = f"./logs/{model_name}"
        config["dataset_message_key"] = messages_field_keys[traj_type]
        
        # Extract chat template
        chat_template_match = re.search(r"chat_template:\s*\|[\s\S]+", content)
        chat_template = chat_template_match.group(0) if chat_template_match else ""
        
        # Remove chat template from config
        if "chat_template" in config:
            del config["chat_template"]
        
        # Create output directory
        output_dir = os.path.dirname(os.path.dirname(base_config_path))
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        output_file = f"{output_dir}/config_{traj_type}_{dataset_size}.yaml"
        
        # Write the config
        with open(output_file, "w") as f:
            f.write(yaml.dump(config, default_flow_style=False, sort_keys=False))
            if chat_template:
                f.write("\n" + chat_template)
    
    print(f"Generated config file: {output_file}")

for dataset_size in train_dataset_size.keys():
    for traj_type in messages_field_keys.keys():
        generate_config(
            base_config_path="./recipes/qwen-2.5/sft/base_configs/config_lora.yaml",
            dataset_size=dataset_size, 
            traj_type=traj_type
        )

Generated config file: ./recipes/qwen-2.5/sft/config_deepseek_1k.yaml
Generated config file: ./recipes/qwen-2.5/sft/config_sos_react_1k.yaml
Generated config file: ./recipes/qwen-2.5/sft/config_sos_1k.yaml
Generated config file: ./recipes/qwen-2.5/sft/config_optimal_1k.yaml


: 