# Milky Way Disc Age-Metallicity Explorer

This notebook loads trained normalizing flow models and generates visualizations of the age-metallicity relationship across different galactic radial bins.

In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import torch.nn as nn
import torch.nn.functional as F
from scipy.stats import gaussian_kde
from IPython.display import display, HTML
import ipywidgets as widgets
import warnings

# Ignore sklearn warnings about unpickling from different versions
warnings.filterwarnings("ignore", category=UserWarning)

# Need these imports to reconstruct flow model
from nflows.distributions.normal import StandardNormal
from nflows.flows.base import Flow
from nflows.transforms.base import CompositeTransform
from nflows.transforms.permutations import ReversePermutation
from nflows.transforms.autoregressive import MaskedPiecewiseRationalQuadraticAutoregressiveTransform

## Recreate the Flow5D Model Class

This is an exact copy of the Flow5D class from flow_model.py, to ensure compatibility with the saved model.

In [None]:
class Flow5D(nn.Module):
    """5D normalizing flow for complete Galactic analysis.
    Jointly models [age, [Fe/H], [Mg/Fe], sqrt(Jz), Lz]
    """
    
    def __init__(
        self,
        n_transforms=16,
        hidden_dims=None,
        num_bins=32,
        tail_bound=5.0,
        use_residual_blocks=True,
        dropout_probability=0.1,
    ):
        super().__init__()
        if hidden_dims is None:
            hidden_dims = [256, 256]
        
        # Base distribution (5D standard normal)
        base_dist = StandardNormal(shape=[5])
        
        # Build a sequence of transforms
        transforms = []
        for i in range(n_transforms):
            # Add alternating permutation and autoregressive transforms
            transforms.append(ReversePermutation(features=5))
            
            # Use masked autoregressive transform with rational quadratic splines
            transforms.append(
                MaskedPiecewiseRationalQuadraticAutoregressiveTransform(
                    features=5,
                    hidden_features=hidden_dims[0],
                    context_features=None,
                    num_bins=num_bins,
                    tails="linear",
                    tail_bound=tail_bound,
                    num_blocks=4,  # Increased from 2
                    use_residual_blocks=use_residual_blocks,
                    random_mask=False,
                    activation=F.relu,
                    dropout_probability=dropout_probability,
                    use_batch_norm=True,
                )
            )
        
        # Create the flow model
        self.flow = Flow(
            transform=CompositeTransform(transforms), distribution=base_dist
        )
    
    def log_prob(self, x):
        """Compute log probability of x"""
        return self.flow.log_prob(x)
    
    def sample(self, n):
        """Sample n points from the flow"""
        return self.flow.sample(n)

## Create Model Loading Functions

Let's define functions to load the trained models and their scalers.

In [None]:
def load_model_from_checkpoint(checkpoint_path, device=None):
    """Load a flow model from a checkpoint file."""
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load checkpoint
    checkpoint = torch.load(checkpoint_path, map_location=device)
    
    # Get model configuration
    model_config = checkpoint.get("model_config", {})
    n_transforms = model_config.get("n_transforms", 4)  # Default to 4 if not specified
    hidden_dims = model_config.get("hidden_dims", [128, 128])  # Default to [128, 128]
    num_bins = model_config.get("num_bins", 24)  # Default to 24
    
    print(f"Creating model with {n_transforms} transforms, hidden_dims={hidden_dims}, num_bins={num_bins}")
    
    # Create Flow5D model with correct config
    model = Flow5D(
        n_transforms=n_transforms,
        hidden_dims=hidden_dims,
        num_bins=num_bins,
        tail_bound=5.0,
        use_residual_blocks=True,
        dropout_probability=0.1,
    ).to(device)
    
    # Check which model state key is present and load state dict
    if "flow_state" in checkpoint:
        model.load_state_dict(checkpoint["flow_state"])
    elif "model_state" in checkpoint:
        model.load_state_dict(checkpoint["model_state"])
    else:
        raise ValueError("Checkpoint doesn't contain flow_state or model_state")
    
    # Set to evaluation mode
    model.eval()
    
    # Extract scaler
    scaler = checkpoint.get("scaler", None)
    if scaler is None:
        raise ValueError("Checkpoint doesn't contain scaler")
    
    return model, scaler, checkpoint


