In [2]:
import os
import yaml
import pandas as pd
import numpy as np
from scipy import stats
import glob
import math

def get_percentiles(param_info):
    """Calculate 5th and 95th percentiles based on distribution type."""
    dist_type = param_info.get('distribution', '')
    
    # For fixed values, return the same value for both percentiles
    if 'values' in param_info:
        values = param_info['values']
        if len(values) == 1:
            return values[0], values[0]
        else:
            # If multiple values, return min and max
            return min(values), max(values)
    
    # For distributions, calculate based on distribution type
    if 'normal' in dist_type.lower() and 'log' not in dist_type.lower():
        mu = param_info.get('mu', 0)
        sigma = param_info.get('sigma', 1)
        return stats.norm.ppf(0.05, loc=mu, scale=sigma), stats.norm.ppf(0.95, loc=mu, scale=sigma)
    
    elif 'log_normal' in dist_type.lower():
        mu = param_info.get('mu', 0)
        sigma = param_info.get('sigma', 1)
        return stats.lognorm.ppf(0.05, s=sigma, scale=math.exp(mu)), stats.lognorm.ppf(0.95, s=sigma, scale=math.exp(mu))
    
    elif 'uniform' in dist_type.lower() or 'int_uniform' in dist_type.lower():
        min_val = param_info.get('min', 0)
        max_val = param_info.get('max', 1)
        
        # For integer uniform, we can calculate different percentiles
        if 'int_uniform' in dist_type.lower():
            # For integer uniform, the 5th percentile is typically close to the min value
            # and the 95th percentile is close to the max value, depending on the range
            range_size = max_val - min_val + 1  # +1 because integers are inclusive
            p05 = min_val + math.floor(0.05 * range_size)
            p95 = min_val + math.ceil(0.95 * range_size) - 1  # -1 because we're rounding up
            return p05, p95
        else:
            # For continuous uniform, return the analytical percentiles
            range_size = max_val - min_val
            p05 = min_val + 0.05 * range_size
            p95 = min_val + 0.95 * range_size
            return p05, p95
    
    # For log_uniform distributions
    elif 'log_uniform' in dist_type.lower():
        min_val = param_info.get('min', 0.01)
        max_val = param_info.get('max', 1)
        
        # Calculate percentiles in log space and then convert back
        log_min = np.log(min_val)
        log_max = np.log(max_val)
        log_range = log_max - log_min
        
        p05_log = log_min + 0.05 * log_range
        p95_log = log_min + 0.95 * log_range
        
        return np.exp(p05_log), np.exp(p95_log)
    
    # Default case
    return None, None

def get_distribution_params(param_info):
    """Convert distribution parameters to mathematical notation."""
    dist_type = param_info.get('distribution', '')
    
    if 'values' in param_info:
        values = param_info['values']
        return f"{values}"
    
    if 'normal' in dist_type.lower():
        mu = param_info.get('mu', 0)
        sigma = param_info.get('sigma', 1)
        return f"mu = {mu}, sigma = {sigma}"
    
    elif 'log_normal' in dist_type.lower():
        mu = param_info.get('mu', 0)
        sigma = param_info.get('sigma', 1)
        return f"mu = {mu}, sigma = {sigma}"
    
    elif 'uniform' in dist_type.lower() or 'int_uniform' in dist_type.lower():
        min_val = param_info.get('min', 0)
        max_val = param_info.get('max', 1)
        return f"[{min_val}, {max_val}]"
    
    elif 'log_uniform' in dist_type.lower():
        min_val = param_info.get('min', 0.01)
        max_val = param_info.get('max', 1)
        return f"[{min_val}, {max_val}]"
    
    return ""

def clean_distribution_name(dist_name):
    """Format distribution names in mathematical notation and remove 'q_' prefix if present."""
    # First, remove 'q_' prefix if present
    if dist_name.startswith('q_'):
        dist_name = dist_name[2:]
    
    # Format distribution names in mathematical notation
    if dist_name == 'normal':
        return 'Normal'
    elif dist_name == 'log_normal':
        return 'Log-normal'
    elif dist_name == 'uniform':
        return 'Uniform'
    elif dist_name == 'int_uniform':
        return 'Integer Uniform'
    elif dist_name == 'log_uniform':
        return 'Log-uniform'
    elif dist_name == 'log_uniform_values':
        return 'Log-uniform'
    
    return dist_name

