# Subgrid Emulator - Basic Usage Examples

This notebook demonstrates how to use the `subgrid_emu` package to make predictions for various cosmological summary statistics.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from subgrid_emu import (
    load_emulator, 
    list_available_statistics,
    get_x_grid,
    get_plot_info,
    get_parameter_info
)


## 1. List Available Statistics

In [None]:
from IPython.display import display, Markdown

param_info = get_parameter_info()

def create_parameter_markdown_table():
    md = "### Input Parameters\n\n"
    md += "| # | Parameter | Symbol | Range | Description |\n"
    md += "|---|-----------|-------|-------|-------------|\n"
    
    for i, name in enumerate(param_info['names']):
        latex_name = param_info['latex_names'][i]
        range_val = param_info['ranges'][name]
        desc = param_info['descriptions'][name]
        
        # Ensure LaTeX is properly wrapped
        if not latex_name.startswith('$'):
            latex_name = f'${latex_name}$'
        
        # Escape pipe characters in the content
        latex_name = latex_name.replace('|', '\\|')
        desc = desc.replace('|', '\\|')
        
        md += f"| {i+1} | {name} | {latex_name} | {str(range_val)} | {desc} |\n"
    
    return md

display(Markdown(create_parameter_markdown_table()))

In [None]:
from IPython.display import display, Markdown

stats = list_available_statistics()

def create_markdown_table(stats_dict, title):
    md = f"### {title}\n\n"
    md += "| stat_name | Symbol [Units] | Description |\n"
    md += "|-----------|-------|-------------|\n"
    
    for stat in stats_dict:
        plot_info = get_plot_info(stat)
        latex_expr = plot_info['ylabel']
        # Ensure LaTeX is properly wrapped
        if not latex_expr.startswith('$'):
            latex_expr = f'${latex_expr}$'
        
        # Escape pipe characters in the content
        latex_expr = latex_expr.replace('|', '\\|')
        description = plot_info['title'].replace('|', '\\|')
        
        md += f"| {stat} | {latex_expr} | {description} |\n"
    
    return md

display(Markdown(create_markdown_table(stats['5-parameter'], "5-parameter models (smaller simulation box)")))
display(Markdown(create_markdown_table(stats['2-parameter'], "2-parameter models (larger simulation box)")))

## 2. Parameter Information

## 3. All 5-Parameter Models

Display predictions for all available 5-parameter summary statistics.

In [None]:
# Define parameters: [kappa_w, e_w, M_seed/1e6, v_kin/1e4, eps/1e1]
params_5p = np.array([3.0, 0.5, 0.8, 0.5, 0.1])

# Get all 5-parameter statistics
stats_5p = list_available_statistics()['5-parameter']
# stats_5p = ['GSMF', 'BHMSM', 'fGas', 'CGD', 'CSFR'] # works without Pk

# Create subplots
n_stats = len(stats_5p)
n_cols = 3
n_rows = (n_stats + n_cols - 1) // n_cols

fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 4*n_rows))
axes = axes.flatten() if n_stats > 1 else [axes]

for i, stat_name in enumerate(stats_5p):
    # Load emulator
    emu = load_emulator(stat_name)
    
    # Make prediction
    mean, std = emu.predict(params_5p)
    
    # Get plotting info
    x_grid, _ = get_x_grid(stat_name)
    plot_info = get_plot_info(stat_name)
    
    # Plot
    ax = axes[i]
    ax.plot(x_grid, mean, 'r-', lw=2, label='Mean prediction')
    ax.fill_between(x_grid, mean - 2*std, mean + 2*std, 
                    alpha=0.3, color='blue', label='95% CI (~2σ)')
    ax.set_xscale(plot_info['xscale'])
    ax.set_yscale(plot_info['yscale'])
    ax.set_xlabel(plot_info['xlabel'], fontsize='x-large')
    ax.set_ylabel(plot_info['ylabel'], fontsize='x-large')
    ax.set_title(plot_info['title'], fontsize='x-large')
    ax.legend(fontsize='large', frameon=True, framealpha=0.7, loc='lower center')
    # ax.grid(True, alpha=0.3)

# Hide extra subplots
for i in range(n_stats, len(axes)):
    axes[i].axis('off')

plt.tight_layout()
plt.show()

print(f"\nShowing predictions for {n_stats} summary statistics with 5 parameters")
print(f"Parameters used: {params_5p}")