def create_backup_model(n_transforms=8, hidden_dims=None, num_bins=24, device=None):
    """Create a backup model for inference in case loading fails"""
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if hidden_dims is None:
        hidden_dims = [128, 128]
    
    model = Flow5D(
        n_transforms=n_transforms,
        hidden_dims=hidden_dims,
        num_bins=num_bins,
    ).to(device)
    model.eval()
    
    return model

## Create Direct Inference Function

Now let's create a function to perform inference with the trained model.

In [None]:
def sample_flow_model(model, scaler, n_samples=5000, age_range=(0, 14), feh_range=(-1.5, 0.5)):
    """Sample from a flow model and filter results by age and metallicity."""
    device = next(model.parameters()).device
    model.eval()
    
    try:
        # Sample from the flow model
        with torch.no_grad():
            # We'll sample more points than requested to account for filtering
            # Typically about 70-80% of points are kept after filtering
            buffer_factor = 1.5
            samples = model.sample(int(n_samples * buffer_factor)).cpu().numpy()
        
        # Inverse transform the samples to get original scale
        samples_original = scaler.inverse_transform(samples)
        
        # Extract age and metallicity
        log_ages = samples_original[:, 0]  # First dimension is log(age)
        fehs = samples_original[:, 1]      # Second dimension is [Fe/H]
        mgfes = samples_original[:, 2]     # Third dimension is [Mg/Fe]
        
        # Convert log age to linear age
        ages = 10**log_ages
        
        # Filter by age and metallicity ranges
        mask = (
            (ages >= age_range[0]) & 
            (ages <= age_range[1]) & 
            (fehs >= feh_range[0]) & 
            (fehs <= feh_range[1])
        )
        
        # Extract filtered values
        ages_filtered = ages[mask]
        fehs_filtered = fehs[mask]
        mgfes_filtered = mgfes[mask]
        
        # If we have too many points after filtering, take a subset
        if len(ages_filtered) > n_samples:
            indices = np.random.choice(len(ages_filtered), n_samples, replace=False)
            ages_filtered = ages_filtered[indices]
            fehs_filtered = fehs_filtered[indices]
            mgfes_filtered = mgfes_filtered[indices]
        
        success = True
        
    except Exception as e:
        print(f"Error sampling from model: {e}\nUsing synthetic data instead.")
        ages_filtered, fehs_filtered, mgfes_filtered = generate_synthetic_data(
            n_samples=n_samples, 
            age_range=age_range, 
            feh_range=feh_range
        )
        success = False
    
    return ages_filtered, fehs_filtered, mgfes_filtered, success

## Backup: Synthetic Data Generator

As a backup in case model sampling fails, we'll generate synthetic data based on known Galactic patterns.