def extract_quantization(param_info):
    """Extract quantization parameter if available."""
    if 'q' in param_info:
        return param_info['q']
    return None

def process_yaml_files(yaml_data_list):
    """Process all YAML data and return a DataFrame with the required format."""
    results = []
    
    for id, (yaml_data, fixed_parameters) in yaml_data_list.items():
        if 'parameters' not in yaml_data:
            continue
        
        parameters = yaml_data['parameters']
        
        parameter_names = set()
        for param_name, param_info in parameters.items():
            # Skip if not a distribution or values list
            if 'distribution' not in param_info and 'values' not in param_info:
                continue
                
            # Calculate percentiles
            if 'values' in param_info:
                p05, p95 = None, None
            else:
                p05, p95 = get_percentiles(param_info)
            
            # Format percentiles to have consistent decimal places
            if isinstance(p05, (int, float)) and isinstance(p95, (int, float)):
                # For small numbers, show more decimals
                if abs(p05) < 0.01 or abs(p95) < 0.01:
                    p05_formatted = f"{p05:.6f}"
                    p95_formatted = f"{p95:.6f}"
                # For medium numbers, show fewer decimals
                elif abs(p05) < 1 or abs(p95) < 1:
                    p05_formatted = f"{p05:.4f}"
                    p95_formatted = f"{p95:.4f}"
                # For large numbers, show even fewer decimals
                else:
                    p05_formatted = f"{p05:.2f}"
                    p95_formatted = f"{p95:.2f}"
            else:
                p05_formatted = p05
                p95_formatted = p95
            
            # Get distribution name with mathematical notation
            dist_name = clean_distribution_name(param_info.get('distribution', 'values'))
            
            # Get quantization parameter
            q_value = extract_quantization(param_info)
            
            # Get distribution parameters in mathematical notation
            params_math = get_distribution_params(param_info)
            
            # Prepare row data
            row = {
                'parameter': param_name,
                'distribution': dist_name,
                'parameters': params_math,
                'quantization': q_value,
                'percentile_05': p05_formatted,
                'percentile_95': p95_formatted,
                'id': id,
            }
            
            results.append(row)
            parameter_names.add(param_name)
        
        for parameter_name, value in fixed_parameters.items():
            if parameter_name in parameter_names:
                raise ValueError(f"Parameter name '{parameter_name}' is already present in the distribution list for {id}.")
            
            row = {
                'parameter': parameter_name,
                'distribution': 'values',
                'parameters': [value],
                'quantization': None,
                'percentile_05': None,
                'percentile_95': None,
                'id': id,
            }
            results.append(row)
    
    # Convert to DataFrame
    df = pd.DataFrame(results)
    
    # Sort by parameter name for consistency
    df = df.sort_values('parameter')
    
    return df

In [3]:
with open('exp/short/prototree.yaml') as f:
    data = yaml.load(f, Loader=yaml.FullLoader)

# test it out
sample_df = process_yaml_files({'prototree': (data, {
    'num_add_on_layers': 1
})})
sample_df

Unnamed: 0,parameter,distribution,parameters,quantization,percentile_05,percentile_95,id
0,backbone_lr_multiplier,Log-normal,"mu = 0, sigma = 1.2",0.2,0.1389,7.1982,prototree
1,joint_phase_len_at_lr1,Normal,"mu = 70, sigma = 10",5.0,53.55,86.45,prototree
2,log_probabilities,values,[False],,,,prototree
3,lr_step_gamma,Normal,"mu = 0.5, sigma = 0.2",0.1,0.171,0.829,prototree
4,lr_weight_decay,Log-normal,"mu = -10, sigma = 1",5e-05,9e-06,0.000235,prototree
5,non_backbone_lr_multiplier,Log-normal,"mu = 0, sigma = 1.2",0.2,0.1389,7.1982,prototree
7,num_add_on_layers,values,[1],,,,prototree
6,warm_up_phase_len_at_lr1,Normal,"mu = 30, sigma = 5",5.0,21.78,38.22,prototree


