In [1]:
import os
import sys
import json
import yaml
import numpy as np
import matplotlib.pyplot as plt
from copy import deepcopy
import logging

logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger('scaling_notebook')

current_dir = os.getcwd()
if not any(p.endswith('dendritic_modeling') for p in sys.path):
    if os.path.basename(current_dir) == 'notebooks':
        project_root = os.path.dirname(current_dir)
    else:
        project_root = current_dir
    sys.path.insert(0, project_root)
    logger.info(f"Added {project_root} to Python path")

try:
    from dendritic_modeling.train_experiments import main as train_main
    logger.info("Successfully imported train_main from dendritic_modeling")
except ImportError as e:
    logger.error(f"Failed to import: {e}")
    raise

2025-02-25 10:33:42,112 - scaling_notebook - INFO - Added /n/holylabs/LABS/kempner_dev/Users/hsafaai/Code/dendritic-modeling to Python path
2025-02-25 10:34:07,045 - scaling_notebook - INFO - Successfully imported train_main from dendritic_modeling


In [2]:
def load_config_yaml(path):
    """Load a YAML configuration file."""
    try:
        with open(path, 'r') as f:
            return yaml.safe_load(f)
    except Exception as e:
        logger.error(f"Failed to load config file {path}: {e}")
        return None

def save_config_yaml(cfg_dict, path):
    """Save a configuration dictionary to a YAML file."""
    os.makedirs(os.path.dirname(path), exist_ok=True)
    with open(path, 'w') as f:
        yaml.dump(cfg_dict, f, sort_keys=False)

def parse_experiment_logs(exp_dir):
    """
    Parse param_count and final_loss from experiment logs.
    Returns (param_count, final_loss) or (None, None) if unsuccessful.
    """
    final_path = os.path.join(exp_dir, "final_results.json")
    if not os.path.exists(final_path):
        logger.warning(f"final_results.json not found in {exp_dir}")
        return None, None
    
    try:
        with open(final_path, 'r') as f:
            data = json.load(f)
        
        param_count = data.get("param_count")
        final_loss = data.get("final_loss")
        
        # Check if we have valid values
        if param_count is None:
            logger.warning(f"param_count not found in {final_path}")
        if final_loss is None:
            logger.warning(f"final_loss not found in {final_path}")
            
        return param_count, final_loss
    except Exception as e:
        logger.error(f"Error parsing {final_path}: {e}")
        return None, None

In [3]:
def run_single_experiment(config_base, output_dir, experiment_name):
    """
    Run a single experiment with the given configuration.
    Returns (param_count, final_loss) or (None, None) if unsuccessful.
    """
    os.makedirs(output_dir, exist_ok=True)
    config_path = os.path.join(output_dir, "config.yaml")
    save_config_yaml(config_base, config_path)

    logger.info(f"Running experiment: {experiment_name}")
    try:
        train_main(config_path=config_path, output_dir=output_dir, experiment_name=experiment_name)
    except Exception as e:
        logger.error(f"Training failed: {e}")
        return None, None
    
    # Parse experiment logs to get param_count and final_loss
    logger.info(f"Parsing results for: {experiment_name}")
    pcount, loss = parse_experiment_logs(output_dir)
    
    if pcount is None or loss is None:
        logger.warning(f"Failed to get param_count or final_loss for {experiment_name}")
    else:
        logger.info(f"Successfully retrieved metrics: params={pcount}, loss={loss}")
        
    return pcount, loss

def json_serialize_fix(obj):
    """
    Helper function to make values JSON serializable.
    Handles numpy values, which don't serialize well.
    """
    import numpy as np
    
    if isinstance(obj, np.integer):
        return int(obj)
    elif isinstance(obj, np.floating):
        return float(obj)
    elif isinstance(obj, np.ndarray):
        return obj.tolist()
    else:
        return obj

def save_results_json(results, output_dir, experiment_name_prefix):
    """
    Enhanced function to save results with proper JSON serialization.
    Returns True if successful, False otherwise.
    """
    try:
        # Create a copy of results to modify
        results_copy = {}
        
        # Process each network type
        for net_type, data_list in results.items():
            # Process each data point
            processed_data_list = []
            for data_point in data_list:
                # Create a new dict with properly serialized values
                processed_data = {}
                for key, value in data_point.items():
                    processed_data[key] = json_serialize_fix(value)
                processed_data_list.append(processed_data)
            results_copy[net_type] = processed_data_list
        
        # Save to file
        out_json = os.path.join(output_dir, f"{experiment_name_prefix}_allresults.json")
        with open(out_json, 'w') as f:
            json.dump(results_copy, f, indent=2)
            
        # Verify by loading back
        with open(out_json, 'r') as f:
            verification = json.load(f)
            
        # Log verification info
        logger.info(f"Saved and verified results to {out_json}")
        logger.info(f"Saved data contains: {list(verification.keys())}")
        for net_type in verification:
            logger.info(f"  {net_type}: {len(verification[net_type])} data points")
            
        return True
    except Exception as e:
        logger.error(f"Error saving results: {e}")
        return False