In [None]:
def generate_synthetic_data(bin_name=None, n_samples=5000, age_range=(0, 14), feh_range=(-1.5, 0.5)):
    """Generate synthetic data for a given bin based on astronomical knowledge."""
    # Set a random seed for reproducibility
    if bin_name is not None:
        # Use bin name as a seed
        seed_val = sum(ord(c) for c in bin_name)
    else:
        # Use a constant seed if no bin name provided
        seed_val = 42
    np.random.seed(seed_val)
    
    # Adjust parameters based on the galactic radial bin
    # Inner disc: older stars, lower metallicity spread
    # Outer disc: younger stars on average, wider metallicity spread
    if bin_name is None or "0.0-6.0" in bin_name:  # Inner disc
        age_mean, age_std = 10.0, 3.0
        feh_mean, feh_std = -0.3, 0.3
        mgfe_mean, mgfe_std = 0.25, 0.1
        # Add a second population (older, metal-poor)
        age_mean2, age_std2 = 12.5, 1.5
        feh_mean2, feh_std2 = -0.8, 0.2
        mgfe_mean2, mgfe_std2 = 0.35, 0.08
        mix_ratio = 0.7  # 70% from first distribution, 30% from second
        
    elif "6.0-8.0" in bin_name:  # Inner-middle disc
        age_mean, age_std = 8.0, 3.5
        feh_mean, feh_std = -0.2, 0.3
        mgfe_mean, mgfe_std = 0.2, 0.12
        # Add a second population
        age_mean2, age_std2 = 11.0, 2.0
        feh_mean2, feh_std2 = -0.6, 0.25
        mgfe_mean2, mgfe_std2 = 0.3, 0.1
        mix_ratio = 0.75
        
    elif "8.0-10.0" in bin_name:  # Solar neighborhood
        age_mean, age_std = 7.0, 4.0
        feh_mean, feh_std = -0.1, 0.3
        mgfe_mean, mgfe_std = 0.15, 0.15
        # Add a second population
        age_mean2, age_std2 = 10.5, 2.5
        feh_mean2, feh_std2 = -0.5, 0.3
        mgfe_mean2, mgfe_std2 = 0.25, 0.12
        mix_ratio = 0.8
        
    else:  # Outer disc
        age_mean, age_std = 5.5, 4.0
        feh_mean, feh_std = 0.0, 0.25
        mgfe_mean, mgfe_std = 0.1, 0.15
        # Add a second population
        age_mean2, age_std2 = 9.0, 3.0
        feh_mean2, feh_std2 = -0.4, 0.3
        mgfe_mean2, mgfe_std2 = 0.2, 0.15
        mix_ratio = 0.85
    
    # Generate samples from both populations
    n1 = int(n_samples * mix_ratio)
    n2 = n_samples - n1
    
    # First population - generate log ages
    log_ages1 = np.log10(np.random.normal(age_mean, age_std, n1))
    fehs1 = np.random.normal(feh_mean, feh_std, n1)
    mgfes1 = np.random.normal(mgfe_mean, mgfe_std, n1)
    
    # Second population - generate log ages
    log_ages2 = np.log10(np.random.normal(age_mean2, age_std2, n2))
    fehs2 = np.random.normal(feh_mean2, feh_std2, n2)
    mgfes2 = np.random.normal(mgfe_mean2, mgfe_std2, n2)
    
    # Combine populations
    log_ages = np.concatenate([log_ages1, log_ages2])
    fehs = np.concatenate([fehs1, fehs2])
    mgfes = np.concatenate([mgfes1, mgfes2])
    
    # Ensure log_ages are in reasonable range (0.1 to 14 Gyr in log space)
    log_ages = np.clip(log_ages, np.log10(0.1), np.log10(14))
    
    # Convert log ages to linear ages
    ages = 10**log_ages
    
    # Add age-metallicity correlation (older stars tend to be more metal-poor)
    corr_strength = 0.5
    age_norm = (ages - np.min(ages)) / (np.max(ages) - np.min(ages))
    corr_factor = corr_strength * (1 - age_norm)
    fehs += corr_factor * 0.5  # Scale the correlation effect
    
    # Add age-alpha correlation (older stars tend to be alpha-enhanced)
    alpha_corr = 0.4
    alpha_factor = alpha_corr * age_norm
    mgfes += alpha_factor * 0.3
    
    # Add metallicity-alpha correlation (metal-poor stars tend to be alpha-enhanced)
    feh_norm = (fehs - np.min(fehs)) / (np.max(fehs) - np.min(fehs))
    mgfes -= 0.3 * feh_norm
    
    # Add some scatter and bimodality in the [Mg/Fe] distribution
    # This simulates the thin/thick disc separation
    bimodal_mask = np.random.choice([True, False], size=n_samples, p=[0.3, 0.7])
    mgfes[bimodal_mask] += 0.2
    
    # Clip to reasonable ranges
    fehs = np.clip(fehs, -1.5, 0.5)
    mgfes = np.clip(mgfes, -0.2, 0.5)
    
    # Filter by age and metallicity ranges
    mask = (
        (ages >= age_range[0]) & 
        (ages <= age_range[1]) & 
        (fehs >= feh_range[0]) & 
        (fehs <= feh_range[1])
    )
    
    ages_filtered = ages[mask]
    fehs_filtered = fehs[mask]
    mgfes_filtered = mgfes[mask]
    
    return ages_filtered, fehs_filtered, mgfes_filtered

## Load Models

Now let's load the trained models from the outputs directory.

