In [1]:
import pickle
from pathlib import Path
from typing import List, Union, Optional, Any
import json

def compare_model_args(
    model_paths: List[Union[str, Path]],
    keys_to_compare: Optional[List[str]] = None,
    keys_to_ignore: Optional[List[str]] = None,
) -> None:
    """
    Load and compare arguments across multiple model checkpoints.
    
    Args:
        model_paths: List of paths to model directories (each should contain 'args' file)
        keys_to_compare: Optional list of specific keys to compare. If None, shows all keys.
        keys_to_ignore: Optional list of keys to skip. Default ignores common noisy keys.
    """
    # Default keys to ignore
    if keys_to_ignore is None:
        keys_to_ignore = ['wandb', 'participant_suffixes', 'interctc', 
                         'manifest_paths', 'milestones', 'outputDir']
    
    # Load args from each model
    model_args = []
    valid_paths = []
    
    for model_path in model_paths:
        path = Path(model_path)
        args_file = path / "args"
        
        if not args_file.exists():
            print(f"Warning: {args_file} does not exist, skipping.")
            continue
            
        with open(args_file, "rb") as f:
            args = pickle.load(f)
            model_args.append(args)
            valid_paths.append(path.name)
    
    if not model_args:
        print("No valid model args found.")
        return
    
    # Collect all keys if not specified
    if keys_to_compare is None:
        all_keys = set()
        for args in model_args:
            # Handle both dict and object with __dict__
            if isinstance(args, dict):
                all_keys.update(args.keys())
            else:
                all_keys.update(vars(args).keys())
        # Filter out ignored keys
        all_keys = all_keys - set(keys_to_ignore)
        keys_to_compare = sorted(all_keys)
    
    def format_transformer_specs(transformer_dict: dict) -> str:
        """Format transformer specs in a clean, readable way."""
        lines = []
        for key, value in sorted(transformer_dict.items()):
            if isinstance(value, dict):
                # Nested dict (like chunked_attention)
                lines.append(f"{key}:")
                for k, v in sorted(value.items()):
                    if isinstance(v, dict):
                        lines.append(f"  {k}:")
                        for kk, vv in sorted(v.items()):
                            lines.append(f"    {kk}: {vv}")
                    else:
                        lines.append(f"  {k}: {v}")
            elif isinstance(value, float):
                if abs(value) < 0.001 or abs(value) > 10000:
                    lines.append(f"{key}: {value:.2e}")
                else:
                    lines.append(f"{key}: {value:.4f}")
            else:
                lines.append(f"{key}: {value}")
        return "\n  " + "\n  ".join(lines)
    
    def format_value(value: Any, key: str = "", max_width: int = 50) -> str:
        """Format value for display with pretty printing for nested structures."""
        # Handle numeric types with appropriate precision
        if isinstance(value, float):
            if abs(value) < 0.001 or abs(value) > 10000:
                return f"{value:.2e}"
            else:
                return f"{value:.4f}"
        
        # Special handling for model dict - filter out GRU, format transformer nicely
        if key == "model" and isinstance(value, dict):
            filtered_model = {k: v for k, v in value.items() if k != 'gru'}
            if 'transformer' in filtered_model:
                lines = []
                for model_key, model_val in filtered_model.items():
                    if model_key == 'transformer':
                        lines.append(f"transformer:{format_transformer_specs(model_val)}")
                    else:
                        lines.append(f"{model_key}: {model_val}")
                return "\n  " + "\n  ".join(lines)
            return json.dumps(filtered_model, indent=2).replace("\n", "\n  ")
        
        # Pretty print dicts on multiple lines if they're nested
        if isinstance(value, dict):
            # For model specs and nested dicts, format nicely
            if len(str(value)) > max_width:
                formatted = json.dumps(value, indent=2)
                return "\n  " + formatted.replace("\n", "\n  ")
            return str(value)
        
        # Convert to string and truncate if needed
        value_str = str(value)
        if len(value_str) > max_width:
            return value_str[:max_width-3] + "..."
        return value_str
    
    # Print comparison
    print(f"\n{'='*100}")
    print(f"Comparing {len(valid_paths)} models:")
    for i, name in enumerate(valid_paths, 1):
        print(f"  [{i}] {name}")
    print(f"{'='*100}\n")
    
    # Group related keys
    masking_keys = ['max_mask_pct', 'num_masks']
    
    # Find differences and group masking params
    differences = []
    same = []
    masking_values = {k: [] for k in masking_keys}
    
    for key in keys_to_compare:
        values = []
        for args in model_args:
            # Handle both dict and object access
            if isinstance(args, dict):
                value = args.get(key, "N/A")
            else:
                value = getattr(args, key, "N/A")
            values.append(value)
        
        # Track masking params separately
        if key in masking_keys:
            masking_values[key] = values
            continue
        
        # Check if all values are the same
        if len(set(str(v) for v in values)) == 1:
            same.append(key)
        else:
            differences.append((key, values))
    
    # Add masking params as a group if there are differences
    masking_differs = any(
        len(set(str(v) for v in masking_values[k])) > 1 
        for k in masking_keys if masking_values[k]
    )
    if masking_differs:
        differences.insert(0, ('masking_params', masking_values))
    
    # Print differences first (most important)
    if differences:
        print(f"\n{'DIFFERENCES':-^100}\n")
        for key, values in differences:
            print(f"\n{key}:")
            print("-" * 100)
            
            # Special handling for grouped masking params
            if key == 'masking_params':
                for idx, name in enumerate(valid_paths):
                    mask_info = ", ".join(f"{k}={masking_values[k][idx]}" for k in masking_keys if masking_values[k])
                    print(f"  [{idx+1}] {name:<40} = {mask_info}")
                continue
            
            for i, (name, value) in enumerate(zip(valid_paths, values), 1):
                formatted_val = format_value(value, key=key)
                if "\n" in formatted_val:
                    print(f"  [{i}] {name}:{formatted_val}")
                else:
                    print(f"  [{i}] {name:<40} = {formatted_val}")
    
    # Print same values (less important, collapsed)
    if same:
        print(f"\n\n{'SAME VALUES (collapsed)':-^100}\n")
        for key in same:
            value = model_args[0].get(key) if isinstance(model_args[0], dict) else getattr(model_args[0], key)
            formatted = format_value(value, key=key, max_width=80)
            if "\n" in formatted:
                print(f"\n{key}:{formatted}")
            else:
                print(f"{key:<40} = {formatted}")

In [4]:
compare_model_args([
    "/data2/brain2text/b2t_combined/outputs/baseline_hpo_b2t_25_trial_38", 
    "/data2/brain2text/b2t_combined/outputs/fully_chunked_25_trial_38_seed_0"
], keys_to_compare=None)



Comparing 2 models:
  [1] baseline_hpo_b2t_25_trial_38
  [2] fully_chunked_25_trial_38_seed_0


--------------------------------------------DIFFERENCES---------------------------------------------


device:
----------------------------------------------------------------------------------------------------
  [1] baseline_hpo_b2t_25_trial_38             = cuda:1
  [2] fully_chunked_25_trial_38_seed_0         = cuda:0

early_stopping_enabled:
----------------------------------------------------------------------------------------------------
  [1] baseline_hpo_b2t_25_trial_38             = True
  [2] fully_chunked_25_trial_38_seed_0         = False

l2_decay:
----------------------------------------------------------------------------------------------------
  [1] baseline_hpo_b2t_25_trial_38             = 1.91e-05
  [2] fully_chunked_25_trial_38_seed_0         = 1.91e-05

model:
----------------------------------------------------------------------------------------------------
  [1] b