## 4. All 2-Parameter Models

Display predictions for all available 2-parameter summary statistics.

In [None]:
# Define 2-parameter inputs: [v_kin/1e4, eps/1e1]
params_2p = np.array([0.5, 0.1])

# Get all 2-parameter statistics
stats_2p = list_available_statistics()['2-parameter']

# Create subplots
n_stats = len(stats_2p)
fig, axes = plt.subplots(1, n_stats, figsize=(7*n_stats, 5))
if n_stats == 1:
    axes = [axes]

for i, stat_name in enumerate(stats_2p):
    # Load emulator
    emu = load_emulator(stat_name)
    
    # Make prediction
    mean, std = emu.predict(params_2p)
    
    # Get plotting info
    x_grid, _ = get_x_grid(stat_name)
    plot_info = get_plot_info(stat_name)
    
    # Plot
    ax = axes[i]
    ax.plot(x_grid, mean, 'r-', lw=2, label='Mean prediction')
    ax.fill_between(x_grid, mean - 2*std, mean + 2*std, 
                    alpha=0.3, color='blue', label='95% CI (~2σ)')
    ax.set_xscale(plot_info['xscale'])
    ax.set_yscale(plot_info['yscale'])
    ax.set_xlabel(plot_info['xlabel'], fontsize='x-large')
    ax.set_ylabel(plot_info['ylabel'], fontsize='x-large')
    ax.set_title(plot_info['title'], fontsize='x-large')
    ax.legend(fontsize=10, frameon=True, framealpha=0.7)

plt.tight_layout()
plt.show()

print(f"\nShowing predictions for {n_stats} summary statistics with 2 parameters")
print(f"Parameters used: {params_2p}")

## 5. Parameter Variation Study

Study how predictions change with varying kinetic feedback efficiency.

In [None]:
# Load emulator for gas density profile
emu_cgd = load_emulator('CGD')

# Base parameters
base_params = np.array([3.0, 0.5, 0.8, 0.65, 0.1])

# Vary the last parameter (epsilon_kin)
eps_values = np.linspace(0.05, 0.5, 5)

# Get x-grid
x_grid, _ = get_x_grid('CGD')
plot_info = get_plot_info('CGD')

# Plot
fig, ax = plt.subplots(figsize=(7, 5))

for eps in eps_values:
    params = base_params.copy()
    params[4] = eps
    
    mean, _ = emu_cgd.predict(params)
    
    ax.plot(x_grid, mean, lw=2, label= r'$\epsilon_\text{kin}/10^{1}$ = ' + f'{eps:.2f}')

ax.set_xscale(plot_info['xscale'])
ax.set_yscale(plot_info['yscale'])
ax.set_xlabel(plot_info['xlabel'], fontsize='x-large')
ax.set_ylabel(plot_info['ylabel'], fontsize='x-large')
ax.set_title('Cluster Gas Density profile vs Kinetic Feedback Efficiency', fontsize=14)
ax.legend(fontsize='large', frameon=True, framealpha=0.7)
plt.tight_layout()
plt.show()

## 6. Batch Predictions

Make predictions for multiple parameter sets simultaneously.

In [None]:
# Load emulator
emu = load_emulator('GSMF')

# Create random parameter samples
np.random.seed(42)
n_samples = 8
params_batch = np.random.uniform(
    low=[2.0, 0.2, 0.6, 0.1, 0.02],
    high=[4.0, 1.0, 1.2, 1.2, 0.5],
    size=(n_samples, 5)
)

print("Parameter samples:")
print(params_batch)
print()

# Make predictions
x_grid, _ = get_x_grid('GSMF')
plot_info = get_plot_info('GSMF')

fig, ax = plt.subplots(figsize=(7, 5))

for i, params in enumerate(params_batch):
    mean, _ = emu.predict(params)
    ax.plot(x_grid, mean, lw=1.5, alpha=0.7, label=f'Sample {i+1}')

ax.set_xscale(plot_info['xscale'])
ax.set_yscale(plot_info['yscale'])
ax.set_xlabel(plot_info['xlabel'], fontsize='x-large')
ax.set_ylabel(plot_info['ylabel'], fontsize='x-large')
ax.set_title('GSMF for Different Parameter Sets', fontsize=14)
ax.legend(fontsize='large', frameon=True, framealpha=0.7, ncols=2)
plt.tight_layout()
plt.show()