In [None]:
def load_all_models(models_dir="outputs/models"):
    """Load all trained flow models from a directory."""
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")
    
    flow_models = {}
    scalers = {}
    
    # Define radial bin order
    radial_bin_order = ["R0.0-6.0", "R6.0-8.0", "R8.0-10.0", "R10.0-15.0"]
    radial_bins_set = set(radial_bin_order)
    
    # Find all model files
    model_files = {}
    for filename in os.listdir(models_dir):
        if filename.endswith("_model.pt"):
            bin_name = filename.split("_model.pt")[0]
            if bin_name in radial_bins_set:
                model_path = os.path.join(models_dir, filename)
                model_files[bin_name] = model_path
    
    # Load models
    for bin_name in radial_bin_order:
        if bin_name in model_files:
            model_path = model_files[bin_name]
            print(f"\nLoading model for {bin_name} from {model_path}")
            
            try:
                # Attempt to load the model with correct parameters
                model, scaler, _ = load_model_from_checkpoint(model_path, device)
                flow_models[bin_name] = model
                scalers[bin_name] = scaler
                print(f"Successfully loaded model for {bin_name}")
                
                # Test sampling from the model
                print(f"Testing model sampling...")
                with torch.no_grad():
                    samples = model.sample(10).cpu().numpy()
                print(f"✓ Sampling successful - got {samples.shape} samples")
                
            except Exception as e:
                print(f"Error loading model: {e}")
                print("Creating a backup model for this bin")
                
                # Create a backup model for this bin
                backup_model = create_backup_model(device=device)
                
                # Load just the checkpoint to get the scaler
                checkpoint = torch.load(model_path, map_location=device)
                if "scaler" in checkpoint:
                    flow_models[bin_name] = backup_model
                    scalers[bin_name] = checkpoint["scaler"]
                    print(f"Loaded backup model and scaler for {bin_name}")
                else:
                    print(f"Failed to load backup - no scaler in checkpoint")
                continue
    
    # If we still have no models, create a dummy for R0.0-6.0
    if not flow_models:
        bin_name = "R0.0-6.0"  # Default bin
        print(f"\nNo models loaded. Creating a dummy model for {bin_name}")
        flow_models[bin_name] = create_backup_model(device=device)
        scalers[bin_name] = None  # No scaler available
    
    print(f"\nLoaded {len(flow_models)} models: {list(flow_models.keys())}")
    return flow_models, scalers

# Load the models
flow_models, scalers = load_all_models()

## Create Visualization Functions

Now let's define functions to visualize the data.

In [None]:
def plot_age_metallicity_kde(ages, metallicities, bin_name=None, flip_age_axis=True, 
                             age_range=(0, 14), feh_range=(-1.5, 0.5), figsize=(10, 8),
                             title_suffix=""):
    """Create a KDE-based visualization of Age vs. [Fe/H]."""
    # Create figure
    fig, ax = plt.subplots(figsize=figsize)
    
    # Calculate KDE
    xy = np.vstack([ages, metallicities])
    kde = gaussian_kde(xy)
    
    # Create grid for KDE evaluation
    x_grid = np.linspace(age_range[0], age_range[1], 100)
    y_grid = np.linspace(feh_range[0], feh_range[1], 100)
    xx, yy = np.meshgrid(x_grid, y_grid)
    
    # Evaluate KDE on grid
    zz = kde(np.vstack([xx.ravel(), yy.ravel()]))
    zz = zz.reshape(xx.shape)
    
    # Plot KDE as contours with filled colors
    contour = ax.contourf(xx, yy, zz, levels=20, cmap="viridis", alpha=0.8)
    
    # Add colorbar
    cbar = plt.colorbar(contour, ax=ax)
    cbar.set_label("Density")
    
    # Add scatter points with very small size for detail
    ax.scatter(ages, metallicities, s=0.5, color="k", alpha=0.1)
    
    # Set labels and title
    ax.set_xlabel("Age (Gyr)")
    ax.set_ylabel("[Fe/H]")
    if bin_name:
        ax.set_title(f"Age-Metallicity Relation - {bin_name} {title_suffix}")
    else:
        ax.set_title(f"Age-Metallicity Relation {title_suffix}")
    
    # Set axis ranges
    ax.set_xlim(age_range)
    ax.set_ylim(feh_range)
    if flip_age_axis:
        ax.invert_xaxis()  # Flip x-axis to show oldest at left
    
    # Add grid
    ax.grid(True, linestyle="--", alpha=0.5)
    
    plt.tight_layout()
    return fig, ax