def run_scaling_experiment(base_config_path, output_base_dir, dim_values,
                           experiment_name_prefix="scaling_exp"):
    """
    For each dim in dim_values, do 2 runs:
      1) EINet
      2) MLP
    Modify 'excitatory_branch_factors = [2, dim]' in config.
    Return a dictionary:
      results["EINet"] = list of { "dim":..., "param_count":..., "loss":... }
      results["MLP"]   = list of { "dim":..., "param_count":..., "loss":... }
    """
    os.makedirs(output_base_dir, exist_ok=True)

    base_cfg = load_config_yaml(base_config_path)
    if not base_cfg:
        logger.error(f"Failed to load base config from {base_config_path}")
        return None

    results = {"EINet": [], "MLP": []}

    for dim in dim_values:
        for net_type in ["EINet", "MLP"]:
            cfg = deepcopy(base_cfg)

            # Set the network type
            cfg["model"]["network"]["type"] = net_type

            # Set excitatory_branch_factors = [2, dim]
            cfg["model"]["network"]["parameters"]["excitatory_branch_factors"] = [2, int(dim)]

            # Ensure we don't have mismatch in inhibitory factors
            cfg["model"]["network"]["parameters"]["inhibitory_branch_factors"] = []

            # Adjust training settings for faster experiment runs
            # Set lower number of epochs (10-20) for this scaling experiment
            cfg["train"]["epochs"] = 20
            
            # Build a unique experiment dir
            exp_name = f"{experiment_name_prefix}_{net_type}_dim{dim}"
            exp_dir = os.path.join(output_base_dir, exp_name)

            pcount, loss = run_single_experiment(cfg, exp_dir, exp_name)
            logger.info(f"Completed {net_type} dim={dim}: params={pcount}, loss={loss}")
            
            if (pcount is not None) and (loss is not None):
                results[net_type].append({"dim": dim, "param_count": pcount, "loss": loss})
            else:
                logger.warning(f"Could not get param_count/loss for {net_type} dim={dim}")
                
                # Try to get param_count manually if possible
                try:
                    model_path = os.path.join(exp_dir, "best_model.pt")
                    if os.path.exists(model_path):
                        logger.info(f"Attempting to load model to count parameters: {model_path}")
                        import torch
                        model = torch.load(model_path)
                        manual_pcount = sum(p.numel() for p in model.parameters())
                        logger.info(f"Manual parameter count: {manual_pcount}")
                        
                        # If we have a param count but no loss, use accuracy as fallback
                        perf_file = os.path.join(exp_dir, "performance", "final.json")
                        if os.path.exists(perf_file):
                            try:
                                with open(perf_file, 'r') as f:
                                    perf_data = json.load(f)
                                test_acc = perf_data.get("test accuracy")
                                if test_acc is not None:
                                    # Use 1-accuracy as a proxy for loss
                                    proxy_loss = 1.0 - float(test_acc)
                                    logger.info(f"Using 1-accuracy as proxy for loss: {proxy_loss}")
                                    
                                    results[net_type].append({
                                        "dim": dim, 
                                        "param_count": manual_pcount, 
                                        "loss": proxy_loss
                                    })
                            except Exception as e:
                                logger.error(f"Error reading performance file: {e}")
                except Exception as e:
                    logger.error(f"Failed manual parameter counting: {e}")

    # Save the results with enhanced JSON serialization
    save_success = save_results_json(results, output_base_dir, experiment_name_prefix)
    if not save_success:
        logger.warning("Results may not have been saved correctly")
    
    return results