In [4]:
# Get all YAML files in the current directory
yaml_files = glob.glob('exp/*/*.yaml')

yaml_data_map = {}

# Parse each YAML file
for file_path in yaml_files:
    # no sbatch and accproto is identical to vanilla accuracy long
    if 'sbatch' in file_path or file_path == 'exp/long/vanilla-accproto.yaml':
        continue
    try:
        with open(file_path, 'r') as file:
            yaml_data = yaml.safe_load(file)
            yaml_data_map[file_path] = yaml_data
    except Exception as e:
        print(f"Error reading {file_path}: {e}")

yaml_data_map.keys()

dict_keys(['exp/long/vanilla-accuracy.yaml', 'exp/short/deformable.yaml', 'exp/short/st-protopnet.yaml', 'exp/short/tesnet.yaml', 'exp/short/vanilla-accuracy.yaml', 'exp/short/prototree.yaml'])

In [5]:
fixed_parameters = {
    'exp/long/vanilla-accuracy.yaml': {
        "joint_epochs_per_phase": 10,
        "last_only_epochs_per_phase": 20,
        "lr_step_gamma": 0.1,
    },
    'exp/short/deformable.yaml': {
        "post_project_phases": 32,
        "lr_step_gamma": 0.1,
    },
    'exp/short/st-protopnet.yaml': {
        "post_project_phases": 32,
        "ortho_p_norm": 1,
        "lr_step_gamma": 0.1,
    },
    'exp/short/tesnet.yaml': {
        "post_project_phases": 32,
        "lr_step_gamma": 0.1,
    },
    'exp/short/vanilla-accuracy.yaml': {
        "post_project_phases": 32,
        "latent_dim_multiplier_exp": -2,
        "lr_step_gamma": 0.1,
    },
    'exp/short/prototree.yaml': {
        "joint_epochs_before_lr_milestones": 30,
        "num_lr_milestones": 5,
        "adamw_weight_decay": 0.0,
        "adamw_eps": 1e-7,
        "depth": 9,
        "proto_channels": 256,
    }
}

In [6]:
yaml_data_map = {k: (v, fixed_parameters[k]) for k, v in yaml_data_map.items()}
yaml_data_map.keys(), [type(v) for v in yaml_data_map.values()]

(dict_keys(['exp/long/vanilla-accuracy.yaml', 'exp/short/deformable.yaml', 'exp/short/st-protopnet.yaml', 'exp/short/tesnet.yaml', 'exp/short/vanilla-accuracy.yaml', 'exp/short/prototree.yaml']),
 [tuple, tuple, tuple, tuple, tuple, tuple])

In [7]:
result_df = process_yaml_files(yaml_data_map)
result_df

Unnamed: 0,parameter,distribution,parameters,quantization,percentile_05,percentile_95,id
79,adamw_eps,values,[1e-07],,,,exp/short/prototree.yaml
78,adamw_weight_decay,values,[0.0],,,,exp/short/prototree.yaml
69,backbone_lr_multiplier,Log-normal,"mu = 0, sigma = 1.2",0.20,0.1389,7.1982,exp/short/prototree.yaml
26,closeness_loss_coef,Normal,"mu = 1.0, sigma = 0.4",0.20,0.3421,1.6579,exp/short/st-protopnet.yaml
56,cluster_coef,Normal,"mu = -1.0, sigma = 0.4",0.20,-1.6579,-0.3421,exp/short/vanilla-accuracy.yaml
...,...,...,...,...,...,...,...
9,separation_coef,Normal,"mu = 0.08, sigma = 0.1",,-0.0845,0.2445,exp/long/vanilla-accuracy.yaml
65,separation_coef,Normal,"mu = 0.1, sigma = 0.04",0.02,0.0342,0.1658,exp/short/vanilla-accuracy.yaml
37,support_separation_coef,Normal,"mu = 0.5, sigma = 0.04",0.02,0.4342,0.5658,exp/short/st-protopnet.yaml
38,trivial_separation_coef,Normal,"mu = 0.1, sigma = 0.04",0.02,0.0342,0.1658,exp/short/st-protopnet.yaml


In [8]:
result_df.to_csv('analysis/parameter_ranges.csv', index=False)