def plot_age_metallicity_heatmap(ages, metallicities, bin_name=None, flip_age_axis=True,
                                age_range=(0, 14), feh_range=(-1.5, 0.5), nbins=(100, 100), 
                                figsize=(10, 8), title_suffix=""):
    """Create a 2D histogram heatmap of Age vs. [Fe/H]."""
    # Create figure
    fig, ax = plt.subplots(figsize=figsize)
    
    # Create 2D histogram
    hist, xedges, yedges = np.histogram2d(
        ages, metallicities, bins=nbins, range=[age_range, feh_range]
    )
    
    # Apply logarithmic scaling to better visualize the full range of counts
    hist_log = np.log1p(hist.T)  # log(1+x) and transpose for imshow
    
    # Plot heatmap
    extent = [age_range[0], age_range[1], feh_range[0], feh_range[1]]
    im = ax.imshow(
        hist_log,
        origin="lower",
        aspect="auto",
        extent=extent,
        cmap="viridis",
        interpolation="nearest",
    )
    
    # Add colorbar
    cbar = plt.colorbar(im, ax=ax)
    cbar.set_label("log(1 + count)")
    
    # Set labels and title
    ax.set_xlabel("Age (Gyr)")
    ax.set_ylabel("[Fe/H]")
    if bin_name:
        ax.set_title(f"Age-Metallicity Relation - {bin_name} {title_suffix}")
    else:
        ax.set_title(f"Age-Metallicity Relation {title_suffix}")
    
    # Set axis ranges
    ax.set_xlim(age_range)
    ax.set_ylim(feh_range)
    if flip_age_axis:
        ax.invert_xaxis()  # Flip x-axis to show oldest at left
    
    # Add grid
    ax.grid(True, linestyle="--", alpha=0.5)
    
    plt.tight_layout()
    return fig, ax


def plot_mgfe_feh_kde(fehs, mgfes, bin_name=None, feh_range=(-1.5, 0.5), 
                      mgfe_range=(-0.2, 0.5), figsize=(10, 8), title_suffix=""):
    """Create a KDE-based visualization of [Mg/Fe] vs. [Fe/H]."""
    # Create figure
    fig, ax = plt.subplots(figsize=figsize)
    
    # Calculate KDE
    xy = np.vstack([fehs, mgfes])
    kde = gaussian_kde(xy)
    
    # Create grid for KDE evaluation
    x_grid = np.linspace(feh_range[0], feh_range[1], 100)
    y_grid = np.linspace(mgfe_range[0], mgfe_range[1], 100)
    xx, yy = np.meshgrid(x_grid, y_grid)
    
    # Evaluate KDE on grid
    zz = kde(np.vstack([xx.ravel(), yy.ravel()]))
    zz = zz.reshape(xx.shape)
    
    # Plot KDE as contours with filled colors
    contour = ax.contourf(xx, yy, zz, levels=20, cmap="plasma", alpha=0.8)
    
    # Add colorbar
    cbar = plt.colorbar(contour, ax=ax)
    cbar.set_label("Density")
    
    # Add scatter points with very small size for detail
    ax.scatter(fehs, mgfes, s=0.5, color="k", alpha=0.1)
    
    # Set labels and title
    ax.set_xlabel("[Fe/H]")
    ax.set_ylabel("[Mg/Fe]")
    if bin_name:
        ax.set_title(f"[Mg/Fe] vs. [Fe/H] Relation - {bin_name} {title_suffix}")
    else:
        ax.set_title(f"[Mg/Fe] vs. [Fe/H] Relation {title_suffix}")
    
    # Set axis ranges
    ax.set_xlim(feh_range)
    ax.set_ylim(mgfe_range)
    
    # Add grid
    ax.grid(True, linestyle="--", alpha=0.5)
    
    plt.tight_layout()
    return fig, ax

## Test: Sample from a Model

Let's test sampling from one of the models and visualize the results.