In [4]:
def plot_loglog(results, output_dir, experiment_name_prefix):
    """
    Create a log-log plot of parameter count vs. loss with enhanced error checking.
    """
    # Debug input
    logger.info("Starting plot_loglog function")
    logger.info(f"Results type: {type(results)}")
    
    if not isinstance(results, dict):
        logger.error(f"Results is not a dictionary: {type(results)}")
        return None
        
    logger.info(f"Results keys: {results.keys()}")
    
    fig, ax = plt.subplots(figsize=(10, 8))
    
    ax.set_xscale("log")
    ax.set_yscale("log")
    
    has_data = False

    # We'll loop over the two network types
    for net_type, color, marker in [("EINet", "blue", "o"), ("MLP", "red", "s")]:
        if net_type not in results:
            logger.warning(f"Network type {net_type} not found in results")
            continue
            
        data_list = results.get(net_type, [])
        logger.info(f"{net_type} data list: {data_list}")
        
        # Filter out entries missing param_count or loss
        filtered_data = []
        for d in data_list:
            if not isinstance(d, dict):
                logger.warning(f"Data point is not a dictionary: {d}")
                continue
                
            if d.get("param_count") and d.get("loss"):
                # Ensure values are numeric
                try:
                    pc = float(d["param_count"])
                    loss = float(d["loss"])
                    d["param_count"] = pc
                    d["loss"] = loss
                    filtered_data.append(d)
                except (ValueError, TypeError) as e:
                    logger.warning(f"Non-numeric values in data point: {d}, Error: {e}")
            else:
                logger.warning(f"Missing param_count or loss in data point: {d}")
        
        data_list = filtered_data
        logger.info(f"{net_type} filtered data list: {data_list}")
        
        # Sort by param_count for a nice curve
        data_list.sort(key=lambda x: x["param_count"])
        
        if len(data_list) > 0:
            has_data = True
            x_param = [d["param_count"] for d in data_list]
            y_loss = [d["loss"] for d in data_list]
            
            logger.info(f"{net_type} x values: {x_param}")
            logger.info(f"{net_type} y values: {y_loss}")
            
            # Add curve with line between points
            ax.plot(x_param, y_loss, marker=marker, color=color, label=net_type, linewidth=2, markersize=10)
            
            # Add point labels for individual dimensions
            for d in data_list:
                ax.annotate(
                    f"dim={d['dim']}", 
                    xy=(d["param_count"], d["loss"]),
                    xytext=(10, 0),
                    textcoords='offset points',
                    fontsize=10
                )
    
    if has_data:
        ax.set_xlabel("Parameter Count (log scale)", fontsize=12)
        ax.set_ylabel("Loss (log scale)", fontsize=12)
        ax.set_title(f"Scaling Behavior: {experiment_name_prefix}", fontsize=14)
        ax.grid(True, which="both", alpha=0.3)
        ax.legend(loc='best', fontsize=12)
        
        # Add minor gridlines
        ax.grid(which='minor', linestyle=':', alpha=0.2)
        
        plt.tight_layout()
        
        # Save the figure
        savefig = os.path.join(output_dir, f"{experiment_name_prefix}_loglog.png")
        plt.savefig(savefig, dpi=300)
        logger.info(f"Saved log-log plot to {savefig}")
        
        # Display the plot in the notebook
        plt.show()
    else:
        logger.warning("No valid data to plot")
    
    return fig

def load_existing_results(output_dir, experiment_name_prefix):
    """
    Enhanced function to load existing results with better error handling.
    """
    results_path = os.path.join(output_dir, f"{experiment_name_prefix}_allresults.json")
    
    if not os.path.exists(results_path):
        logger.warning(f"Results file does not exist: {results_path}")
        return None
    
    try:
        with open(results_path, 'r') as f:
            results = json.load(f)
        
        # Validate structure
        if not isinstance(results, dict):
            logger.error(f"Results file has invalid structure - not a dictionary")
            return None
            
        for net_type in ['EINet', 'MLP']:
            if net_type not in results:
                logger.warning(f"Results file missing network type: {net_type}")
                results[net_type] = []
                
            if not isinstance(results[net_type], list):
                logger.error(f"Results for {net_type} is not a list")
                results[net_type] = []
        
        logger.info(f"Successfully loaded results from {results_path}")
        # Log what was loaded
        for net_type in results:
            logger.info(f"  {net_type}: {len(results[net_type])} data points")
            for item in results[net_type]:
                logger.info(f"    dim={item.get('dim')}, params={item.get('param_count')}, loss={item.get('loss')}")
                
        return results
    except json.JSONDecodeError:
        logger.error(f"Results file contains invalid JSON: {results_path}")
        return None
    except Exception as e:
        logger.error(f"Error loading results: {e}")
        return None

In [None]:
BASE_CONFIG = "../.vscode/config_exp.yaml"  # Path to base configuration file
OUTPUT_DIR = "/n/holylabs/LABS/kempner_dev/Users/hsafaai/results/scaling_experiment"  # Directory to save results
DIM_LIST = [2, 4, 8, 16, 32]  # Dimensions to sweep through
EXPERIMENT_NAME = "scaling_exp"  # Prefix for experiment names

os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
existing_results = load_existing_results(OUTPUT_DIR, EXPERIMENT_NAME)

