In [36]:
# Import necessary libraries
import os
import sys
import yaml
import numpy as np
import matplotlib.pyplot as plt
import torch
from copy import deepcopy
import logging

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger('scaling_notebook')

In [37]:
# Add project root to path if needed
current_dir = os.getcwd()
if not any(p.endswith('dendritic_modeling') for p in sys.path):
    if os.path.basename(current_dir) == 'notebooks':
        project_root = os.path.dirname(current_dir)
    else:
        project_root = current_dir
    sys.path.insert(0, project_root)
    logger.info(f"Added {project_root} to Python path")

2025-02-25 12:30:01,080 - scaling_notebook - INFO - Added /n/holylabs/LABS/kempner_dev/Users/hsafaai/Code/dendritic-modeling to Python path


In [38]:
# Import necessary modules from the repo
from dendritic_modeling.config import load_config
from dendritic_modeling.models import ProbabilisticClassifier, Classifier
from dendritic_modeling.networks import ExcitationInhibitionNetwork, MLPExcInhNetwork
from dendritic_modeling.synthetic_datasets import get_unified_datasets

In [39]:
# Function to load and modify a configuration file
def load_config_yaml(path):
    with open(path, 'r') as f:
        return yaml.safe_load(f)

def prepare_config(base_config_path, dim, net_type="EINet", epochs=5):
    config = load_config(base_config_path)
    
    # Set network type
    config.model.network.type = net_type
    
    # Set dimensions according to the sweep parameter
    if net_type == "EINet":
        config.model.network.parameters.excitatory_branch_factors = [2, int(dim)]
        config.model.network.parameters.inhibitory_branch_factors = []
    elif net_type == "MLP":
        config.model.network.parameters.hidden_layer_sizes = [int(dim), int(dim//2)]
    
    # Set training parameters
    config.train.epochs = epochs
    
    return config

In [40]:
# Function to initialize model based on configuration
def initialize_model(model_cfg):
    task = model_cfg.task
    probabilistic = model_cfg.probabilistic
    net_type = model_cfg.network.type
    net_params = model_cfg.network.parameters.__dict__
    
    if task == 'classification':
        if net_type == 'MLP':
            logger.info(f"Creating MLP network")
            net = MLPExcInhNetwork(**net_params)
            output_dim = net_params.get('output_dim', 10)
            if probabilistic:
                return ProbabilisticClassifier(net, output_dim)
            else:
                return Classifier(net)
        elif net_type == 'EINet':
            logger.info(f"Creating EINet network")
            net = ExcitationInhibitionNetwork(**net_params)
            output_dim = net_params['excitatory_layer_sizes'][-1]
            if probabilistic:
                return ProbabilisticClassifier(net, output_dim)
            else:
                return Classifier(net)
        else:
            raise ValueError(f"Invalid network type: {net_type}")
    else:
        raise ValueError(f"Invalid task: {task}")

In [41]:
# Function to evaluate model performance
def evaluate_model(model, test_loader, device):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, targets in test_loader:
            inputs, targets = inputs.to(device), targets.to(device)
            
            if hasattr(model, 'compute_loss'):
                loss = model.compute_loss(inputs, targets)
                total_loss += loss.item()
            
            if hasattr(model, 'predict'):
                predictions = model.predict(inputs)
                correct += (predictions == targets).sum().item()
            else:
                outputs = model(inputs)
                predictions = outputs.argmax(dim=1)
                correct += (predictions == targets).sum().item()
            
            total += targets.size(0)
    
    accuracy = correct / total if total > 0 else 0
    avg_loss = total_loss / len(test_loader) if len(test_loader) > 0 else 0
    
    return accuracy, avg_loss

In [42]:
# Function to run a single experiment and return results
def run_single_experiment(base_config_path, dim, net_type="EINet", epochs=5):
    config = prepare_config(base_config_path, dim, net_type, epochs)
    
    # Get datasets
    train_ds, valid_ds, test_ds = get_unified_datasets(config.task, config.train)
    test_loader = torch.utils.data.DataLoader(test_ds, batch_size=64)
    
    # Initialize model
    model = initialize_model(config.model)
    
    # Count parameters
    param_count = sum(p.numel() for p in model.parameters())
    logger.info(f"{net_type} dim={dim}: Parameter count = {param_count}")
    
    # Move to device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    # Evaluate model (without training - for quick analysis)
    accuracy, loss = evaluate_model(model, test_loader, device)
    logger.info(f"{net_type} dim={dim}: Accuracy = {accuracy:.4f}, Loss = {loss:.4f}")
    
    return {"dim": dim, "param_count": param_count, "accuracy": accuracy, "loss": loss}

In [43]:
# Function to run a full sweep of experiments
def run_scaling_experiment(base_config_path, dim_list, net_types=["EINet", "MLP"], epochs=5):
    results = {net_type: [] for net_type in net_types}
    
    for net_type in net_types:
        for dim in dim_list:
            result = run_single_experiment(base_config_path, dim, net_type, epochs)
            results[net_type].append(result)
    
    return results

In [44]:
# Function to plot results on log-log scale
def plot_loglog(results):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
    
    ax1.set_xscale("log")
    ax1.set_yscale("log")
    ax2.set_xscale("log")
    ax2.set_yscale("log")
    
    # Colors and markers for different network types
    styles = {
        "EINet": {"color": "blue", "marker": "o"},
        "MLP": {"color": "red", "marker": "s"}
    }
    
    for net_type, style in styles.items():
        if net_type not in results or not results[net_type]:
            logger.warning(f"No data for {net_type}")
            continue
        
        data_list = sorted(results[net_type], key=lambda x: x["param_count"])
        
        x_param = [d["param_count"] for d in data_list]
        y_loss = [max(1e-10, d["loss"]) for d in data_list]  # Avoid log(0)
        y_error = [max(1e-10, 1.0 - d["accuracy"]) for d in data_list]  # Convert accuracy to error rate
        dims = [d["dim"] for d in data_list]
        
        # Plot loss curve
        ax1.plot(x_param, y_loss, marker=style["marker"], color=style["color"], 
                 label=f"{net_type} (Loss)", linewidth=2, markersize=10)
        
        # Plot error rate curve
        ax2.plot(x_param, y_error, marker=style["marker"], color=style["color"], 
                linestyle='--', label=f"{net_type} (Error)", linewidth=2, markersize=10)
        
        for i, (x, y1, y2, dim) in enumerate(zip(x_param, y_loss, y_error, dims)):
            ax1.annotate(f"dim={dim}", xy=(x, y1), xytext=(10, 0),
                        textcoords='offset points', fontsize=10)
            ax2.annotate(f"dim={dim}", xy=(x, y2), xytext=(10, 0),
                        textcoords='offset points', fontsize=10)
    
    # Set labels and titles
    ax1.set_xlabel("Parameter Count (log scale)", fontsize=12)
    ax1.set_ylabel("Loss (log scale)", fontsize=12)
    ax1.set_title("Loss vs Parameter Count", fontsize=14)
    ax1.grid(True, which="both", alpha=0.3)
    ax1.legend(loc='best', fontsize=12)
    
    ax2.set_xlabel("Parameter Count (log scale)", fontsize=12)
    ax2.set_ylabel("Error Rate (log scale)", fontsize=12)
    ax2.set_title("Error Rate vs Parameter Count", fontsize=14)
    ax2.grid(True, which="both", alpha=0.3)
    ax2.legend(loc='best', fontsize=12)
    
    plt.tight_layout()
    return fig

In [45]:
# Function to calculate scaling exponents
def calculate_scaling_exponents(results):
    scaling_data = {}
    
    for net_type in ["EINet", "MLP"]:
        if net_type not in results or not results[net_type]:
            print(f"No data for {net_type}")
            continue
            
        data_list = sorted(results[net_type], key=lambda x: x["param_count"])
        
        if len(data_list) < 2:
            print(f"Not enough data points for {net_type} to calculate scaling exponent")
            continue
            
        log_params = np.log(np.array([d["param_count"] for d in data_list]))
        log_loss = np.log(np.array([max(1e-10, d["loss"]) for d in data_list]))
        log_error = np.log(np.array([max(1e-10, 1.0 - d["accuracy"]) for d in data_list]))
        
        loss_slope, loss_intercept = np.polyfit(log_params, log_loss, 1)
        error_slope, error_intercept = np.polyfit(log_params, log_error, 1)
        
        scaling_data[net_type] = {
            "loss_slope": loss_slope,
            "error_slope": error_slope,
            "loss_intercept": loss_intercept,
            "error_intercept": error_intercept
        }
        
        print(f"\n{net_type} Scaling:")
        print(f"  Loss scaling exponent: {loss_slope:.4f}")
        print(f"  Error scaling exponent: {error_slope:.4f}")
        
        if loss_slope < 0:
            print(f"  Loss scales with parameters as: loss ∝ (params)^{loss_slope:.4f}")
            print(f"  Doubling parameter count decreases loss by {2**abs(loss_slope)-1:.2%}")
        
        if error_slope < 0:
            print(f"  Error scales with parameters as: error ∝ (params)^{error_slope:.4f}")
            print(f"  Doubling parameter count decreases error by {2**abs(error_slope)-1:.2%}")
    
    return scaling_data

In [46]:
# Function to visualize scaling laws with fitted lines
def plot_scaling_laws(results, scaling_data):
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))
    
    ax1.set_xscale("log")
    ax1.set_yscale("log")
    ax2.set_xscale("log")
    ax2.set_yscale("log")
    
    styles = {
        "EINet": {"color": "blue", "marker": "o"},
        "MLP": {"color": "red", "marker": "s"}
    }
    
    for net_type, style in styles.items():
        if net_type not in results or not results[net_type]:
            continue
        
        data_list = sorted(results[net_type], key=lambda x: x["param_count"])
        
        x_param = np.array([d["param_count"] for d in data_list])
        y_loss = np.array([max(1e-10, d["loss"]) for d in data_list])
        y_error = np.array([max(1e-10, 1.0 - d["accuracy"]) for d in data_list])
        
        # Plot data points
        ax1.scatter(x_param, y_loss, marker=style["marker"], color=style["color"], 
                   s=80, label=f"{net_type} Data")
        ax2.scatter(x_param, y_error, marker=style["marker"], color=style["color"], 
                   s=80, label=f"{net_type} Data")
        
        if net_type in scaling_data:
            # Plot fitted lines
            x_range = np.logspace(np.log10(min(x_param)), np.log10(max(x_param)), 100)
            
            loss_slope = scaling_data[net_type]["loss_slope"]
            loss_intercept = scaling_data[net_type]["loss_intercept"]
            y_loss_fit = np.exp(loss_intercept) * x_range**loss_slope
            
            error_slope = scaling_data[net_type]["error_slope"] 
            error_intercept = scaling_data[net_type]["error_intercept"]
            y_error_fit = np.exp(error_intercept) * x_range**error_slope
            
            ax1.plot(x_range, y_loss_fit, '--', color=style["color"], 
                    label=f"{net_type} Fit: ∝ N^{loss_slope:.3f}")
            ax2.plot(x_range, y_error_fit, '--', color=style["color"], 
                    label=f"{net_type} Fit: ∝ N^{error_slope:.3f}")
    
    ax1.set_xlabel("Parameter Count (N)", fontsize=12)
    ax1.set_ylabel("Loss", fontsize=12)
    ax1.set_title("Scaling Law: Loss vs Parameters", fontsize=14)
    ax1.grid(True, which="both", alpha=0.3)
    ax1.legend(loc='best', fontsize=12)
    
    ax2.set_xlabel("Parameter Count (N)", fontsize=12)
    ax2.set_ylabel("Error Rate", fontsize=12)
    ax2.set_title("Scaling Law: Error Rate vs Parameters", fontsize=14)
    ax2.grid(True, which="both", alpha=0.3)
    ax2.legend(loc='best', fontsize=12)
    
    plt.tight_layout()
    return fig