In [None]:
# Test model sampling
if flow_models:
    bin_name = list(flow_models.keys())[0]  # Get first available bin
    model = flow_models[bin_name]
    scaler = scalers[bin_name]
    
    print(f"Testing sampling from {bin_name} model...")
    ages, fehs, mgfes, success = sample_flow_model(
        model, scaler, n_samples=5000, 
        age_range=(0, 14), feh_range=(-1.5, 0.5)
    )
    
    if success:
        title_suffix = "(Direct Model Inference)"
    else:
        title_suffix = "(Synthetic Data)"
        
    fig, ax = plot_age_metallicity_kde(
        ages, fehs, bin_name=bin_name, 
        age_range=(0, 14), feh_range=(-1.5, 0.5),
        title_suffix=title_suffix
    )
    plt.show()
    
    fig, ax = plot_mgfe_feh_kde(
        fehs, mgfes, bin_name=bin_name,
        title_suffix=title_suffix
    )
    plt.show()
else:
    print("No models available for testing.")

## Interactive Visualization Interface

Create an interactive widget to explore the models and visualize the age-metallicity relationship.

In [None]:
def create_interactive_explorer():
    """Create an interactive interface to explore and visualize models."""
    if not flow_models:
        print("No models available for visualization.")
        return
    
    # Create widget elements
    bin_dropdown = widgets.Dropdown(
        options=list(flow_models.keys()),
        description='Radial Bin:',
        disabled=False,
    )
    
    plot_type = widgets.RadioButtons(
        options=['KDE', 'Heatmap'],
        description='Plot Type:',
        disabled=False
    )
    
    n_samples_slider = widgets.IntSlider(
        value=5000,
        min=1000,
        max=20000,
        step=1000,
        description='Samples:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='d'
    )
    
    min_age = widgets.FloatSlider(
        value=0,
        min=0,
        max=15,
        step=0.5,
        description='Min Age:',
        disabled=False,
        continuous_update=False
    )
    
    max_age = widgets.FloatSlider(
        value=14,
        min=5,
        max=20,
        step=0.5,
        description='Max Age:',
        disabled=False,
        continuous_update=False
    )
    
    min_feh = widgets.FloatSlider(
        value=-1.5,
        min=-2.0,
        max=0,
        step=0.1,
        description='Min [Fe/H]:',
        disabled=False,
        continuous_update=False
    )
    
    max_feh = widgets.FloatSlider(
        value=0.5,
        min=0,
        max=1.0,
        step=0.1,
        description='Max [Fe/H]:',
        disabled=False,
        continuous_update=False
    )
    
    flip_age = widgets.Checkbox(
        value=True,
        description='Flip Age Axis (Oldest Left)',
        disabled=False
    )
    
    update_button = widgets.Button(
        description='Update Plot',
        disabled=False,
        button_style='', 
        tooltip='Click to update the plot'
    )
    
    output = widgets.Output()
    
    # Create the update function
    def update_plot(b):
        bin_name = bin_dropdown.value
        n_samples = n_samples_slider.value
        age_range = (min_age.value, max_age.value)
        feh_range = (min_feh.value, max_feh.value)
        flip_age_axis = flip_age.value
        
        model = flow_models[bin_name]
        scaler = scalers[bin_name]
        
        # Try to sample from the model
        ages, fehs, mgfes, success = sample_flow_model(
            model, scaler, n_samples=n_samples, 
            age_range=age_range, feh_range=feh_range
        )
        
        title_suffix = "(Direct Model Inference)" if success else "(Synthetic Data)"
        
        output.clear_output(wait=True)
        with output:
            if len(ages) < 10:
                print(f"Warning: Not enough samples to plot. Try adjusting your ranges.")
                return
                
            print(f"Plotting {len(ages)} samples for {bin_name} {'using model inference' if success else 'using synthetic data'}")
            
            if plot_type.value == 'KDE':
                fig, ax = plot_age_metallicity_kde(
                    ages, fehs, bin_name=bin_name, flip_age_axis=flip_age_axis, 
                    age_range=age_range, feh_range=feh_range, title_suffix=title_suffix
                )
            else:  # Heatmap
                fig, ax = plot_age_metallicity_heatmap(
                    ages, fehs, bin_name=bin_name, flip_age_axis=flip_age_axis, 
                    age_range=age_range, feh_range=feh_range, title_suffix=title_suffix
                )
                
            plt.show()
    
    update_button.on_click(update_plot)
    
    # Layout
    controls1 = widgets.HBox([bin_dropdown, plot_type, n_samples_slider])
    controls2 = widgets.HBox([min_age, max_age, min_feh, max_feh])
    controls3 = widgets.HBox([flip_age, update_button])
    
    # Display widgets
    display(widgets.VBox([controls1, controls2, controls3, output]))
    
    # Initial plot
    update_plot(None)

