In [3]:
from pathlib import Path
import json
import torch
import json
from sklearn.metrics import r2_score, mean_squared_error
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from train import MLP


def map_runs_by_split():
    """
    Scans through all run directories and creates a dictionary mapping
    split types to their corresponding run directories
    """
    runs_dir = Path("runs")
    splits_dict = {}

    # Iterate through all directories in runs/
    for run_dir in runs_dir.iterdir():
        if not run_dir.is_dir() or run_dir.name == "wandb":
            continue

        # Try to load the config file
        config_path = run_dir / "config.json"
        if not config_path.exists():
            continue

        try:
            with open(config_path, "r") as f:
                config = json.load(f)

            # Get the split type from config
            split_type = config.get("split_type")
            if split_type:
                if split_type not in splits_dict:
                    splits_dict[split_type] = []
                splits_dict[split_type].append(str(run_dir))

        except json.JSONDecodeError:
            print(f"Error reading config from {run_dir}")
            continue

    # Print the results in a formatted way
    print("=== Runs by Split Type ===")
    for split_type, runs in splits_dict.items():
        print(f"\n{split_type}:")
        for run in runs:
            print(f"  - {run}")

    return splits_dict


# Create the mapping
splits_dict = map_runs_by_split()


ModuleNotFoundError: No module named 'plotly'

In [11]:
from pathlib import Path
import torch
import json
from pprint import pprint

def load_checkpoint_run(run_dir):
    """
    Load checkpoint and configuration from a specific run directory
    
    Args:
        run_dir (str): Path to the run directory (e.g., 'runs/20240315_123456_wandb_id')
        
    Returns:
        tuple: (checkpoint_data, config_data, feature_mask_data)
    """
    run_path = Path(run_dir)
    
    # Load latest checkpoint
    checkpoint_path = run_path / "checkpoints" / "checkpoint_latest.pt"
    checkpoint = torch.load(checkpoint_path)
    
    # Load config
    config_path = run_path / "config.json"
    with open(config_path, 'r') as f:
        config = json.load(f)
        
    # Load feature mask
    feature_mask_path = run_path / "feature_mask.json"
    with open(feature_mask_path, 'r') as f:
        feature_mask = json.load(f)
    
    return checkpoint, config, feature_mask

# Example usage:
run_dir = "../runs/20250109_002346_pfqj68dv"  # Replace with your actual run directory
checkpoint, config, feature_mask = load_checkpoint_run(run_dir)

print("Configuration:")
pprint(config)
print("\nFeature Mask:")
pprint(feature_mask)
print("\nCheckpoint Contents:")
print(f"Epoch: {checkpoint['epoch']}")
print("Available keys:", checkpoint.keys())

Configuration:
{'_wandb': {},
 'batch_size': 64,
 'device': 7,
 'dropout_rate': 0.1,
 'epochs': 25,
 'feature_groups': ['token_probs', 'model_arch', 'training'],
 'hidden_size': 64,
 'learning_rate': 0.001,
 'normalize_data': False,
 'num_layers': 3,
 'r2_eval_frequency': 1,
 'save_checkpoint_frequency': 0,
 'seed': 43,
 'split_type': 'size_largerthan_500m_split',
 'weight_decay': 0.01}