print("No existing results found. Running scaling experiments...")
results = run_scaling_experiment(
    base_config_path=BASE_CONFIG,
    output_base_dir=OUTPUT_DIR,
    dim_values=DIM_LIST,
    experiment_name_prefix=EXPERIMENT_NAME
)

2025-02-25 10:34:07,260 - scaling_notebook - INFO - Successfully loaded results from /n/holylabs/LABS/kempner_dev/Users/hsafaai/results/scaling_experiment/scaling_exp_allresults.json
2025-02-25 10:34:07,261 - scaling_notebook - INFO -   EINet: 0 data points
2025-02-25 10:34:07,261 - scaling_notebook - INFO -   MLP: 0 data points


Using existing results from previous runs.

Existing Results Summary:

EINet Results:

MLP Results:


2025-02-25 11:05:30,934 - scaling_notebook - INFO - Running experiment: scaling_exp_EINet_dim2



Running scaling experiments...
INFO logging_config.py:52  Log file set to /n/holylabs/LABS/kempner_dev/Users/hsafaai/results/scaling_experiment/scaling_exp_EINet_dim2/dendritic_modeling.log.
INFO train_experiments.py:102  Loading dataset...
Files already downloaded and verified
Files already downloaded and verified
INFO train_experiments.py:55  Running EINet with probabilistic classifier
INFO train_experiments.py:139  Using learning strategy: mle
INFO train_experiments.py:152  Saving results and visualizations to: /n/holylabs/LABS/kempner_dev/Users/hsafaai/results/dendritic_modeling/results/scaling_exp_EINet_dim2_2025-02-25_11-05-40
INFO train_experiments.py:188  Completed epoch 1/20
INFO train_experiments.py:188  Completed epoch 2/20
INFO train_experiments.py:188  Completed epoch 3/20
INFO train_experiments.py:188  Completed epoch 4/20
INFO train_experiments.py:188  Completed epoch 5/20
INFO train_experiments.py:188  Completed epoch 6/20
INFO train_experiments.py:188  Completed epoch

2025-02-25 11:09:37,370 - scaling_notebook - INFO - Parsing results for: scaling_exp_EINet_dim2
2025-02-25 11:09:37,370 - scaling_notebook - INFO - Completed EINet dim=2: params=None, loss=None
2025-02-25 11:09:37,374 - scaling_notebook - INFO - Running experiment: scaling_exp_MLP_dim2


INFO logging_config.py:52  Log file set to /n/holylabs/LABS/kempner_dev/Users/hsafaai/results/scaling_experiment/scaling_exp_MLP_dim2/dendritic_modeling.log.
INFO train_experiments.py:102  Loading dataset...
Files already downloaded and verified
Files already downloaded and verified
INFO train_experiments.py:46  Running feedforward with probabilistic classifier
INFO networks.py:341  Building MLPExcInhNetwork...
INFO train_experiments.py:139  Using learning strategy: mle
INFO train_experiments.py:152  Saving results and visualizations to: /n/holylabs/LABS/kempner_dev/Users/hsafaai/results/dendritic_modeling/results/scaling_exp_MLP_dim2_2025-02-25_11-09-42
INFO train_experiments.py:188  Completed epoch 1/20
INFO train_experiments.py:188  Completed epoch 2/20
INFO train_experiments.py:188  Completed epoch 3/20
INFO train_experiments.py:188  Completed epoch 4/20
INFO train_experiments.py:188  Completed epoch 5/20
INFO train_experiments.py:188  Completed epoch 6/20
INFO train_experiments.py

2025-02-25 11:10:16,669 - scaling_notebook - INFO - Parsing results for: scaling_exp_MLP_dim2
2025-02-25 11:10:16,670 - scaling_notebook - INFO - Completed MLP dim=2: params=None, loss=None
2025-02-25 11:10:16,673 - scaling_notebook - INFO - Running experiment: scaling_exp_EINet_dim4


INFO logging_config.py:52  Log file set to /n/holylabs/LABS/kempner_dev/Users/hsafaai/results/scaling_experiment/scaling_exp_EINet_dim4/dendritic_modeling.log.
INFO train_experiments.py:102  Loading dataset...
Files already downloaded and verified
Files already downloaded and verified
INFO train_experiments.py:55  Running EINet with probabilistic classifier
INFO train_experiments.py:139  Using learning strategy: mle
INFO train_experiments.py:152  Saving results and visualizations to: /n/holylabs/LABS/kempner_dev/Users/hsafaai/results/dendritic_modeling/results/scaling_exp_EINet_dim4_2025-02-25_11-10-21
INFO train_experiments.py:188  Completed epoch 1/20
INFO train_experiments.py:188  Completed epoch 2/20
INFO train_experiments.py:188  Completed epoch 3/20
INFO train_experiments.py:188  Completed epoch 4/20
INFO train_experiments.py:188  Completed epoch 5/20
INFO train_experiments.py:188  Completed epoch 6/20
INFO train_experiments.py:188  Completed epoch 7/20
INFO train_experiments.py:

2025-02-25 11:14:04,209 - scaling_notebook - INFO - Parsing results for: scaling_exp_EINet_dim4
2025-02-25 11:14:04,209 - scaling_notebook - INFO - Completed EINet dim=4: params=None, loss=None
2025-02-25 11:14:04,213 - scaling_notebook - INFO - Running experiment: scaling_exp_MLP_dim4


INFO logging_config.py:52  Log file set to /n/holylabs/LABS/kempner_dev/Users/hsafaai/results/scaling_experiment/scaling_exp_MLP_dim4/dendritic_modeling.log.
INFO train_experiments.py:102  Loading dataset...
Files already downloaded and verified
Files already downloaded and verified
INFO train_experiments.py:46  Running feedforward with probabilistic classifier
INFO networks.py:341  Building MLPExcInhNetwork...
INFO train_experiments.py:139  Using learning strategy: mle
INFO train_experiments.py:152  Saving results and visualizations to: /n/holylabs/LABS/kempner_dev/Users/hsafaai/results/dendritic_modeling/results/scaling_exp_MLP_dim4_2025-02-25_11-14-09
INFO train_experiments.py:188  Completed epoch 1/20
INFO train_experiments.py:188  Completed epoch 2/20
INFO train_experiments.py:188  Completed epoch 3/20
INFO train_experiments.py:188  Completed epoch 4/20
INFO train_experiments.py:188  Completed epoch 5/20
INFO train_experiments.py:188  Completed epoch 6/20
INFO train_experiments.py

2025-02-25 11:14:43,872 - scaling_notebook - INFO - Parsing results for: scaling_exp_MLP_dim4
2025-02-25 11:14:43,873 - scaling_notebook - INFO - Completed MLP dim=4: params=None, loss=None
2025-02-25 11:14:43,876 - scaling_notebook - INFO - Running experiment: scaling_exp_EINet_dim8


INFO logging_config.py:52  Log file set to /n/holylabs/LABS/kempner_dev/Users/hsafaai/results/scaling_experiment/scaling_exp_EINet_dim8/dendritic_modeling.log.
INFO train_experiments.py:102  Loading dataset...
Files already downloaded and verified
Files already downloaded and verified
INFO train_experiments.py:55  Running EINet with probabilistic classifier
INFO train_experiments.py:139  Using learning strategy: mle
INFO train_experiments.py:152  Saving results and visualizations to: /n/holylabs/LABS/kempner_dev/Users/hsafaai/results/dendritic_modeling/results/scaling_exp_EINet_dim8_2025-02-25_11-14-48
INFO train_experiments.py:188  Completed epoch 1/20
INFO train_experiments.py:188  Completed epoch 2/20
INFO train_experiments.py:188  Completed epoch 3/20
INFO train_experiments.py:188  Completed epoch 4/20
INFO train_experiments.py:188  Completed epoch 5/20


In [None]:
# Plot the results
if results:
    print("\nPlotting results...")
    plot_loglog(results, OUTPUT_DIR, EXPERIMENT_NAME)
    print("\nScaling experiment completed successfully!")
else:
    print("\nNo results to plot. Check logs for errors.")

In [None]:
if results:
    for net_type in ["EINet", "MLP"]:
        data_list = results.get(net_type, [])
        # Filter out entries missing param_count or loss
        data_list = [d for d in data_list if (d.get("param_count") and d.get("loss"))]
        
        if len(data_list) >= 2:  # Need at least 2 points to calculate slope
            # Sort by param_count
            data_list.sort(key=lambda x: x["param_count"])
            
            # Extract log values
            log_params = np.log(np.array([d["param_count"] for d in data_list]))
            log_loss = np.log(np.array([d["loss"] for d in data_list]))
            
            # Simple linear regression to find slope
            if len(log_params) > 1:  # Check for valid array length
                slope = np.polyfit(log_params, log_loss, 1)[0]
                print(f"\n{net_type} scaling exponent (slope in log-log space): {slope:.4f}")
                
                if slope < 0:
                    print(f"Loss scales with parameters as: loss ∝ (params)^{slope:.4f}")
                    print(f"This means that doubling parameter count decreases loss by {2**abs(slope)-1:.2%}")