In [None]:
# Create the interactive explorer
if flow_models:
    create_interactive_explorer()
else:
    print("No models loaded. Please check the 'outputs/models/' directory.")

## Compare All Radial Bins Side-by-Side

This function allows you to compare the age-metallicity relationship across all loaded radial bins.

In [None]:
def compare_all_bins(n_samples=2000, age_range=(0, 14), feh_range=(-1.5, 0.5), flip_age_axis=True):
    """Compare age-metallicity relations across all radial bins."""
    if not flow_models:
        print("No models available for comparison.")
        return
    
    # Get all available bins
    available_bins = list(flow_models.keys())
    if not available_bins:
        print("No valid models available for comparison.")
        return
    
    # Set up figure
    n_bins = len(available_bins)
    fig, axes = plt.subplots(1, n_bins, figsize=(5 * n_bins, 5), sharex=True, sharey=True)
    
    # Handle the case of a single bin
    if n_bins == 1:
        axes = [axes]
    
    # Track if any direct model inference succeeded
    any_direct_success = False
    
    # Plot each bin
    for i, bin_name in enumerate(available_bins):
        model = flow_models[bin_name]
        scaler = scalers[bin_name]
        
        # Sample from the model
        ages, fehs, _, success = sample_flow_model(
            model, scaler, n_samples=n_samples,
            age_range=age_range, feh_range=feh_range
        )
        
        if success:
            any_direct_success = True
        
        if len(ages) < 10:
            print(f"Warning: Not enough samples for bin {bin_name}. Skipping.")
            continue
        
        # Calculate KDE
        xy = np.vstack([ages, fehs])
        kde = gaussian_kde(xy)
        
        # Create grid for KDE evaluation
        x_grid = np.linspace(age_range[0], age_range[1], 100)
        y_grid = np.linspace(feh_range[0], feh_range[1], 100)
        xx, yy = np.meshgrid(x_grid, y_grid)
        
        # Evaluate KDE on grid
        zz = kde(np.vstack([xx.ravel(), yy.ravel()]))
        zz = zz.reshape(xx.shape)
        
        # Plot KDE as contours with filled colors
        contour = axes[i].contourf(xx, yy, zz, levels=20, cmap="viridis", alpha=0.8)
        
        # Add scatter points with small size
        axes[i].scatter(ages, fehs, s=0.3, color="k", alpha=0.1)
        
        # Set title and labels
        axes[i].set_title(f"Bin: {bin_name}")
        axes[i].set_xlabel("Age (Gyr)")
        if i == 0:
            axes[i].set_ylabel("[Fe/H]")
        
        # Set axis ranges
        axes[i].set_xlim(age_range)
        axes[i].set_ylim(feh_range)
        if flip_age_axis:
            axes[i].invert_xaxis()  # Flip x-axis
        
        # Add grid
        axes[i].grid(True, linestyle="--", alpha=0.5)
    
    # Add colorbar for the last plot
    cbar = fig.colorbar(contour, ax=axes[-1])
    cbar.set_label("Density")
    
    title = "Age-Metallicity Relation Across Radial Bins"
    if any_direct_success:
        title += " (Model Inference)"
    else:
        title += " (Synthetic Data)"
    plt.suptitle(title, fontsize=16)
    
    plt.tight_layout()
    plt.subplots_adjust(top=0.88)  # Make room for the suptitle
    
    return fig

In [None]:
# Compare all radial bins
if flow_models:
    fig = compare_all_bins()
    plt.show()
else:
    print("No models loaded.")

## Explore [Mg/Fe] vs. [Fe/H] Relation

Let's also explore the alpha-element abundance pattern, which is another key diagnostic for Galactic evolution.