Feature Mask:
{'feature_groups': ['token_probs', 'model_arch', 'training'],
 'feature_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1]}

Checkpoint Contents:
Epoch: 24
Available keys: dict_keys(['epoch', 'model_state_dict', 'optimizer_state_dict', 'metrics'])



You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



In [13]:
### If config is normalized, use norm_stats to denormalize predictions
def load_model_for_prediction(checkpoint_path):
    """
    Load model and normalization statistics from checkpoint
    """
    # Load checkpoint, config and feature mask
    checkpoint, config, feature_mask_data = load_checkpoint_run(checkpoint_path)
    
    # Extract feature mask and convert to list of booleans
    feature_mask = feature_mask_data.get('feature_mask', [True] * 9)
    if isinstance(feature_mask, dict):
        feature_mask = [bool(feature_mask[str(i)]) if str(i) in feature_mask else True for i in range(9)]
    
    # Initialize model
    model = MLP(
        input_size=9,
        hidden_size=config['hidden_size'],
        output_size=11,
        num_layers=config['num_layers'],
        dropout_rate=config['dropout_rate'],
        feature_mask=feature_mask
    )
    
    model.load_state_dict(checkpoint["model_state_dict"])
    norm_stats = checkpoint["normalization_stats"]
    
    return model, norm_stats

def predict(model, X, norm_stats=None):
    """
    Make predictions with proper normalization/denormalization
    """
    model.eval()
    with torch.no_grad():
        # Normalize input if needed
        if norm_stats is not None:
            # Add small epsilon to avoid division by zero
            eps = 1e-8
            X_std = np.where(norm_stats["X_std"] == 0, eps, norm_stats["X_std"])
            X_normalized = (X - norm_stats["X_mean"]) / X_std
            X_tensor = torch.FloatTensor(X_normalized)
        else:
            X_tensor = torch.FloatTensor(X)
            
        # Get predictions
        y_pred_normalized = model(X_tensor).numpy()
        
        # Denormalize predictions if needed
        if norm_stats is not None:
            # Add small epsilon to avoid division by zero
            y_std = np.where(norm_stats["y_std"] == 0, eps, norm_stats["y_std"])
            y_pred = y_pred_normalized * y_std + norm_stats["y_mean"]
        else:
            y_pred = y_pred_normalized
            
        return y_pred


In [15]:
def analyze_token_proportion_smoothness(split_type, num_points=1000):
    """
    Analyzes how smoothly the model responds to changes in token proportions
    by sweeping across one proportion while maintaining sum = 1
    """
    if split_type not in splits_dict:
        raise ValueError(f"Invalid split type. Must be one of: {list(splits_dict.keys())}")
        
    model_path = splits_dict[split_type][0]  # Use first model from the split type
    # checkpoint_path = Path(model_path) / "checkpoints" / "checkpoint_latest.pt"
    
    # Load model and normalization stats using the modular function
    model, norm_stats = load_model_for_prediction(model_path)
    model.eval()
    
    # Create sweep values for first proportion (0 to 1)
    sweep_values = np.linspace(0, 1, num_points)
    
    # Initialize results dictionary
    results = {
        'sweep_values': sweep_values,
        'predictions': []
    }
    
    # For each sweep value, distribute remaining probability equally
    for p1 in sweep_values:
        # Remaining probability to distribute among other 4 proportions
        remaining = 1 - p1
        if remaining < 0:
            continue
            
        # Distribute remaining equally among other proportions
        other_props = remaining / 4
        
        # Create input array [p1, p2, p3, p4, p5, *other_features]
        x = np.zeros(9)  # Assuming 9 total features
        x[0] = p1
        x[1:5] = other_props  # Equal distribution for other proportions
        
        # Set other features to some reasonable default
        x[5] = 20  # Model size in millions
        x[6] = 256  # d_model dimension
        x[7] = 8  # Number of attention heads
        x[8] = 15000  # Training steps
        
        # Get prediction using the modular predict function
        pred = predict(model, x.reshape(1, -1), norm_stats)
        results['predictions'].append(pred.squeeze())
    
    # Convert predictions to numpy array
    results['predictions'] = np.array(results['predictions'])
    
    # Create visualization
    fig = go.Figure()
    
    # Plot prediction for each output dimension
    datasets = {
        'train_cross_entropy': 0,
        'commoncrawl': 1,
        'c4': 2,
        'wikipedia': 3,
        'stackexchange': 4,
        'github': 5,
        'arxiv': 6,
        'book': 7,
        'hellaswag': 8,
        'piqa': 9,
        'arc_easy': 10
    }
    
    for dataset, idx in datasets.items():
        fig.add_trace(go.Scatter(
            x=sweep_values,
            y=results['predictions'][:, idx],
            name=dataset,
            mode='lines'
        ))
    
    fig.update_layout(
        title='Model Predictions vs First Token Proportion (Denormalized)',
        xaxis_title='First Token Proportion',
        yaxis_title='Predicted Value',
        height=600,
        width=1000
    )
    
    return fig, results

split_type = "single_step_15000_split"
fig, results = analyze_token_proportion_smoothness(split_type)
fig.show()


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



In [18]:
def analyze_token_proportion_smoothness(split_type, token_idx=0, num_points=1000):
    """
    Analyzes how smoothly the model responds to changes in token proportions
    by sweeping across one proportion while maintaining sum = 1
    
    Args:
        split_type (str): Type of model split to analyze
        token_idx (int): Index of token proportion to sweep (0-4)
        num_points (int): Number of points to sample
    """
    if split_type not in splits_dict:
        raise ValueError(f"Invalid split type. Must be one of: {list(splits_dict.keys())}")
    
    if not 0 <= token_idx <= 4:
        raise ValueError("token_idx must be between 0 and 4")
        
    # Token names for better labeling
    token_names = {
        0: "RedPajamaWikipedia",
        1: "RedPajamaStackExchange",
        2: "RedPajamaGithub",
        3: "RedPajamaArXiv",
        4: "RedPajamaBook"
    }
    
    model_path = splits_dict[split_type][0]
    model, norm_stats = load_model_for_prediction(model_path)
    model.eval()
    
    # Create sweep values for selected proportion (0 to 1)
    sweep_values = np.linspace(0, 1, num_points)
    
    results = {
        'sweep_values': sweep_values,
        'predictions': []
    }
    
    # For each sweep value, distribute remaining probability equally
    for p_sweep in sweep_values:
        # Remaining probability to distribute among other 4 proportions
        remaining = 1 - p_sweep
        if remaining < 0:
            continue
            
        # Distribute remaining equally among other proportions
        other_props = remaining / 4
        
        # Create input array [p1, p2, p3, p4, p5, *other_features]
        x = np.zeros(9)
        
        # Set all token proportions to equal distribution first
        x[0:5] = other_props
        
        # Override the swept token proportion
        x[token_idx] = p_sweep
        
        # Set other features to default values
        x[5] = 20  # Model size in millions
        x[6] = 256  # d_model dimension
        x[7] = 8  # Number of attention heads
        x[8] = 15000  # Training steps
        
        # Get prediction
        pred = predict(model, x.reshape(1, -1), norm_stats)
        results['predictions'].append(pred.squeeze())
    
    results['predictions'] = np.array(results['predictions'])
    
    # Create visualization
    fig = go.Figure()
    
    datasets = {
        'train_cross_entropy': 0,
        'commoncrawl': 1,
        'c4': 2,
        'wikipedia': 3,
        'stackexchange': 4,
        'github': 5,
        'arxiv': 6,
        'book': 7,
        'hellaswag': 8,
        'piqa': 9,
        'arc_easy': 10
    }
    
    for dataset, idx in datasets.items():
        fig.add_trace(go.Scatter(
            x=sweep_values,
            y=results['predictions'][:, idx],
            name=dataset,
            mode='lines'
        ))
    
    fig.update_layout(
        title=f'Model Predictions vs {token_names[token_idx]} Token Proportion (Denormalized)',
        xaxis_title=f'{token_names[token_idx]} Token Proportion',
        yaxis_title='Predicted Value',
        height=600,
        width=1000
    )
    
    return fig, results

token_names = {
    0: "RedPajamaWikipedia",
    1: "RedPajamaStackExchange",
    2: "RedPajamaGithub",
    3: "RedPajamaArXiv",
    4: "RedPajamaBook"
}

# Example usage:
split_type = "single_step_15000_split"
token_idx = 3  # Analyze GitHub proportion (index 2)
fig, results = analyze_token_proportion_smoothness(split_type, token_idx=token_idx)
fig.show()


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.



In [24]:
def analyze_model_smoothness_3d(split_type, x_dim=0, y_dim=5, num_points=100):
    """
    Creates a 3D visualization of model predictions by sweeping across two dimensions.
    
    Args:
        split_type (str): Type of model split to analyze
        x_dim (int): Index of feature to sweep on x-axis (0-8)
        y_dim (int): Index of feature to sweep on y-axis (0-8)
        num_points (int): Number of points to sample for each dimension
    """
    if split_type not in splits_dict:
        raise ValueError(f"Invalid split type. Must be one of: {list(splits_dict.keys())}")
        
    model_path = splits_dict[split_type][0]  # Use first model from the split type
    
    # Feature names for better labeling
    feature_names = {
        0: "RedPajamaWikipedia",
        1: "RedPajamaStackExchange", 
        2: "RedPajamaGithub",
        3: "RedPajamaArXiv",
        4: "RedPajamaBook",
        5: "Model Size (M)",
        6: "d_model",
        7: "Num Heads", 
        8: "Training Steps"
    }
    
    # Load model and normalization stats using the modular function
    model, norm_stats = load_model_for_prediction(model_path)
    model.eval()
    
    # Create sweep values
    if x_dim <= 4:  # Token proportion
        x_values = np.linspace(0, 1, num_points)
    else:  # Other features
        x_values = np.linspace(0, 2000, num_points) if x_dim == 5 else np.linspace(0, 1000, num_points)
        
    if y_dim <= 4:  # Token proportion
        y_values = np.linspace(0, 1, num_points)
    else:  # Other features
        y_values = np.linspace(0, 2000, num_points) if y_dim == 5 else np.linspace(0, 1000, num_points)
    
    X, Y = np.meshgrid(x_values, y_values)
    Z = np.zeros((num_points, num_points, 11))  # 11 outputs
    
    # For each combination of x and y values
    for i, x_val in enumerate(x_values):
        for j, y_val in enumerate(y_values):
            # Skip invalid combinations if both dims are token proportions
            if x_dim <= 4 and y_dim <= 4 and (x_val + y_val > 1):
                Z[j, i, :] = np.nan  # Use NaN to indicate invalid regions
                continue
                
            # Create input array with default values
            x = np.zeros(9)
            
            # Set default values for token proportions (if not being swept)
            if x_dim <= 4 or y_dim <= 4:
                x[0:5] = 0.2  # Equal distribution by default
            
            # Set default values for other features
            x[5] = 20  # Model size in millions
            x[6] = 256  # d_model dimension
            x[7] = 8  # Number of attention heads
            x[8] = 15000  # Training steps
            
            # Set sweep values
            x[x_dim] = x_val
            x[y_dim] = y_val
            
            # If either dimension is a token proportion, redistribute others
            if x_dim <= 4 or y_dim <= 4:
                token_props = x[0:5]
                remaining = 1.0 - token_props[x_dim] - (token_props[y_dim] if y_dim <= 4 else 0)
                if remaining >= 0:
                    # Count how many other token proportions need to share the remaining
                    num_others = 3 if (x_dim <= 4 and y_dim <= 4) else 4
                    other_val = remaining / num_others
                    
                    # Distribute remaining to other token proportions
                    for k in range(5):
                        if k != x_dim and k != y_dim:
                            x[k] = other_val
            
            # Get prediction using the modular predict function with denormalization
            pred = predict(model, x.reshape(1, -1), norm_stats)
            Z[j, i, :] = pred.squeeze()
    
    # Create 3D plots for each output
    datasets = {
        'train_cross_entropy': 0,
        'commoncrawl': 1,
        'c4': 2,
        'wikipedia': 3,
        'stackexchange': 4,
        'github': 5,
        'arxiv': 6,
        'book': 7,
        'hellaswag': 8,
        'piqa': 9,
        'arc_easy': 10
    }
    
    figs = {}
    for dataset, idx in datasets.items():
        fig = go.Figure(data=[go.Surface(x=X, y=Y, z=Z[:,:,idx])])
        
        fig.update_layout(
            title=f'{split_type} - {dataset} Predictions (Denormalized)',  # Added denormalized note
            scene=dict(
                xaxis_title=feature_names[x_dim],
                yaxis_title=feature_names[y_dim],
                zaxis_title=f'{dataset} (Denormalized)'
            ),
            width=800,
            height=800,
            autosize=False
        )
        figs[dataset] = fig
    
    return figs, {'X': X, 'Y': Y, 'Z': Z}

# Example usage:
split_type = "single_step_15000_split"
figs, results = analyze_model_smoothness_3d(split_type, x_dim=4, y_dim=5)

# Display one of the plots
figs['book'].show()


You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.