In [None]:
# Main execution code to run the sweep
if __name__ == "__main__":
    # Configuration 
    BASE_CONFIG = "../.vscode/config_exp.yaml"  # Update path as needed
    DIM_LIST = [2, 4, 8, 16, 32]  # Dimensions to sweep
    NET_TYPES = ["EINet", "MLP"]  # Network types to test
    EPOCHS = 5  # Epochs for each run (set low for quick testing)
    
    # Run the experiment
    results = run_scaling_experiment(
        base_config_path=BASE_CONFIG,
        dim_list=DIM_LIST,
        net_types=NET_TYPES,
        epochs=EPOCHS
    )
    
    # Print results
    for net_type, data_list in results.items():
        print(f"\n{net_type} Results:")
        for data in data_list:
            print(f"  dim={data['dim']}: params={data['param_count']}, "
                  f"accuracy={data['accuracy']:.4f}, loss={data['loss']:.4f}")
    
    # Create basic plot
    plot_figure = plot_loglog(results)
    plt.show()
    
    # Calculate and plot scaling exponents
    scaling_data = calculate_scaling_exponents(results)
    scaling_figure = plot_scaling_laws(results, scaling_data)
    plt.show()

2025-02-25 12:30:01,696 - scaling_notebook - INFO - Creating EINet network
2025-02-25 12:30:01,700 - scaling_notebook - INFO - EINet dim=2: Parameter count = 72210


In [None]:
print(data_list)