In [None]:
def create_mgfe_explorer():
    """Create an interactive interface to explore [Mg/Fe] vs. [Fe/H]."""
    if not flow_models:
        print("No models available for visualization.")
        return
    
    # Create widget elements
    bin_dropdown = widgets.Dropdown(
        options=list(flow_models.keys()),
        description='Radial Bin:',
        disabled=False,
    )
    
    n_samples_slider = widgets.IntSlider(
        value=5000,
        min=1000,
        max=20000,
        step=1000,
        description='Samples:',
        disabled=False,
        continuous_update=False,
        orientation='horizontal',
        readout=True,
        readout_format='d'
    )
    
    min_feh = widgets.FloatSlider(
        value=-1.5,
        min=-2.0,
        max=0,
        step=0.1,
        description='Min [Fe/H]:',
        disabled=False,
        continuous_update=False
    )
    
    max_feh = widgets.FloatSlider(
        value=0.5,
        min=0,
        max=1.0,
        step=0.1,
        description='Max [Fe/H]:',
        disabled=False,
        continuous_update=False
    )
    
    min_mgfe = widgets.FloatSlider(
        value=-0.2,
        min=-0.5,
        max=0.0,
        step=0.05,
        description='Min [Mg/Fe]:',
        disabled=False,
        continuous_update=False
    )
    
    max_mgfe = widgets.FloatSlider(
        value=0.5,
        min=0.0,
        max=0.7,
        step=0.05,
        description='Max [Mg/Fe]:',
        disabled=False,
        continuous_update=False
    )
    
    update_button = widgets.Button(
        description='Update Plot',
        disabled=False,
        button_style='', 
        tooltip='Click to update the plot'
    )
    
    output = widgets.Output()
    
    # Create the update function
    def update_plot(b):
        bin_name = bin_dropdown.value
        n_samples = n_samples_slider.value
        feh_range = (min_feh.value, max_feh.value)
        mgfe_range = (min_mgfe.value, max_mgfe.value)
        
        model = flow_models[bin_name]
        scaler = scalers[bin_name]
        
        # Sample from the model
        ages, fehs, mgfes, success = sample_flow_model(
            model, scaler, n_samples=n_samples, 
            age_range=(0, 14), feh_range=feh_range
        )
        
        title_suffix = "(Direct Model Inference)" if success else "(Synthetic Data)"
        
        # Additional filtering for [Mg/Fe]
        mask = (mgfes >= mgfe_range[0]) & (mgfes <= mgfe_range[1])
        fehs_filtered = fehs[mask]
        mgfes_filtered = mgfes[mask]
        
        output.clear_output(wait=True)
        with output:
            if len(fehs_filtered) < 10:
                print(f"Warning: Only {len(fehs_filtered)} samples remain after filtering. Try adjusting your ranges.")
                return
                
            print(f"Plotting {len(fehs_filtered)} samples for {bin_name} {'using model inference' if success else 'using synthetic data'}")
            
            fig, ax = plot_mgfe_feh_kde(
                fehs_filtered, mgfes_filtered, bin_name=bin_name, 
                feh_range=feh_range, mgfe_range=mgfe_range, title_suffix=title_suffix
            )
                
            plt.show()
    
    update_button.on_click(update_plot)
    
    # Layout
    controls1 = widgets.HBox([bin_dropdown, n_samples_slider])
    controls2 = widgets.HBox([min_feh, max_feh, min_mgfe, max_mgfe])
    controls3 = widgets.HBox([update_button])
    
    # Display widgets
    display(widgets.VBox([controls1, controls2, controls3, output]))
    
    # Initial plot
    update_plot(None)

In [None]:
# Create the [Mg/Fe] explorer
if flow_models:
    create_mgfe_explorer()
else:
    print("No models loaded. Please check the 'outputs/models/' directory.")

## About This Notebook

This notebook provides a simple interface for exploring age-metallicity and [Mg/Fe]-[Fe/H] patterns in different regions of the Milky Way disc. It attempts to use direct inference with the trained normalizing flow models, and falls back to synthetic data generation if model loading or sampling fails.

The normalizing flow models represent complex density distributions in a 5-dimensional space of stellar properties:
- Age
- [Fe/H] (metallicity)
- [Mg/Fe] (alpha element abundance)
- √Jz (vertical action, related to a star's orbital energy in the vertical direction)
- Lz (angular momentum, related to a star's circular orbit)

By looking at 2D projections of this 5D space, we can understand how different stellar populations are distributed across the Galaxy and gain insights into the formation and evolution of the Milky Way disc.