{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Milky Way Disc Age-Metallicity Explorer (3D Version)\n",
    "\n",
    "This notebook loads trained 3D normalizing flow models and generates visualizations of the age-metallicity relationship across different galactic radial bins. This is a simplified version that focuses only on age, [Fe/H], and [Mg/Fe]."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import torch\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from scipy.stats import gaussian_kde\n",
    "from IPython.display import display, HTML\n",
    "import ipywidgets as widgets\n",
    "import warnings\n",
    "\n",
    "# Ignore sklearn warnings about unpickling from different versions\n",
    "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
    "\n",
    "# Need these imports to reconstruct flow model\n",
    "from nflows.distributions.normal import StandardNormal\n",
    "from nflows.flows.base import Flow\n",
    "from nflows.transforms.base import CompositeTransform\n",
    "from nflows.transforms.permutations import ReversePermutation\n",
    "from nflows.transforms.autoregressive import MaskedPiecewiseRationalQuadraticAutoregressiveTransform\n",
    "from nflows.transforms.coupling import PiecewiseRationalQuadraticCouplingTransform"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Recreate the Flow3D Model Class\n",
    "\n",
    "This is an exact copy of the Flow3D class from flow_model.py, to ensure compatibility with the saved model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_alternating_binary_mask(features, even=True):\n",
    "    \"\"\"\n",
    "    Creates a binary mask of a given dimension which alternates between 0 and 1.\n",
    "    \"\"\"\n",
    "    mask = torch.zeros(features)\n",
    "    start = 0 if even else 1\n",
    "    mask[start::2] = 1\n",
    "    return mask\n",
    "\n",
    "class Flow3D(nn.Module):\n",
    "    \"\"\"\n",
    "    3D normalizing flow for analyzing [age, [Fe/H], [Mg/Fe]] jointly.\n",
    "    Enhanced version with more flexibility.\n",
    "    \"\"\"\n",
    "\n",
    "    def __init__(self, n_transforms=16, hidden_dims=None, num_bins=32, tail_bound=5.0):\n",
    "        super().__init__()\n",
    "        if hidden_dims is None:\n",
    "            hidden_dims = [256, 256]\n",
    "\n",
    "        # Base distribution (3D standard normal)\n",
    "        base_dist = StandardNormal(shape=[3])\n",
    "\n",
    "        # Build a sequence of transforms\n",
    "        transforms = []\n",
    "        for i in range(n_transforms):\n",
    "            # Add alternating permutation and coupling transforms\n",
    "            transforms.append(ReversePermutation(features=3))\n",
    "\n",
    "            # Use advanced transform with higher capacity\n",
    "            transforms.append(\n",
    "                MaskedPiecewiseRationalQuadraticAutoregressiveTransform(\n",
    "                    features=3,\n",
    "                    hidden_features=hidden_dims[0],\n",
    "                    context_features=None,\n",
    "                    num_bins=num_bins,\n",
    "                    tails=\"linear\",\n",
    "                    tail_bound=tail_bound,\n",
    "                    num_blocks=4,  # Increased from 2\n",
    "                    use_residual_blocks=True,\n",
    "                    random_mask=False,\n",
    "                    activation=F.relu,\n",
    "                    dropout_probability=0.1,\n",
    "                    use_batch_norm=True,\n",
    "                )\n",
    "            )\n",
    "\n",
    "        # Create the flow model\n",
    "        self.flow = Flow(\n",
    "            transform=CompositeTransform(transforms), distribution=base_dist\n",
    "        )\n",
    "\n",
    "    def log_prob(self, x):\n",
    "        \"\"\"Compute log probability of x\"\"\"\n",
    "        return self.flow.log_prob(x)\n",
    "\n",
    "    def sample(self, n):\n",
    "        \"\"\"Sample n points from the flow\"\"\"\n",
    "        return self.flow.sample(n)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create Model Loading Functions\n",
    "\n",
    "Let's define functions to load the trained 3D models and their scalers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_model_from_checkpoint(checkpoint_path, device=None):\n",
    "    \"\"\"Load a flow model from a checkpoint file.\"\"\"\n",
    "    if device is None:\n",
    "        device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    \n",
    "    # Load checkpoint\n",
    "    checkpoint = torch.load(checkpoint_path, map_location=device)\n",
    "    \n",
    "    # Get model configuration\n",
    "    model_config = checkpoint.get(\"model_config\", {})\n",
    "    n_transforms = model_config.get(\"n_transforms\", 12)  # Default to 12\n",
    "    hidden_dims = model_config.get(\"hidden_dims\", [128, 128])\n",
    "    num_bins = model_config.get(\"num_bins\", 24)\n",
    "    \n",
    "    print(f\"Creating model with {n_transforms} transforms, hidden_dims={hidden_dims}, num_bins={num_bins}\")\n",
    "    \n",
    "    # Create Flow3D model with correct config\n",
    "    model = Flow3D(\n",
    "        n_transforms=n_transforms,\n",
    "        hidden_dims=hidden_dims,\n",
    "        num_bins=num_bins,\n",
    "        tail_bound=5.0,\n",
    "    ).to(device)\n",
    "    \n",
    "    # Check which model state key is present and load state dict\n",
    "    if \"flow_state\" in checkpoint:\n",
    "        model.load_state_dict(checkpoint[\"flow_state\"])\n",
    "    elif \"model_state\" in checkpoint:\n",
    "        model.load_state_dict(checkpoint[\"model_state\"])\n",
    "    else:\n",
    "        raise ValueError(\"Checkpoint doesn't contain flow_state or model_state\")\n",
    "    \n",
    "    # Set to evaluation mode\n",
    "    model.eval()\n",
    "    \n",
    "    # Extract scaler\n",
    "    scaler = checkpoint.get(\"scaler\", None)\n",
    "    if scaler is None:\n",
    "        raise ValueError(\"Checkpoint doesn't contain scaler\")\n",
    "    \n",
    "    return model, scaler, checkpoint\n",
    "\n",
    "\n",
    "def create_backup_model(n_transforms=8, hidden_dims=None, num_bins=24, device=None):\n",
    "    \"\"\"Create a backup model for inference in case loading fails\"\"\"\n",
    "    if device is None:\n",
    "        device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    if hidden_dims is None:\n",
    "        hidden_dims = [128, 128]\n",
    "    \n",
    "    model = Flow3D(\n",
    "        n_transforms=n_transforms,\n",
    "        hidden_dims=hidden_dims,\n",
    "        num_bins=num_bins,\n",
    "    ).to(device)\n",
    "    model.eval()\n",
    "    \n",
    "    return model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create Direct Inference Function\n",
    "\n",
    "Now let's create a function to perform inference with the trained model."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sample_flow_model(model, scaler, n_samples=5000, age_range=(0, 14), feh_range=(-1.5, 0.5)):\n",
    "    \"\"\"Sample from a flow model and filter results by age and metallicity.\"\"\"\n",
    "    device = next(model.parameters()).device\n",
    "    model.eval()\n",
    "    \n",
    "    try:\n",
    "        # Sample from the flow model\n",
    "        with torch.no_grad():\n",
    "            # We'll sample more points than requested to account for filtering\n",
    "            buffer_factor = 1.5\n",
    "            samples = model.sample(int(n_samples * buffer_factor)).cpu().numpy()\n",
    "        \n",
    "        # Inverse transform the samples to get original scale\n",
    "        samples_original = scaler.inverse_transform(samples)\n",
    "        \n",
    "        # Extract age and metallicity and Mg/Fe\n",
    "        log_ages = samples_original[:, 0]  # First dimension is log(age)\n",
    "        fehs = samples_original[:, 1]      # Second dimension is [Fe/H]\n",
    "        mgfes = samples_original[:, 2]     # Third dimension is [Mg/Fe]\n",
    "        \n",
    "        # Convert log age to linear age\n",
    "        ages = 10**log_ages\n",
    "        \n",
    "        # Filter by age and metallicity ranges\n",
    "        mask = (\n",
    "            (ages >= age_range[0]) & \n",
    "            (ages <= age_range[1]) & \n",
    "            (fehs >= feh_range[0]) & \n",
    "            (fehs <= feh_range[1])\n",
    "        )\n",
    "        \n",
    "        # Extract filtered values\n",
    "        ages_filtered = ages[mask]\n",
    "        fehs_filtered = fehs[mask]\n",
    "        mgfes_filtered = mgfes[mask]\n",
    "        \n",
    "        # If we have too many points after filtering, take a subset\n",
    "        if len(ages_filtered) > n_samples:\n",
    "            indices = np.random.choice(len(ages_filtered), n_samples, replace=False)\n",
    "            ages_filtered = ages_filtered[indices]\n",
    "            fehs_filtered = fehs_filtered[indices]\n",
    "            mgfes_filtered = mgfes_filtered[indices]\n",
    "        \n",
    "        success = True\n",
    "        \n",
    "    except Exception as e:\n",
    "        print(f\"Error sampling from model: {e}\\nUsing synthetic data instead.\")\n",
    "        ages_filtered, fehs_filtered, mgfes_filtered = generate_synthetic_data(\n",
    "            n_samples=n_samples, \n",
    "            age_range=age_range, \n",
    "            feh_range=feh_range\n",
    "        )\n",
    "        success = False\n",
    "    \n",
    "    return ages_filtered, fehs_filtered, mgfes_filtered, success"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Backup: Synthetic Data Generator\n",
    "\n",
    "As a backup in case model sampling fails, we'll generate synthetic data based on known Galactic patterns."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def generate_synthetic_data(bin_name=None, n_samples=5000, age_range=(0, 14), feh_range=(-1.5, 0.5)):\n",
    "    \"\"\"Generate synthetic data for a given bin based on astronomical knowledge.\"\"\"\n",
    "    # Set a random seed for reproducibility\n",
    "    if bin_name is not None:\n",
    "        # Use bin name as a seed\n",
    "        seed_val = sum(ord(c) for c in bin_name)\n",
    "    else:\n",
    "        # Use a constant seed if no bin name provided\n",
    "        seed_val = 42\n",
    "    np.random.seed(seed_val)\n",
    "    \n",
    "    # Adjust parameters based on the galactic radial bin\n",
    "    # Inner disc: older stars, lower metallicity spread\n",
    "    # Outer disc: younger stars on average, wider metallicity spread\n",
    "    if bin_name is None or \"0.0-6.0\" in bin_name:  # Inner disc\n",
    "        age_mean, age_std = 10.0, 3.0\n",
    "        feh_mean, feh_std = -0.3, 0.3\n",
    "        mgfe_mean, mgfe_std = 0.25, 0.1\n",
    "        # Add a second population (older, metal-poor)\n",
    "        age_mean2, age_std2 = 12.5, 1.5\n",
    "        feh_mean2, feh_std2 = -0.8, 0.2\n",
    "        mgfe_mean2, mgfe_std2 = 0.35, 0.08\n",
    "        mix_ratio = 0.7  # 70% from first distribution, 30% from second\n",
    "        \n",
    "    elif \"6.0-8.0\" in bin_name:  # Inner-middle disc\n",
    "        age_mean, age_std = 8.0, 3.5\n",
    "        feh_mean, feh_std = -0.2, 0.3\n",
    "        mgfe_mean, mgfe_std = 0.2, 0.12\n",
    "        # Add a second population\n",
    "        age_mean2, age_std2 = 11.0, 2.0\n",
    "        feh_mean2, feh_std2 = -0.6, 0.25\n",
    "        mgfe_mean2, mgfe_std2 = 0.3, 0.1\n",
    "        mix_ratio = 0.75\n",
    "        \n",
    "    elif \"8.0-10.0\" in bin_name:  # Solar neighborhood\n",
    "        age_mean, age_std = 7.0, 4.0\n",
    "        feh_mean, feh_std = -0.1, 0.3\n",
    "        mgfe_mean, mgfe_std = 0.15, 0.15\n",
    "        # Add a second population\n",
    "        age_mean2, age_std2 = 10.5, 2.5\n",
    "        feh_mean2, feh_std2 = -0.5, 0.3\n",
    "        mgfe_mean2, mgfe_std2 = 0.25, 0.12\n",
    "        mix_ratio = 0.8\n",
    "        \n",
    "    else:  # Outer disc\n",
    "        age_mean, age_std = 5.5, 4.0\n",
    "        feh_mean, feh_std = 0.0, 0.25\n",
    "        mgfe_mean, mgfe_std = 0.1, 0.15\n",
    "        # Add a second population\n",
    "        age_mean2, age_std2 = 9.0, 3.0\n",
    "        feh_mean2, feh_std2 = -0.4, 0.3\n",
    "        mgfe_mean2, mgfe_std2 = 0.2, 0.15\n",
    "        mix_ratio = 0.85\n",
    "    \n",
    "    # Generate samples from both populations\n",
    "    n1 = int(n_samples * mix_ratio)\n",
    "    n2 = n_samples - n1\n",
    "    \n",
    "    # First population - generate log ages\n",
    "    log_ages1 = np.log10(np.random.normal(age_mean, age_std, n1))\n",
    "    fehs1 = np.random.normal(feh_mean, feh_std, n1)\n",
    "    mgfes1 = np.random.normal(mgfe_mean, mgfe_std, n1)\n",
    "    \n",
    "    # Second population - generate log ages\n",
    "    log_ages2 = np.log10(np.random.normal(age_mean2, age_std2, n2))\n",
    "    fehs2 = np.random.normal(feh_mean2, feh_std2, n2)\n",
    "    mgfes2 = np.random.normal(mgfe_mean2, mgfe_std2, n2)\n",
    "    \n",
    "    # Combine populations\n",
    "    log_ages = np.concatenate([log_ages1, log_ages2])\n",
    "    fehs = np.concatenate([fehs1, fehs2])\n",
    "    mgfes = np.concatenate([mgfes1, mgfes2])\n",
    "    \n",
    "    # Ensure log_ages are in reasonable range (0.1 to 14 Gyr in log space)\n",
    "    log_ages = np.clip(log_ages, np.log10(0.1), np.log10(14))\n",
    "    \n",
    "    # Convert log ages to linear ages\n",
    "    ages = 10**log_ages\n",
    "    \n",
    "    # Add age-metallicity correlation (older stars tend to be more metal-poor)\n",
    "    corr_strength = 0.5\n",
    "    age_norm = (ages - np.min(ages)) / (np.max(ages) - np.min(ages))\n",
    "    corr_factor = corr_strength * (1 - age_norm)\n",
    "    fehs += corr_factor * 0.5  # Scale the correlation effect\n",
    "    \n",
    "    # Add age-alpha correlation (older stars tend to be alpha-enhanced)\n",
    "    alpha_corr = 0.4\n",
    "    alpha_factor = alpha_corr * age_norm\n",
    "    mgfes += alpha_factor * 0.3\n",
    "    \n",
    "    # Add metallicity-alpha correlation (metal-poor stars tend to be alpha-enhanced)\n",
    "    feh_norm = (fehs - np.min(fehs)) / (np.max(fehs) - np.min(fehs))\n",
    "    mgfes -= 0.3 * feh_norm\n",
    "    \n",
    "    # Add some scatter and bimodality in the [Mg/Fe] distribution\n",
    "    # This simulates the thin/thick disc separation\n",
    "    bimodal_mask = np.random.choice([True, False], size=n_samples, p=[0.3, 0.7])\n",
    "    mgfes[bimodal_mask] += 0.2\n",
    "    \n",
    "    # Clip to reasonable ranges\n",
    "    fehs = np.clip(fehs, -1.5, 0.5)\n",
    "    mgfes = np.clip(mgfes, -0.2, 0.5)\n",
    "    \n",
    "    # Filter by age and metallicity ranges\n",
    "    mask = (\n",
    "        (ages >= age_range[0]) & \n",
    "        (ages <= age_range[1]) & \n",
    "        (fehs >= feh_range[0]) & \n",
    "        (fehs <= feh_range[1])\n",
    "    )\n",
    "    \n",
    "    ages_filtered = ages[mask]\n",
    "    fehs_filtered = fehs[mask]\n",
    "    mgfes_filtered = mgfes[mask]\n",
    "    \n",
    "    return ages_filtered, fehs_filtered, mgfes_filtered"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load Models\n",
    "\n",
    "Now let's load the trained 3D models from the outputs directory."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_all_models(models_dir=\"outputs/models\", model_suffix=\"_3d_model.pt\"):\n",
    "    \"\"\"Load all trained 3D flow models from a directory.\"\"\"\n",
    "    device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "    print(f\"Using device: {device}\")\n",
    "    \n",
    "    flow_models = {}\n",
    "    scalers = {}\n",
    "    \n",
    "    # Define radial bin order\n",
    "    radial_bin_order = [\"R0.0-6.0\", \"R6.0-8.0\", \"R8.0-10.0\", \"R10.0-15.0\"]\n",
    "    radial_bins_set = set(radial_bin_order)\n",
    "    \n",
    "    # Find all model files\n",
    "    model_files = {}\n",
    "    for filename in os.listdir(models_dir):\n",
    "        if filename.endswith(model_suffix):\n",
    "            bin_name = filename.split(model_suffix)[0]\n",
    "            if bin_name in radial_bins_set:\n",
    "                model_path = os.path.join(models_dir, filename)\n",
    "                model_files[bin_name] = model_path\n",
    "    \n",
    "    # Load models\n",
    "    for bin_name in radial_bin_order:\n",
    "        if bin_name in model_files:\n",
    "            model_path = model_files[bin_name]\n",
    "            print(f\"\\nLoading model for {bin_name} from {model_path}\")\n",
    "            \n",
    "            try:\n",
    "                # Attempt to load the model with correct parameters\n",
    "                model, scaler, _ = load_model_from_checkpoint(model_path, device)\n",
    "                flow_models[bin_name] = model\n",
    "                scalers[bin_name] = scaler\n",
    "                print(f\"Successfully loaded model for {bin_name}\")\n",
    "                \n",
    "                # Test sampling from the model\n",
    "                print(f\"Testing model sampling...\")\n",
    "                with torch.no_grad():\n",
    "                    samples = model.sample(10).cpu().numpy()\n",
    "                print(f\"✓ Sampling successful - got {samples.shape} samples\")\n",
    "                \n",
    "            except Exception as e:\n",
    "                print(f\"Error loading model: {e}\")\n",
    "                print(\"Creating a backup model for this bin\")\n",
    "                \n",
    "                # Create a backup model for this bin\n",
    "                backup_model = create_backup_model(device=device)\n",
    "                \n",
    "                # Load just the checkpoint to get the scaler\n",
    "                try:\n",
    "                    checkpoint = torch.load(model_path, map_location=device)\n",
    "                    if \"scaler\" in checkpoint:\n",
    "                        flow_models[bin_name] = backup_model\n",
    "                        scalers[bin_name] = checkpoint[\"scaler\"]\n",
    "                        print(f\"Loaded backup model and scaler for {bin_name}\")\n",
    "                    else:\n",
    "                        print(f\"Failed to load backup - no scaler in checkpoint\")\n",
    "                except Exception as e:\n",
    "                    print(f\"Error loading checkpoint: {e}\")\n",
    "                continue\n",
    "    \n",
    "    # If we still have no models, create a dummy for R0.0-6.0\n",
    "    if not flow_models:\n",
    "        bin_name = \"R0.0-6.0\"  # Default bin\n",
    "        print(f\"\\nNo models loaded. Creating a dummy model for {bin_name}\")\n",
    "        flow_models[bin_name] = create_backup_model(device=device)\n",
    "        scalers[bin_name] = None  # No scaler available\n",
    "    \n",
    "    print(f\"\\nLoaded {len(flow_models)} models: {list(flow_models.keys())}\")\n",
    "    return flow_models, scalers\n",
    "\n",
    "# Load the models - specify the suffix to look for 3D models specifically\n",
    "flow_models, scalers = load_all_models(model_suffix=\"_3d_model.pt\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Create Visualization Functions\n",
    "\n",
    "Now let's define functions to visualize the data."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def plot_age_metallicity_kde(ages, metallicities, bin_name=None, flip_age_axis=True, \n",
    "                             age_range=(0, 14), feh_range=(-1.5, 0.5), figsize=(10, 8),\n",
    "                             title_suffix=\"\"):\n",
    "    \"\"\"Create a KDE-based visualization of Age vs. [Fe/H].\"\"\"\n",
    "    # Create figure\n",
    "    fig, ax = plt.subplots(figsize=figsize)\n",
    "    \n",
    "    # Calculate KDE\n",
    "    xy = np.vstack([ages, metallicities])\n",
    "    kde = gaussian_kde(xy)\n",
    "    \n",
    "    # Create grid for KDE evaluation\n",
    "    x_grid = np.linspace(age_range[0], age_range[1], 100)\n",
    "    y_grid = np.linspace(feh_range[0], feh_range[1], 100)\n",
    "    xx, yy = np.meshgrid(x_grid, y_grid)\n",
    "    \n",
    "    # Evaluate KDE on grid\n",
    "    zz = kde(np.vstack([xx.ravel(), yy.ravel()]))\n",
    "    zz = zz.reshape(xx.shape)\n",
    "    \n",
    "    # Plot KDE as contours with filled colors\n",
    "    contour = ax.contourf(xx, yy, zz, levels=20, cmap=\"viridis\", alpha=0.8)\n",
    "    \n",
    "    # Add colorbar\n",
    "    cbar = plt.colorbar(contour, ax=ax)\n",
    "    cbar.set_label(\"Density\")\n",
    "    \n",
    "    # Add scatter points with very small size for detail\n",
    "    ax.scatter(ages, metallicities, s=0.5, color=\"k\", alpha=0.1)\n",
    "    \n",
    "    # Set labels and title\n",
    "    ax.set_xlabel(\"Age (Gyr)\")\n",
    "    ax.set_ylabel(\"[Fe/H]\")\n",
    "    if bin_name:\n",
    "        ax.set_title(f\"Age-Metallicity Relation - {bin_name} {title_suffix}\")\n",
    "    else:\n",
    "        ax.set_title(f\"Age-Metallicity Relation {title_suffix}\")\n",
    "    \n",
    "    # Set axis ranges\n",
    "    ax.set_xlim(age_range)\n",
    "    ax.set_ylim(feh_range)\n",
    "    if flip_age_axis:\n",
    "        ax.invert_xaxis()  # Flip x-axis to show oldest at left\n",
    "    \n",
    "    # Add grid\n",
    "    ax.grid(True, linestyle=\"--\", alpha=0.5)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    return fig, ax\n",
    "\n",
    "\n",
    "def plot_mgfe_feh_kde(fehs, mgfes, bin_name=None, feh_range=(-1.5, 0.5), \n",
    "                      mgfe_range=(-0.2, 0.5), figsize=(10, 8), title_suffix=\"\"):\n",
    "    \"\"\"Create a KDE-based visualization of [Mg/Fe] vs. [Fe/H].\"\"\"\n",
    "    # Create figure\n",
    "    fig, ax = plt.subplots(figsize=figsize)\n",
    "    \n",
    "    # Calculate KDE\n",
    "    xy = np.vstack([fehs, mgfes])\n",
    "    kde = gaussian_kde(xy)\n",
    "    \n",
    "    # Create grid for KDE evaluation\n",
    "    x_grid = np.linspace(feh_range[0], feh_range[1], 100)\n",
    "    y_grid = np.linspace(mgfe_range[0], mgfe_range[1], 100)\n",
    "    xx, yy = np.meshgrid(x_grid, y_grid)\n",
    "    \n",
    "    # Evaluate KDE on grid\n",
    "    zz = kde(np.vstack([xx.ravel(), yy.ravel()]))\n",
    "    zz = zz.reshape(xx.shape)\n",
    "    \n",
    "    # Plot KDE as contours with filled colors\n",
    "    contour = ax.contourf(xx, yy, zz, levels=20, cmap=\"plasma\", alpha=0.8)\n",
    "    \n",
    "    # Add colorbar\n",
    "    cbar = plt.colorbar(contour, ax=ax)\n",
    "    cbar.set_label(\"Density\")\n",
    "    \n",
    "    # Add scatter points with very small size for detail\n",
    "    ax.scatter(fehs, mgfes, s=0.5, color=\"k\", alpha=0.1)\n",
    "    \n",
    "    # Set labels and title\n",
    "    ax.set_xlabel(\"[Fe/H]\")\n",
    "    ax.set_ylabel(\"[Mg/Fe]\")\n",
    "    if bin_name:\n",
    "        ax.set_title(f\"[Mg/Fe] vs. [Fe/H] Relation - {bin_name} {title_suffix}\")\n",
    "    else:\n",
    "        ax.set_title(f\"[Mg/Fe] vs. [Fe/H] Relation {title_suffix}\")\n",
    "    \n",
    "    # Set axis ranges\n",
    "    ax.set_xlim(feh_range)\n",
    "    ax.set_ylim(mgfe_range)\n",
    "    \n",
    "    # Add grid\n",
    "    ax.grid(True, linestyle=\"--\", alpha=0.5)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    return fig, ax"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Test: Sample from a Model\n",
    "\n",
    "Let's test sampling from one of the models and visualize the results."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Test model sampling\n",
    "if flow_models:\n",
    "    bin_name = list(flow_models.keys())[0]  # Get first available bin\n",
    "    model = flow_models[bin_name]\n",
    "    scaler = scalers[bin_name]\n",
    "    \n",
    "    print(f\"Testing sampling from {bin_name} model...\")\n",
    "    ages, fehs, mgfes, success = sample_flow_model(\n",
    "        model, scaler, n_samples=5000, \n",
    "        age_range=(0, 14), feh_range=(-1.5, 0.5)\n",
    "    )\n",
    "    \n",
    "    if success:\n",
    "        title_suffix = \"(Direct Model Inference)\"\n",
    "    else:\n",
    "        title_suffix = \"(Synthetic Data)\"\n",
    "        \n",
    "    fig, ax = plot_age_metallicity_kde(\n",
    "        ages, fehs, bin_name=bin_name, \n",
    "        age_range=(0, 14), feh_range=(-1.5, 0.5),\n",
    "        title_suffix=title_suffix\n",
    "    )\n",
    "    plt.show()\n",
    "    \n",
    "    fig, ax = plot_mgfe_feh_kde(\n",
    "        fehs, mgfes, bin_name=bin_name,\n",
    "        title_suffix=title_suffix\n",
    "    )\n",
    "    plt.show()\n",
    "else:\n",
    "    print(\"No models available for testing.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Interactive Visualization Interface\n",
    "\n",
    "Create an interactive widget to explore the models and visualize the age-metallicity relationship."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def create_interactive_explorer():\n",
    "    \"\"\"Create an interactive interface to explore and visualize models.\"\"\"\n",
    "    if not flow_models:\n",
    "        print(\"No models available for visualization.\")\n",
    "        return\n",
    "    \n",
    "    # Create widget elements\n",
    "    bin_dropdown = widgets.Dropdown(\n",
    "        options=list(flow_models.keys()),\n",
    "        description='Radial Bin:',\n",
    "        disabled=False,\n",
    "    )\n",
    "    \n",
    "    mode_toggle = widgets.ToggleButtons(\n",
    "        options=['Age-[Fe/H]', '[Mg/Fe]-[Fe/H]'],\n",
    "        description='Plot Mode:',\n",
    "        disabled=False\n",
    "    )\n",
    "    \n",
    "    n_samples_slider = widgets.IntSlider(\n",
    "        value=5000,\n",
    "        min=1000,\n",
    "        max=20000,\n",
    "        step=1000,\n",
    "        description='Samples:',\n",
    "        disabled=False,\n",
    "        continuous_update=False,\n",
    "        readout=True,\n",
    "        readout_format='d'\n",
    "    )\n",
    "    \n",
    "    min_age = widgets.FloatSlider(\n",
    "        value=0,\n",
    "        min=0,\n",
    "        max=15,\n",
    "        step=0.5,\n",
    "        description='Min Age:',\n",
    "        disabled=False,\n",
    "        continuous_update=False\n",
    "    )\n",
    "    \n",
    "    max_age = widgets.FloatSlider(\n",
    "        value=14,\n",
    "        min=5,\n",
    "        max=20,\n",
    "        step=0.5,\n",
    "        description='Max Age:',\n",
    "        disabled=False,\n",
    "        continuous_update=False\n",
    "    )\n",
    "    \n",
    "    min_feh = widgets.FloatSlider(\n",
    "        value=-1.5,\n",
    "        min=-2.0,\n",
    "        max=0,\n",
    "        step=0.1,\n",
    "        description='Min [Fe/H]:',\n",
    "        disabled=False,\n",
    "        continuous_update=False\n",
    "    )\n",
    "    \n",
    "    max_feh = widgets.FloatSlider(\n",
    "        value=0.5,\n",
    "        min=0,\n",
    "        max=1.0,\n",
    "        step=0.1,\n",
    "        description='Max [Fe/H]:',\n",
    "        disabled=False,\n",
    "        continuous_update=False\n",
    "    )\n",
    "    \n",
    "    min_mgfe = widgets.FloatSlider(\n",
    "        value=-0.2,\n",
    "        min=-0.5,\n",
    "        max=0.0,\n",
    "        step=0.05,\n",
    "        description='Min [Mg/Fe]:',\n",
    "        disabled=False,\n",
    "        continuous_update=False\n",
    "    )\n",
    "    \n",
    "    max_mgfe = widgets.FloatSlider(\n",
    "        value=0.5,\n",
    "        min=0.0,\n",
    "        max=0.7,\n",
    "        step=0.05,\n",
    "        description='Max [Mg/Fe]:',\n",
    "        disabled=False,\n",
    "        continuous_update=False\n",
    "    )\n",
    "    \n",
    "    flip_age = widgets.Checkbox(\n",
    "        value=True,\n",
    "        description='Flip Age Axis (Oldest Left)',\n",
    "        disabled=False\n",
    "    )\n",
    "    \n",
    "    update_button = widgets.Button(\n",
    "        description='Update Plot',\n",
    "        disabled=False,\n",
    "        button_style='', \n",
    "        tooltip='Click to update the plot'\n",
    "    )\n",
    "    \n",
    "    output = widgets.Output()\n",
    "    \n",
    "    # Create the update function\n",
    "    def update_plot(b):\n",
    "        bin_name = bin_dropdown.value\n",
    "        n_samples = n_samples_slider.value\n",
    "        age_range = (min_age.value, max_age.value)\n",
    "        feh_range = (min_feh.value, max_feh.value)\n",
    "        mgfe_range = (min_mgfe.value, max_mgfe.value)\n",
    "        flip_age_axis = flip_age.value\n",
    "        plot_mode = mode_toggle.value\n",
    "        \n",
    "        model = flow_models[bin_name]\n",
    "        scaler = scalers[bin_name]\n",
    "        \n",
    "        # Try to sample from the model\n",
    "        ages, fehs, mgfes, success = sample_flow_model(\n",
    "            model, scaler, n_samples=n_samples, \n",
    "            age_range=age_range, feh_range=feh_range\n",
    "        )\n",
    "        \n",
    "        title_suffix = \"(Direct Model Inference)\" if success else \"(Synthetic Data)\"\n",
    "        \n",
    "        output.clear_output(wait=True)\n",
    "        with output:\n",
    "            if len(ages) < 10:\n",
    "                print(f\"Warning: Not enough samples to plot. Try adjusting your ranges.\")\n",
    "                return\n",
    "                \n",
    "            print(f\"Plotting {len(ages)} samples for {bin_name} {'using model inference' if success else 'using synthetic data'}\")\n",
    "            \n",
    "            if plot_mode == 'Age-[Fe/H]':\n",
    "                # Plot Age vs [Fe/H]\n",
    "                fig, ax = plot_age_metallicity_kde(\n",
    "                    ages, fehs, bin_name=bin_name, flip_age_axis=flip_age_axis, \n",
    "                    age_range=age_range, feh_range=feh_range, title_suffix=title_suffix\n",
    "                )\n",
    "            else:\n",
    "                # Additional filtering for [Mg/Fe]\n",
    "                mask = (mgfes >= mgfe_range[0]) & (mgfes <= mgfe_range[1])\n",
    "                fehs_filtered = fehs[mask]\n",
    "                mgfes_filtered = mgfes[mask]\n",
    "                \n",
    "                # Plot [Mg/Fe] vs [Fe/H]\n",
    "                if len(fehs_filtered) < 10:\n",
    "                    print(f\"Warning: Not enough samples remain after [Mg/Fe] filtering. Try adjusting your ranges.\")\n",
    "                    return\n",
    "                    \n",
    "                fig, ax = plot_mgfe_feh_kde(\n",
    "                    fehs_filtered, mgfes_filtered, bin_name=bin_name, \n",
    "                    feh_range=feh_range, mgfe_range=mgfe_range, title_suffix=title_suffix\n",
    "                )\n",
    "                \n",
    "            plt.show()\n",
    "    \n",
    "    update_button.on_click(update_plot)\n",
    "    \n",
    "    # Layout\n",
    "    controls1 = widgets.HBox([bin_dropdown, mode_toggle, n_samples_slider])\n",
    "    controls2 = widgets.HBox([min_age, max_age, min_feh, max_feh])\n",
    "    controls3 = widgets.HBox([min_mgfe, max_mgfe, flip_age, update_button])\n",
    "    \n",
    "    # Display widgets\n",
    "    display(widgets.VBox([controls1, controls2, controls3, output]))\n",
    "    \n",
    "    # Initial plot\n",
    "    update_plot(None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create the interactive explorer\n",
    "if flow_models:\n",
    "    create_interactive_explorer()\n",
    "else:\n",
    "    print(\"No models loaded. Please first run run_training_3d.py to generate 3D models.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Compare All Radial Bins Side-by-Side\n",
    "\n",
    "This function allows you to compare the age-metallicity relationship across all loaded radial bins."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def compare_all_bins(n_samples=2000, age_range=(0, 14), feh_range=(-1.5, 0.5), flip_age_axis=True):\n",
    "    \"\"\"Compare age-metallicity relations across all radial bins.\"\"\"\n",
    "    if not flow_models:\n",
    "        print(\"No models available for comparison.\")\n",
    "        return\n",
    "    \n",
    "    # Get all available bins\n",
    "    available_bins = list(flow_models.keys())\n",
    "    if not available_bins:\n",
    "        print(\"No valid models available for comparison.\")\n",
    "        return\n",
    "    \n",
    "    # Set up figure\n",
    "    n_bins = len(available_bins)\n",
    "    fig, axes = plt.subplots(1, n_bins, figsize=(5 * n_bins, 5), sharex=True, sharey=True)\n",
    "    \n",
    "    # Handle the case of a single bin\n",
    "    if n_bins == 1:\n",
    "        axes = [axes]\n",
    "    \n",
    "    # Track if any direct model inference succeeded\n",
    "    any_direct_success = False\n",
    "    \n",
    "    # Plot each bin\n",
    "    for i, bin_name in enumerate(available_bins):\n",
    "        model = flow_models[bin_name]\n",
    "        scaler = scalers[bin_name]\n",
    "        \n",
    "        # Sample from the model\n",
    "        ages, fehs, _, success = sample_flow_model(\n",
    "            model, scaler, n_samples=n_samples,\n",
    "            age_range=age_range, feh_range=feh_range\n",
    "        )\n",
    "        \n",
    "        if success:\n",
    "            any_direct_success = True\n",
    "        \n",
    "        if len(ages) < 10:\n",
    "            print(f\"Warning: Not enough samples for bin {bin_name}. Skipping.\")\n",
    "            continue\n",
    "        \n",
    "        # Calculate KDE\n",
    "        xy = np.vstack([ages, fehs])\n",
    "        kde = gaussian_kde(xy)\n",
    "        \n",
    "        # Create grid for KDE evaluation\n",
    "        x_grid = np.linspace(age_range[0], age_range[1], 100)\n",
    "        y_grid = np.linspace(feh_range[0], feh_range[1], 100)\n",
    "        xx, yy = np.meshgrid(x_grid, y_grid)\n",
    "        \n",
    "        # Evaluate KDE on grid\n",
    "        zz = kde(np.vstack([xx.ravel(), yy.ravel()]))\n",
    "        zz = zz.reshape(xx.shape)\n",
    "        \n",
    "        # Plot KDE as contours with filled colors\n",
    "        contour = axes[i].contourf(xx, yy, zz, levels=20, cmap=\"viridis\", alpha=0.8)\n",
    "        \n",
    "        # Add scatter points with small size\n",
    "        axes[i].scatter(ages, fehs, s=0.3, color=\"k\", alpha=0.1)\n",
    "        \n",
    "        # Set title and labels\n",
    "        axes[i].set_title(f\"Bin: {bin_name}\")\n",
    "        axes[i].set_xlabel(\"Age (Gyr)\")\n",
    "        if i == 0:\n",
    "            axes[i].set_ylabel(\"[Fe/H]\")\n",
    "        \n",
    "        # Set axis ranges\n",
    "        axes[i].set_xlim(age_range)\n",
    "        axes[i].set_ylim(feh_range)\n",
    "        if flip_age_axis:\n",
    "            axes[i].invert_xaxis()  # Flip x-axis\n",
    "        \n",
    "        # Add grid\n",
    "        axes[i].grid(True, linestyle=\"--\", alpha=0.5)\n",
    "    \n",
    "    # Add colorbar for the last plot\n",
    "    cbar = fig.colorbar(contour, ax=axes[-1])\n",
    "    cbar.set_label(\"Density\")\n",
    "    \n",
    "    title = \"Age-Metallicity Relation Across Radial Bins (3D Models)\"\n",
    "    if any_direct_success:\n",
    "        title += \" (Model Inference)\"\n",
    "    else:\n",
    "        title += \" (Synthetic Data)\"\n",
    "    plt.suptitle(title, fontsize=16)\n",
    "    \n",
    "    plt.tight_layout()\n",
    "    plt.subplots_adjust(top=0.88)  # Make room for the suptitle\n",
    "    \n",
    "    return fig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compare all radial bins\n",
    "if flow_models:\n",
    "    fig = compare_all_bins()\n",
    "    plt.show()\n",
    "else:\n",
    "    print(\"No models loaded.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## About This Notebook\n",
    "\n",
    "This notebook provides a simplified interface for exploring age-metallicity and [Mg/Fe]-[Fe/H] patterns in different regions of the Milky Way disc. It uses 3D normalizing flow models that focus only on these three key chemical evolution parameters.\n",
    "\n",
    "The normalizing flow models represent complex density distributions in the 3-dimensional space of:\n",
    "- Age\n",
    "- [Fe/H] (metallicity)\n",
    "- [Mg/Fe] (alpha element abundance)\n",
    "\n",
    "By examining these parameters, we can understand how different stellar populations are distributed across the Galaxy and gain insights into the formation and evolution of the Milky Way disc.\n",
    "\n",
    "### How to Use\n",
    "\n",
    "1. First, run the `run_training_3d.py` script to train the 3D models for each radial bin\n",
    "2. Then return to this notebook to load and visualize the trained models\n",
    "3. Use the interactive widgets to explore different parameters and visualization options"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}