# Practical Machine Learning for X-ray Astronomy: Building a Neural-Network Surrogate for a Power-Law Spectrum

This tutorial is a hands-on introduction to practical machine learning for X-ray astronomers. 

**Note**: it should run entirely or almost entirely on **Google Colab**, so if for any reason you can't install the libraries, upload it to Colab and try running it there!

As a note in advance: none of the results you'll get out of this are science-worthy. This tutorial is meant to give you a first idea for how to set up your own machine learning model. But the first, and most important lesson, is this: **don't blindly trust your ML results.** 

As with any other science project, reporting or using results from a machine learning classifier or regressor requires careful understanding of the biases and caveats, assumptions and limitations that come with the data and algorithms chosen. With building surrogate models especially, you can expect to find all the funny weird corner cases (both subtle and not) of the model you're trying to emulate. In a real-world setting, you need to really understand the limitations of the model you're using, before drawing any scientific conclusions from your machine learning.

With that out of the way, let's have some fun with machine learning! In this tutorial, we will use python, a library called `scikit-learn` and `pytorch` to do our machine learning, `pandas` to deal with data structures, and `matplotlib` and `seaborn` to do our plotting. 

You will build a **neural-network surrogate model** for a simple X-ray spectral model: a **power law**. This is obviously a silly endeavour: a power law is a simple model you can just write down analytically. But here, it serves as a simple (fast) analogue to much more complex numerical models you can't write down analytically, and who might be expensive to compute. Our goal is to build a neural network replacement for this model, which accepts sets of parameters and produces the X-ray spectrum of the given physical model, ideally orders of magnitude faster than the original physics model. The basic idea is this: if we need to evaluate this model *a lot* (hundreds of thousands or millions of times, e.g. for inferring posterior distributions of parameters via Markov Chain Monte Carlo sampling), then doing this once in order to generate a training data set for a surrogate model might be much more computationally efficient than using the model directly for inference. 

The goal is not to build the best model ever, but to learn the workflow and the *failure modes* of ML in scientific settings.

We will keep the focus on:
- train/validation/test splits
- overfitting and generalization
- scaling/normalization
- cross-validation and hyperparameter selection
- robust evaluation

We will start deliberately too small:
1. Train a single-neuron network on one training example  
2. Show it fails on validation data  
3. Scale up the training set and the model capacity in a controlled way



### Imports

There are basically two major packages for building and working with neural networks: [`tensorflow`](https://www.tensorflow.org) and [`pytorch`](https://pytorch.org). They are somewhat different in syntax, but can both use all of the major things you're likely going to want to do, so it doesn't matter much which one you pick. Here, we're going to use `pytorch`. 

`scikit-learn` is a more general machine learning package that has a wide range of algorithms and infrastructure for preprocessing data, exploring the performance of your machine learning models and so on. So this one's useful to have in your arsenal. 

If you're exploring different models (e.g. different architectures, activation functions, hyperparameters, ...), [`weights and biases`](https://wandb.ai) can be a useful tool.

In [None]:
# UNCOMMENT THIS CELL IF YOU RUN ON COLAB
#!pip install emcee
#!pip install corner

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from scipy.special import gamma as scipy_gamma
from scipy.special import gammaln as scipy_gammaln

from sklearn.model_selection import train_test_split, KFold
from sklearn.preprocessing import StandardScaler
from scipy.stats import qmc

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

import emcee
import corner

# Reproducibility (still not perfect on GPU, but good enough for this tutorial)
rng = np.random.default_rng(123)
torch.manual_seed(123)

print("numpy:", np.__version__)
print("torch:", torch.__version__)

## A useful plotting function

We'll be predicting an entire spectrum (counts as a function of energy).  
The helper below makes quick “data vs prediction” plots.

In [None]:
def plot_spectrum(energy, y_true, y_pred=None, title=None, ax=None):
    """Plot a spectrum in counts vs energy.

    Parameters
    ----------
    energy : array, shape (nE,)
        Energy grid in keV.
    y_true : array, shape (nE,)
        True target spectrum (counts).
    y_pred : array, shape (nE,), optional
        Predicted spectrum (counts).
    title : str, optional
        Plot title.
    """

    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(7,4))

    ax.plot(energy, y_true, marker='o', lw=1, label="target")
    if y_pred is not None:
        ax.plot(energy, y_pred, marker='.', markersize=3, lw=1, label="prediction")
    ax.set_yscale("log")
    ax.set_xlabel("Energy (keV)")
    ax.set_ylabel("Counts (arb.)")
    if title:
        ax.set_title(title)
    ax.legend()

    return ax

## The scenario

In X-ray spectral fitting we often evaluate a forward model (i.e. a physics model, usually modified to account for instrumental distortions) many times:
- to explore a likelihood surface
- to infer posterior distributions of model parameters via sampling
- population inference with many spectra

If each forward-model evaluation is expensive (response folding, complex physical models, etc.), a **(neural network) surrogate** can accelerate inference. You train a neural network to approximate the forward model and then use it as a fast emulator.

Here we'll emulate a simple model:

$$
N(E) = A\,(E/E_0)^{-\Gamma}
$$

As mentioned above, this is a bit of silly exercise: a power law is an analytic function that's really fast to evaluate. However, for more expensive numerical models, this might be a way to significantly speed up evaluation time. We're using the simple model here to give you something reasonable to work with as part of the tutorial, and so that you don't have to spend ages waiting for your training data to generate.

We're also not going to worry about instrument responses here. The idea of the surrogate model is that you use it as a **drop-in replacement** for the physical model, and then apply the instrument responses to the output as you would for your physical model during inference. You *could* also learn the model convolved with the response (and that may be efficient when the response calculations are suddenly the most expensive part of the calculation, see also: XRISM), but be aware that this would tie your trained surrogate model very strongly to a single instrument, and you'd have to re-train for any new instrument you want to use it for.


### Generate synthetic spectra

In order to train a neural network to emulate the physical model above, we need to pre-generate a **training data set**. This is a data set that contains matched pairs of **parameters** ($A$ and $\Gamma$) and the corresponding model output fluxes at a set of photon energies. How big that data set should be is hard to estimate in advance: it depends on the complexity of the model and the number of parameters. 

We choose:
- energy grid: 0.3–10 keV
- parameters: amplitude $A$ and photon index $\Gamma$

We will *learn* the mapping:

$$
(A, \Gamma) \rightarrow \{\mathrm{counts}(E_i)\}_{i=1}^{n_E}
$$

This is a **multi-output regression** problem. Here, we're going to simulate data and build a surrogate model for a **fixed energy grid**. However, there have recently been types of neural networks introduced that are independent of the input grid (e.g. Neural Operators, DeepONets, ...).

**Note**: generally, we don't want to train a neural network on the amplitude parameter here. This is because for many models, the amplitude parameter just moves the spectrum up and down in flux space, which is a really cheap mathematical operation we can just apply ourselves after calculating the spectrum, and thus the neural network doesn't have to learn it. I'm leaving it in here explicitly specifically so that we have more than one parameter to train on, and also so you can explore training on parameters that span multiple orders of magnitude.

In [None]:
def powerlaw_counts(energy, amp, gamma, e0=1.0, exposure=1.0):
    """
    Deterministic toy counts spectrum (no noise).
    
    Parameters
    ----------
    energy : array, shape (nE,)
        Energy grid in keV.

    amp : float
        The power law amplitude

    gamma: float
        The power law index

    e0: float
        The break energy?

    exposure : float
        The exposure time of the observation
    """
    photon_flux = amp * (energy/e0)**(-gamma)  # photons / (keV * cm^2 * s) up to a constant
    return photon_flux * exposure

def simulate_dataset_random(n_samples, energy, 
                          logamp_range=(np.log(5e-4), np.log(5e-2)), 
                          gamma_range=(1.0, 3.0), 
                          exposure=1.0, noisy=False):
    """
    Simulate a dataset of (parameters -> spectrum), 
    using simple uniform random sampling (in log-space for the amplitude).
    
    Parameters
    ----------
    n_samples : int
        The number of samples to draw in each dimension. Full dataset will be 
        of size nsamples

    energy : array
        The energy grid to simulate the data over

    logamp_range : (float, float)
        The lower and upper boundary between which to sample 
        the amplitude

    gamma_range : (float, float)
        The lower and upper boundary between which to sample 
        the power law index

    exposure : float
        The exposure of the observation

    noisy : bool, default False
        If True, add Poisson noise to the spectrum

    Returns
    -------
    X : array of shape (nsamples, 2)
        The array containing all the pairs of parameters

    Y : array of (nsamples, len(energy))
        The array with all simulated spectra
    """
    logamp = rng.uniform(logamp_range[0], logamp_range[1], size=n_samples)
    amp = np.exp(logamp)
    gamma = rng.uniform(gamma_range[0], gamma_range[1], size=n_samples)

    X = np.column_stack([logamp, gamma]).astype(np.float32)  # use logA for dynamic range
    Y = np.stack([powerlaw_counts(energy, ampi, gammai, exposure=exposure) for ampi, gammai in zip(amp, gamma)]).astype(np.float32)

    if noisy:
        # Poisson noise around the expected counts
        Y = rng.poisson(Y).astype(np.float32)

    return X, Y


def simulate_dataset_lhc(n_samples, energy, 
                          logamp_range=(np.log(5e-4), np.log(5e-2)), 
                          gamma_range=(1.0, 3.0), 
                          exposure=1.0, noisy=False):
    """
    Simulate a dataset of (parameters -> spectrum), 
    using a Latin Hypercube Sampling.

    Parameters
    ----------
    n_samples : int
        The number of samples to draw in each dimension. Full dataset will be 
        of size nsamples

    energy : array
        The energy grid to simulate the data over

    logamp_range : (float, float)
        The lower and upper boundary between which to sample 
        the amplitude

    gamma_range : (float, float)
        The lower and upper boundary between which to sample 
        the power law index

    exposure : float
        The exposure of the observation

    noisy : bool, default False
        If True, add Poisson noise to the spectrum

    Returns
    -------
    X : array of shape (nsamples, 2)
        The array containing all the pairs of parameters

    Y : array of (nsamples, len(energy))
        The array with all simulated spectra 
    """
    # set up Latin Hypercube Sampling
    sampler = qmc.LatinHypercube(d=2)

    # sample randomly using LHS
    sample = sampler.random(n=n_samples)

    # set array of lower and upper bounds
    l_bounds = [logamp_range[0], gamma_range[0]]
    u_bounds = [logamp_range[1], gamma_range[1]]

    # scale sample to bounds
    X = qmc.scale(sample, l_bounds, u_bounds)

    Y = np.stack([powerlaw_counts(energy, np.exp(log_ampi), gammai, exposure=exposure) for log_ampi, gammai in X]).astype(np.float32)

    if noisy:
        # Poisson noise around the expected counts
        Y = rng.poisson(Y).astype(np.float32)

    return X, Y

Above, there are two sampling functions. One uses uniform random sampling, the other uses [Latin Hypercube Sampling](https://en.wikipedia.org/wiki/Latin_hypercube_sampling). Both tend to do better than a uniform, regularly spaced grid, and LHS tends to do better than random uniform sampling in higher dimensions. LHS is designed to evenly fill the high-dimensional space with points, whereas uniform random sampling can produce clumps in high dimensions. 

Grids tend to be very sparse: if you have a 10-dimensional problem and just want 5 points in each dimension, you already need nearly a million grid points total! Neural networks tend to overfit on the grid points, and do much worse in between, and some of that can be mitigated by some form of random sampling.

Let's try it out:

In [None]:
# generate energy grid
energy = np.logspace(np.log10(0.3), np.log10(10.0), 60).astype(np.float32)

# make some simulated data 
X_demo, Y_demo = simulate_dataset_random(100, energy)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10,5))

ax1.scatter(X_demo[:,0], X_demo[:,1], s=2)

for i in range(3):
    plot_spectrum(energy, Y_demo[i], ax=ax2,
                  title=f"Example spectrum {i}  (logA={X_demo[i,0]:.2f}, Γ={X_demo[i,1]:.2f})")
plt.show()

Let's try the same with the Latin Hypercube Sampling:

In [None]:
# make some simulated data 
X_lhs, Y_lhs = simulate_dataset_lhc(100, energy)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10,5))

ax1.scatter(X_lhs[:,0], X_lhs[:,1], s=2)

for i in range(3):
    plot_spectrum(energy, Y_lhs[i], ax=ax2,
                  title=f"Example spectrum {i}  (logA={X_lhs[i,0]:.2f}, Γ={X_lhs[i,1]:.2f})")
plt.show()

Looks different, because logarithms.

## Building a first surrogate: one node, one training example

To start, we'll train the simplest neural network we can: a single node in a single layer.

A “single node” (a single linear layer) with no hidden layers is:

$$
\hat{\mathbf{y}} = W\mathbf{x} + \mathbf{b}
$$

where:
- $\mathbf{x}$ has 2 inputs: $\log(A)$, $\Gamma$
- $\hat{\mathbf{y}}$ has $n_E$ outputs (one per energy bin)

This model has:
- $2n_E$ weights + $n_E$ biases  
For $n_E = 60$, that's 180 parameters.

**Important**: in surrogate models, the word "parameter" is somewhat overloaded. This is because we're training the neural network to emulate the physical model's relationship between the model's input parameters (here the power law index and amplitude) and the model's output (here, the flux for each energy bin). However, the neural network *also* has parameters, which are often called *weights* and *biases*. These parameters we adjust during **training** such that the neural network produces outputs that closely resemble the spectrum we are trying to emulate. I'll try to make it clear from context which parameters we're talking about, but keep in mind that this might occasionally cause confusion as we go on. Machine learning also knows the concept of **hyperparameters**: these are properties of the neural network (e.g. the architecture, the number of layers/nodes, the details of the training process) that you can change in order to try and improve your neural network performance. We'll return to hyperparameters later.

Training it on **one** example is guaranteed overfitting: this means that we already have many more parameters than we have data points in the training data.

**Note**: Here, we call the collection of pairs of physics model parameters and physics model outputs (spectra) **"training data"**, even though they're not--strictly speaking--data in the sense of what we would usually consider data (i.e. observations coming from a telescope). In machine learning, any collection of input-output pairs used to train a machine learning algorithm is called "training data".

We do it anyway to see what “overfitting” looks like in practice.

In [None]:
# Create the tiniest dataset possible: 1 training example and a few validation examples
X_train_tiny, Y_train_tiny = simulate_dataset_random(1, energy, noisy=False)
X_val_tiny,   Y_val_tiny   = simulate_dataset_random(10, energy, noisy=False)

# We'll scale inputs and outputs (this matters for neural nets!)
# What that means we'll get into a little later
x_scaler = StandardScaler().fit(X_train_tiny)
y_scaler = StandardScaler().fit(Y_train_tiny)

Xtr = x_scaler.transform(X_train_tiny).astype(np.float32)
Ytr = y_scaler.transform(Y_train_tiny).astype(np.float32)

Xva = x_scaler.transform(X_val_tiny).astype(np.float32)
Yva = y_scaler.transform(Y_val_tiny).astype(np.float32)

# Torch tensors
Xtr_t = torch.from_numpy(Xtr)
Ytr_t = torch.from_numpy(Ytr)
Xva_t = torch.from_numpy(Xva)
Yva_t = torch.from_numpy(Yva)

nE = Y_train_tiny.shape[1]

Now we can write down our model. It's really simple at this point: just one line of code.

In [None]:
# this is our model
tiny_model = nn.Linear(2, nE)  # single node per output dimension (linear map)

The goal of the neural network is to predict the X-ray spectrum given some parameters. For this, we compare the neural network output to training examples (pairs of parameters and corresponding spectra) and move the parameters of the neural networks (the *weights* and *biases*) such that they produce output that is *more* similar to the training examples.

To do this, we have to have a metric of comparison: how do we know when the neural network output is "close" to the training examples? This is done via a **loss function**, which quantifies the distance between the neural network output and the "true" training examples. Here, we're going to use the Mean Squared Error as a distance metric, which effectively just uses the square of the difference between the neural network output and the training examples to quantify distance. 

In [None]:
# We use the Mean Squared Error loss:
loss_fn = nn.MSELoss()

How do we know how to wiggle around the weights and biases in order to make the neural network produce outputs that ideally closely resemble the training examples? 

The answer here is optimization. It's worth mentioning here that neural network training rests on a concept called *backpropagation*, which effectively means that people figured out how to efficiently calculate *gradients* of the loss function with respect to the weights and biases (the parameters) of the neural network. This is the one weird trick that makes neural networks work in practice, and allows us to train models even when they have millions (or hundreds of millions) of parameters!

There are a range of different optimization algorithms that machine learning researchers have developed to optimize neural networks, most of them based on an algorithm called **Stochastic Gradient Descent**. 

Here, we're using a version of that algorithm called ADAM (Adaptive Moment Estimation) as an optimizer.

In [None]:
# a standard(ish) choice for optimization:
opt = torch.optim.Adam(tiny_model.parameters(), lr=1e-2)

To train, we usually write a function that performs optimization for a number of fixed steps we call **epochs**. At each step, we compute the outputs of the neural network for our training examples, compare that to the outputs of the original model, compute the loss function, and then use the gradients to move the parameters of the neural network in the direction that decreases the loss function. Usually, in between, we also compute the loss function for examples that the neural network doesn't get to see during training, called the **validation set**. This set allows us to test how well the model **generalizes**, i.e. how well it does on data it hasn't seen before.

In [None]:
def train(model, Xtr, Ytr, Xva, Yva, epochs=2000, print_every=400):
    """
    Training function for neural network training.

    Parameters
    ----------
    model : pytorch NN model
        The neural network model to use

    Xtr, Ytr : arrays
        The training data: Xtr are the input parameters of the 
        physical model, Ytr the spectra each set of parameters 
        generates

    Xva, Yva: arrays
        The validation data, of similar form as Xtr and Ytr

    epochs : int
        The number of iterations (epochs) to train for

    print_every : int
        How often to print training and validation performance

    Returns
    -------
    train_losses, val_losses : list, list
        The loss at each epoch for both the training
        and validation data sets
    """

    # empty lists for training and validation losses
    train_losses, val_losses = [], []

    # loop through each epoch:
    for ep in range(1, epochs+1):

        # train the model
        model.train()

        # set gradients to zero
        opt.zero_grad(set_to_none=True)

        # predict the spectra for the training 
        # set parameters
        pred = model(Xtr)

        # calculate the loss between the predicted
        # spectra and the training examples
        loss = loss_fn(pred, Ytr)

        # backpropagation of gradients
        loss.backward()

        # take a step in the dimensions of the 
        # neural network parameters
        opt.step()

        # evaluate the model at the new location
        model.eval()

        # calculate validation losses
        with torch.no_grad():
            val = loss_fn(model(Xva), Yva)

        # append losses to arrays
        train_losses.append(loss.item())
        val_losses.append(val.item())

        # print losses
        if ep % print_every == 0:
            print(f"epoch {ep:4d}  train MSE={loss.item():.4e}  val MSE={val.item():.4e}")
    return np.array(train_losses), np.array(val_losses)

Let's try it:

In [None]:
train_l, val_l = train(tiny_model, Xtr_t, Ytr_t, Xva_t, Yva_t, epochs=500, print_every=50)

Usually a good plot to make is a plot of the loss function as a function of epoch, for both the training and the validation data set. 

The training loss should always be smaller than the validation loss: it would be hard for the model to perform better on data that it hasn't seen than on data that it has! The one exception is if you include **drop-out** in your model (see suggestions for further reading at the end), where you *can* have a validation loss that's smaller than your training loss. 

Let's see what the loss function looks like for our data set:

In [None]:
plt.figure(figsize=(6,4))
plt.plot(train_l, label="train")
plt.plot(val_l, label="val")
plt.yscale("log")
plt.xlabel("epoch")
plt.ylabel("MSE (scaled space)")
plt.title("Overfitting with 1 training example")
plt.legend()
plt.tight_layout()
plt.show()

So we see that the training loss is indeed smaller than the validation loss and going down. The validation loss is remaining high and not changing at all (that we can see). You could train this for longer (until your training loss stops going down), but let's take a look at what our neural network predicts for now:

### Inspect what the tiny model learned

We'll compare:
- the single training target vs prediction (it should fit very well)
- a validation example vs prediction (it should be bad)

In [None]:
# evaluate our tiny model
tiny_model.eval()

# fix the gradients
with torch.no_grad():
    # predict fluxes for training data
    yhat_train = tiny_model(Xtr_t).numpy()
    # predict fluxes for validation data
    yhat_val0  = tiny_model(Xva_t[:1]).numpy()

# invert scaling to counts space for training data
yhat_train_counts = y_scaler.inverse_transform(yhat_train)[0]
ytrain_counts     = Y_train_tiny[0]

# invert scaling to counts space for validation data
yhat_val_counts = y_scaler.inverse_transform(yhat_val0)[0]
yval_counts     = Y_val_tiny[0]

# plot the spectra
plot_spectrum(energy, ytrain_counts, yhat_train_counts, title="Training example: target vs prediction")
plot_spectrum(energy, yval_counts,  yhat_val_counts,   title="Validation example: target vs prediction (fails)")
plt.show()

Okay, so for the training data, that looks pretty good. For the validation example, this looks terrible! This is expected! The neural network has only seen *one* single example of our training data, which means that it has no idea what the model function should look like for other parameter sets, and it will do terrible.

**Exercise:** Try it with a different training example/validation example. You can also decide on a set of specific parameters you'd like to emulate. Do you get a better/different answer?

The important lesson here is to **not use a neural network to extrapolate, but to interpolate**. Neural networks can be excellent interpolators, but they are *terrible* at extrapolating, which is what we're doing above. So in order to be able to interpolate well, we need more data. 

Oh, and one more point on the figure above: you might be able to tell that the neural network also seems to predict *negative* values. This is bad, because by definition, the model flux cannot be negative. So we'll have to do something about that, too. 

Trying this out with a single training example is fast, and it allows you to spot issues like these early on, before you get lost in complex architectures and spend weeks searching for bugs.

## Machine Learning: From Start To Finish

Okay, so we've explored this on a single example, but that's not enough. A good exploratory process (especially when first getting started with machine learning), is to scale up both your neural network and training data set slowly: you keep adding training examples until your model can no longer do well on the training data, and then you add nodes/layers (or a different structure) to your neural network to increase its complexity (and therefore its flexibility in modelling complex functions). You keep doing this until you reach a model that does well over the full range of parameters you want to use it as a surrogate. 

Here, we're going to compress this somewhat for the purpose of this tutorial. Here's an overview of how to build a machine learning project:

We'll now do the full workflow with a realistic dataset size:

1. Decide the goal and target metric
2. Generate / collect data
3. Split into train/val/test
4. Scale features (and sometimes targets)
5. Choose a baseline model
6. Train, monitor learning curves
7. Tune hyperparameters (with cross-validation where appropriate)
8. Evaluate on a held-out test set
9. Interpret failure modes and report

In spectroscopy terms: you should think of this like building a calibration pipeline (and if you ask David Hogg, he'll tell you that instrument calibration has basically always been a machine learning process, but that's a topic to discuss over drinks).

### Deciding on a goal and target metric

In our case, the overall goal is pretty clear: build a surrogate model that can approximate the physical model in generating an X-ray spectrum, but ideally much faster than the original model.

However, there are subtleties:
* How well does the neural network have to approximate the original model? Is 1% relative error enough? Does it have to be better?
* Where and how can you tolerate biases in the model (and where can't you)?
* In which parts of parameter space does the original model have issues (may be less precise, more biased, etc)? How will you treat those?

These are questions that you'll want to consider when building your model, and be explicit about in the documentation/paper: where can people rely on your model well, and where can they not?

### Generate/collect data

Let's now generate more data. Above, I wrote two different functions to do so: one that samples uniformely from the parameter space within some bounds, and one that uses [Latin Hypercube Sampling](https://en.wikipedia.org/wiki/Latin_hypercube_sampling) to sample the space. In two dimensions, the two are fairly equivalent, but in higher dimensions, uniform sampling tends to produce clusters that we don't want (we ideally want to explore parameter space pretty broadly). Let's use Latin Hypercube Sampling to generate some training data.

We'll also write a small function that turns the flux into log-flux, which we're going to train on. This ensures that our neural network cannot ever generate negative fluxes:

In [None]:
def transform_targets(y):
    """
    Transform targets (flux as a function of energy)
    into logarithmic space for easier performance.

    Parameters
    ----------
    y : array
        The array with flux in linear space

    Returns
    -------
    y_log : array
        The array with flux in log space

    """
    y_log = np.log(y)
    # log1p is stable for small counts and handles 0 gracefully
    return y_log

def inverse_transform_targets(y_log):
    """
    Inverse-transform targets (flux as a function of energy)
    from logarithmic space (back) into linear space.

    Parameters
    ----------
    y_log : array
        The array with flux in log space

    Returns
    -------
    y : array
        The array with flux in linear space
    """
    return np.exp(y_log)


In [None]:
# number of training examples: we will experiment with this later!
n_samples = 5000

# simulate the data 
X, Y_linear = simulate_dataset_random(n_samples, energy, noisy=False)

# transform fluxes into log-space
Y = transform_targets(Y_linear).astype(np.float32)


## Training, validation, and test sets

In real projects you typically want **three** disjoint sets:

- **Training set:** Fit the model parameters (weights/biases). This is the data we're training the model on.
- **Validation set:** This is data that the model doesn't see during learning, and we use it like we did above to check how the model is doing, figure out bugs (like the fact that currently our model is allowed to predict negative fluxes), and also tune hyperparameters. Here, hyperparameters are the details of the neural network (its architecture, like the number of nodes and layers), of the optimization (e.g. the learning rate parameter some optimization algorithms use) and other features such as dropout, the batch size, etc.
- **Test set:** We will also generate a test set, which we do not touch at all during training/validation. We will reserve it for reporting the final performance at the end. These are the numbers that go into the paper. But we already have a validation set, you say? Yes, but there's a subtle, but important point here: while we don't use the validation set directly in training, we do use it to make decisions about the model and the training algorithms. In this way, the model indirectly does "see" the validation data, and it has been shown to be possible to overfit on the validation data. The test set is to make sure we get an unbiased, independent measure of our performance at the end.

**Important:**  
You should not repeatedly look at the test set while iterating, or it stops being a good (unbiased) test.

We can use a function from the library `scikit-learn` to generate a train-test split. We'll do this twice: once to first split out the dataset we'll train on, and from the remainder, we'll split out our test set (which we'll put aside until the end), and the validation data set.

In [None]:
# Train/val/test split: 70/15/15
X_train, X_tmp, Y_train, Y_tmp = train_test_split(X, Y, test_size=0.30, random_state=123)
X_val, X_test, Y_val, Y_test = train_test_split(X_tmp, Y_tmp, test_size=0.50, random_state=123)

print(X_train.shape, X_val.shape, X_test.shape)

### Scaling features (and targets)

Neural networks are sensitive to feature scaling. This means that they don't work very well when the input or output dimensions are orders of magnitude different to each other. For example, our power law index is relatively tightly constrained between 1 and 5, whereas the amplitude can vary over multiple orders of magnitude. This can be hard for neural networks, whose weights are usually initialized between 0 and 1, to learn, so we'll make it easier on our neural network and scale our **features** (the inputs to the neural networks, here the parameters of the physical model) and the **targets** (the output of the neural network, here the model spectra).

A common practice is to scale features such that they have a mean of zero and a variance of once. Note that this scaling is applied to the spectrum on a per-energy-bin basis, which makes the spectrum look *really* weird during training, but because it's an easily reversible operation, it's not a concern for our actual model prediction.

There are many other scalers (see e.g. [here](https://scikit-learn.org/stable/auto_examples/preprocessing/plot_all_scaling.html#sphx-glr-auto-examples-preprocessing-plot-all-scaling-py)), and which one you use heavily depends on your problem. There are also more advanced options such as [Fourier Feature Encoding](https://sair.synerise.com/fourier-feature-encoding/) (see also [this paper](https://bmild.github.io/fourfeat/)) that can help you with your neural network performance. This is a huge rabbit hole to fall down into, which we're not going to do today.

That that you should fit scalers **only on the training set** and apply to validation/test data.

In [None]:
# scaler for the inputs
x_scaler = StandardScaler().fit(X_train)

# scaler for the outputs
y_scaler = StandardScaler().fit(Y_train)

# transform train/val/test features
Xtr = x_scaler.transform(X_train).astype(np.float32)
Xva = x_scaler.transform(X_val).astype(np.float32)
Xte = x_scaler.transform(X_test).astype(np.float32)

# transform train/val/test targets
Ytr = y_scaler.transform(Y_train).astype(np.float32)
Yva = y_scaler.transform(Y_val).astype(np.float32)
Yte = y_scaler.transform(Y_test).astype(np.float32)

nE = Ytr.shape[1]
print("nE:", nE)

### Baseline model: linear surrogate

Before reaching for “deep learning”, we'll start with something that's easier to understand.

A linear model cannot represent everything, but here it might do surprisingly well because
a power law is “simple”. This is useful as a *sanity baseline*.

We'll train:
- linear surrogate (same as before, but with enough data)
- monitor learning curves

#### Some housekeeping

When we have large data sets, the whole data set will not fit into the memory of your GPU. In these cases, we train in **batches**: During each epoch, the training data set is split up randomly into smaller batches of $M$ training examples (often, some multiple of 2, like 64 or 128; doing it *not* in multiples of two is a great way to get some computer scientists **really** angry). In each epoch, you iteratively run through all of your batches first, then start the next training epoch. This has the advantage of often being significantly faster and computationally less demanding. You also make your optimization more noisy. Surprisingly, this can actually *help* with better training, because you might avoid local minima better than if you trained on all of your training data simultaneously.

`pytorch`, the library we're using today, has a bunch of infrastructure that abstracts a lot of that away from you, if you put your training data into the right format. Specifically, the `Dataset` and `DataLoader` classes are really helpful in making that work reliably (see the tutorial [here](https://docs.pytorch.org/tutorials/beginner/basics/data_tutorial.html). 

Let's put our data into the right format:

In [None]:
def make_dataloader(X, Y, batch_size=128, shuffle=True):
    """
    Take our data and turn it into a DataLoader object

    Parameters
    ----------
    X, Y : array, array
        The (training) data features and targets

    batch_size : int, default 128
        The size for each training batch

    shuffle : bool, default True
        If true, shuffle the data into new batches 
        for every epoch.

    Returns
    -------
    dl : pytorch.DataLoader object
        The DataLoader with the training/test/validation data
    """
    # first put data into a `DataSet` object
    ds = TensorDataset(torch.from_numpy(X), torch.from_numpy(Y))

    # now build DataLoader
    dl = DataLoader(ds, batch_size=batch_size, shuffle=shuffle)
    return dl


Let's try it:

In [None]:
train_loader = make_dataloader(Xtr, Ytr, batch_size=256, shuffle=True)
val_loader   = make_dataloader(Xva, Yva, batch_size=256, shuffle=True)

Now we have to rewrite our training loop in order to include the DataLoaders. We'll also include a couple of nifty machine learning tricks:

* we'll save the best-performing model so that we can pull it out again later
* we'll also use **early stopping**: this is a concept where we want to stop training when the performance on the *validation data set* no longer improves, because then we're moving into the realm of overfitting.

In [None]:
def train_epochs(model, train_loader, val_loader, epochs=200, patience=20,
                print_every=50):

    """
    Training loop for training our neural network.

    Parameters
    ----------
    model : pytorch NN model
        The neural network model

    train_loader, val_loader : pytorch.DataLoader objects
        The data loader objects with our training/validation data

    epochs : int, default 200
        The number of epochs to run for

    patience : int, default 20
        The number of epochs to train without improvement
        before stopping

    print_every : int, default 50
        Print losses every `print_every` epochs

    Returns
    -------
    hist : dict
        A dictionary with the training and validation losses
    """
    # initialize best validation performance and model state
    best_val = np.inf
    best_state = None

    # counter for the number of validation losses worse 
    # than the best one
    bad = 0

    # initialize dictionary for storing the losses
    hist = {"train": [], "val": []}

    # run through epochs
    for ep in range(1, epochs+1):
        model.train()
        tl=[]

        # iterate through the training DataLoader (i.e. over batches)
        for xb, yb in train_loader:
            # zero out the gradients
            opt.zero_grad(set_to_none=True)

            # predict the model fluxes
            pred = model(xb)

            # compute the loss
            loss = loss_fn(pred, yb)

            # backprop
            loss.backward()

            # step in parameter space
            opt.step()
            tl.append(loss.item())

        # store mean loss over all batches as 
        # the training loss
        train_loss = float(np.mean(tl))

        # run through validation batches and 
        # compute mean loss
        model.eval()
        vl=[]
        with torch.no_grad():
            for xb, yb in val_loader:
                vl.append(loss_fn(model(xb), yb).item())
        val_loss = float(np.mean(vl))

        hist["train"].append(train_loss)
        hist["val"].append(val_loss)
        
        # print losses
        if ep % print_every == 0:
            print(f"epoch {ep:4d}  train MSE={train_loss:.4e}  val MSE={val_loss:.4e}")


        # if validation loss is better than previous 
        # best value, then keep this as the new best model
        # and reset "bad model" counter
        if val_loss < best_val - 1e-6:
            best_val = val_loss
            best_state = {k: v.detach().clone() for k,v in model.state_dict().items()}
            bad = 0
        else:
            bad += 1
            # if we've had more bad models than our `patience` parameter
            # stop training, because it ain't getting better!
            if bad >= patience:
                break

    if best_state is not None:
        model.load_state_dict(best_state)

    return hist



In [None]:
# our linear model
linear_model = nn.Linear(2, nE)

# set the learning rate for the optimizer
lr = 1e-3

# same optimizer as previously
opt = torch.optim.Adam(linear_model.parameters(), lr=lr)

# same loss function as previously
loss_fn = nn.MSELoss()

#let's train!
hist_lin = train_epochs(linear_model, train_loader, val_loader, epochs=200, patience=100)

In [None]:
# plot the loss function for the training and validation data

plt.figure(figsize=(6,4))
plt.plot(hist_lin["train"], label="train")
plt.plot(hist_lin["val"], label="val")
plt.yscale("log")
plt.xlabel("epoch")
plt.ylabel("MSE (scaled log-count space)")
plt.title("Learning curves: linear surrogate")
plt.legend()
plt.tight_layout()
plt.show()

We're going to make a function that plots an example spectrum, the neural network prediction, and also the absolute residuals, defined as

$$
r = \frac{|y - \hat{y}|}{y}
$$

In [None]:
def plot_example(energy, ytrue, ypred, title=None):
    """
    Plot an example of the data and the neural network prediction,
    as well as the relative error.

    Parameters
    ----------
    energy : array
        The array of energies

    ytrue : array
        The array with the true model values

    ypred : array
        The array with the model values predicted 
        by the neural network

    ax : plt.Axes object
        a matplotlib object to plot the results into

    Returns
    -------
    ax : plt.Axes object
        a matplotlib object to plot the results into
    """

    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(6,6), height_ratios=(3,1), sharex=True)

    ax1.plot(energy, ytrue, lw=1, marker="^", color="black", markersize=3, label="true model")
    ax1.plot(energy, ypred, lw=1, marker="o", color=sns.color_palette()[1], markersize=3, label="NN prediction")
    ax1.set_ylabel("Flux [arbitrary units]")
    ax1.set_yscale("log")

    rel_error = np.abs(ytrue - ypred)/ytrue
    ax2.plot(energy, rel_error, lw=1, color="black", label="relative error")
    ax2.axhline(0.0, lw=2, color="red")
    ax2.set_xlabel("Energy [keV]")
    if title:
        ax1.set_title(title)
    
    return fig, ax1, ax2


Let's plot some examples: you can re-run the cell below multiple times to look at different random samples from the training and validation set:

In [None]:
# evaluate model
linear_model.eval()

# fix gradients
with torch.no_grad():
    yhat_train = linear_model(torch.from_numpy(Xtr)).numpy()
    yhat_val  = linear_model(torch.from_numpy(Xva)).numpy()

# number of training and validation examples
ntrain = Xtr.shape[0]
nval = Xva.shape[0]

# pick random indices for plotting
idx_train = np.random.randint(0, ntrain)
idx_val = np.random.randint(0, nval)

# invert scaling to counts space
yhat_train_counts = y_scaler.inverse_transform(yhat_train)[idx_train]
ytrain_counts     = Y_train[idx_train]

yhat_val_counts = y_scaler.inverse_transform(yhat_val)[idx_val]
yval_counts     = Y_val[idx_val]


plot_example(energy, np.exp(ytrain_counts), np.exp(yhat_train_counts), title="Training example: target vs prediction")
plot_example(energy, np.exp(yval_counts),   np.exp(yhat_val_counts),   title="Validation example: target vs prediction")

plt.show()

Okay, so that looks surprisingly good! Even a linear model can apparently approximate the power law pretty well. 

## Nonlinear surrogate: add a hidden layer

To emulate more complex forward models, you generally need nonlinearity.
A minimal neural network for regression is:

`(inputs) -> Linear -> Activation function -> Linear -> (outputs)`

The linear layer in the middle is called a "hidden layer", and the more of these you add, the easier it will be to model non-linearity. Every time you add a layer, you also add parameters, so be careful: your model can very quickly become very big (however, this is often not an issue in machine learning, as long as you have enough training data and enough computational time).

A neural network of this type is also traditionally called a **Multi-Layer Perceptron** (MLP). If you add more than, say, 2-3 layers, this is what people generally call **deep learning**. 

In the context of making our model non-linear, we need to talk about another concept: **activation functions**. Remember that our linear model from earlier was of the type:

$$
\hat{\mathbf{y}} = W\mathbf{x} + \mathbf{b}
$$

What we're going to do now is actually stack two of those together, such that:

$$
\mathbf{a} = \sigma(W_1\mathbf{x} + \mathbf{b_1})
$$

and 

$$
\hat{\mathbf{y}} = W_2\mathbf{a} + \mathbf{b_2}
$$

Here we've done two things. First, we've stacked two linear models together, such that the first layer produces intermediate values $a$ and the second layer relates these intermediate values $a$ to the outputs we want to predict $\hat{\mathbf{y}}$. You'll notice that there's this function $\sigma()$ in there as well. This is the activation function. It takes our linear model and makes it non-linear, by explicitly squashing the outputs of the linear layer through a non-linearity. A function that people were using early on (and that's still very useful in some cases) is the **sigmoid function**: $\sigma(z) = \frac{1}{1 + e^{-z}}$.

The sigmoid function can cause some issues in deep networks, so there are a wide range of alternatives to use. A typical one is called the **Rectified Linear Unit (ReLU)**, which is simply zero below zero and a linear function above, $\sigma(z) = \mathrm{max}(0, z)$.


Let's write our model into a class in typical PyTorch syntax:

In [None]:
# The class can be any name, but should inherit from `nn.Module`
class MLP(nn.Module):
    # all classes need an `__init__` method, here this allows you to 
    # set the dimension of the outputs and the number of nodes in the hidden layer
    def __init__(self, in_dim=2, out_dim=60, hidden=64):
        """
        Initialize neural network class. 

        Parameters
        ----------
        in_dim : int
            The number of input dimensions (i.e. number of 
            parameters in the physics model)

        out_dim : int
            The number of output dimensions (i.e. the number 
            of fluxes, corresponding to the number of energy 
            bins in the spectrum)

        hidden : int
            The number of nodes in the hidden layer
        """
        super().__init__()

        # nn.Sequential is a container for a sequence of modules, 
        # here linear layers and activation functions
        self.net = nn.Sequential(
            # weights going from inputs to hidden layer
            nn.Linear(in_dim, hidden),
            # activation on hidden nodes
            nn.ReLU(),
            # weights going from hidden layer to outputs
            nn.Linear(hidden, out_dim),
        )
    def forward(self, x):
        """
        Forward pass through the neural network.
        Means you stick in parameters and get a predicted
        spectrum out.

        Parameters
        ----------
        x : array
            The model inputs (i.e. the parameters for which you want
            a model)

        """
        return self.net(x)


We can now initialize and train it the same way we did our previous network:

In [None]:
# our linear model
mlp = MLP(in_dim=2, out_dim=nE, hidden=128)

# set the learning rate for the optimizer
lr = 1e-3

# same optimizer as previously
opt = torch.optim.Adam(mlp.parameters(), lr=lr)

# same loss function as previously
loss_fn = nn.MSELoss()

#let's train!
hist_mlp = train_epochs(mlp, train_loader, val_loader, epochs=3000, patience=40)


In [None]:
# plot the loss functions for training and validation data sets:
plt.figure(figsize=(6,4))
plt.plot(hist_mlp["train"], label="train")
plt.plot(hist_mlp["val"], label="val")
plt.yscale("log")
plt.xlabel("epoch")
plt.ylabel("MSE (scaled log-count space)")
plt.title("Learning curves: linear surrogate")
plt.legend()
plt.tight_layout()
plt.show()

Let's look at some examples again:

In [None]:
linear_model.eval()

with torch.no_grad():
    yhat_train = mlp(torch.from_numpy(Xtr)).numpy()
    yhat_val  = mlp(torch.from_numpy(Xva)).numpy()


ntrain = Xtr.shape[0]
nval = Xva.shape[0]

idx_train = np.random.randint(0, ntrain)
idx_val = np.random.randint(0, nval)

# invert scaling to counts space
yhat_train_counts = y_scaler.inverse_transform(yhat_train)[idx_train]
ytrain_counts     = np.exp(Y_train[idx_train])

yhat_val_counts = y_scaler.inverse_transform(yhat_val)[idx_val]
yval_counts     = np.exp(Y_val[idx_val])

plot_example(energy, ytrain_counts, np.exp(yhat_train_counts), title="Training example: target vs prediction")
plot_example(energy, yval_counts,   np.exp(yhat_val_counts),   title="Validation example: target vs prediction")


That looks very good, but not perfect. Depending on your example, you might get a prediction that's systematically a few percent off. 

### How do biases affect physical inference?

Ultimately, what we need to know (and what we want to evaluate performance on, is whether a biased surrogate model leads to biased inferences.

Let's give this a try, and run a quick bit of MCMC. First, let's simulate a dataset:

In [None]:
# as a reminder, these are the ranges within which we 
# simulated training data
# !! IMPORTANT: do not try to infer parameters outside this range !!
amp_range = (5e-4, 5e-2) 
gamma_range = (1.0, 3.0)

# true parameters we're going to simulate a dataset from
# we're going to work with log-amplitude rather than amplitude
# for easier sampling
true_pars = [np.log(1e-3), 1.2]
exposure = 25000.0

# need to crank up exposure time otherwise the counts will be zero
flux = powerlaw_counts(energy, np.exp(true_pars[0]), true_pars[1], e0=1.0, exposure=exposure)
counts = np.random.poisson(flux)

fig, ax = plt.subplots(1, 1, figsize=(8,4))
ax.plot(energy, flux, color="black", label="Flux without noise")
ax.errorbar(energy, counts, yerr=np.sqrt(counts), marker="o", ls="", markersize=3, label="counts")
ax.set_yscale("log")
ax.legend()
ax.set_xlabel("Energy [keV]")
ax.set_ylabel("Counts")

Now we have to define a likelihood and some priors to actually do the sampling:

In [None]:
def logprior(pars):
    """
    log-prior for the parameters. We're going to 
    pick flat priors within the bounds within which we've
    simulated our training data

    Parameters
    ----------
    pars : iterable
        a set of parameters, of the form [log_amp, gamma]

    Returns
    -------
    log_prior : float
        The logarithm of the prior probability distribution
        for the parameters in `pars`. Outside of prior bounds, 
        will return `-np.inf`    
    """
    log_amp = pars[0]
    gamma = pars[1]

    if log_amp < np.log(5e-4) or log_amp > np.log(5e-2) or gamma < 1.0 or gamma > 3.0:
        return -np.inf

    else:
        return 1.0

def loglikelihood(pars, energy, counts):
    """
    Poisson log-likelihood for a power law model 
    with parameters `amp` and `gamma`.

    Uses the neural network emulator. 
    This function will first transform the parameters 
    in `pars` to the scaled space in which the neural network
    operates, predict the model spectrum given those parameters 
    and transform that model spectrum from the neural-network 
    space into counts space

    
    Parameters
    ----------
    pars : iterable
        a set of parameters, of the form [log_amp, gamma]

    energy : array
        The array with photon energy bins

    counts : array
        The array with (poisson-distributed) photon counts 
        for each of the bins in `energy`

    Returns
    -------
    log_likelihood : float
        The logarithm of the Poisson likelihood for 
        the power law model, the parameters in `pars` 
        and the counts in `counts`
    """
    X_pars = x_scaler.transform(pars.reshape(1, -1))
    X_pars = torch.from_numpy(np.array(X_pars, dtype=np.float32))
    
    mpred = mlp(X_pars).detach().numpy().reshape(1, -1)
    model_flux = np.exp(y_scaler.inverse_transform(mpred)) * exposure

    loglike = np.sum(-model_flux + counts * np.log(model_flux) - scipy_gammaln(counts + 1.0))

    if not np.isfinite(loglike):
        return -np.inf
    else:
        return loglike

def logposterior(pars, energy, counts):
    """
    log-posterior function
    """
    return logprior(pars) + loglikelihood(pars, energy, counts)

Let's try it out:

In [None]:
# log prior with the true parameters
print(logprior(true_pars))

# log prior with bad parameters
print(logprior([-10, 5]))


Let's sample this:

In [None]:
ndim, nwalkers = 2, 100
ivar = 1. / np.random.rand(ndim)
p0 = true_pars + np.random.randn(nwalkers, ndim)*0.01*true_pars

sampler = emcee.EnsembleSampler(nwalkers, ndim, logposterior, args=[energy, counts])
_, _, _ = sampler.run_mcmc(p0, 1000)


Let's make some trace plots:

In [None]:
samples = sampler.flatchain

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 6))

ax1.plot(samples[:,0], lw=1, color="black")
ax1.axhline(true_pars[0], lw=2, color="red")

ax2.plot(samples[:,1], lw=1, color="black")
ax2.axhline(true_pars[1], lw=2, color="red")


Let's make a corner plot:

In [None]:
_ = corner.corner(samples, truths=true_pars)

So just from those plots, it looks like it might be a little bit biased, but we can't necessarily tell from a single example. The posterior for a given parameter set and realization *should* randomly vary around the true value.

And posterior predictive plots:

In [None]:
samples.shape

In [None]:
def posterior_predictive(samples, energy, counts, nsamples=20, ax = None, exposure=1):
    """
    Make a posterior predictive plot
    """
    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=(8,5))

    ax.errorbar(energy, counts, yerr=np.sqrt(counts), marker="o", ls="",
                color="black", markersize=3, label="counts")


    idx_all = np.random.choice(np.arange(0, samples.shape[0], 1, dtype=int), size=nsamples, replace=False)

    for i,idx in enumerate(idx_all):
        # get a single set of parameters from posterior
        p = samples[idx]

        # cast parameters into scaled version for neural network
        X_pars = x_scaler.transform(p.reshape(1, -1))
        X_pars = torch.from_numpy(np.array(X_pars, dtype=np.float32))

        # predict model flux, then rescale to original flux units
        mpred = mlp(X_pars).detach().numpy().reshape(1, -1)
        model_flux = np.exp(y_scaler.inverse_transform(mpred)) * exposure

        # random draws from Poisson distribution to simulate data sets from posterior
        model_counts = np.random.poisson(model_flux)
        if i == 0:
            m1_legend = "draws from original model"
            m2_legend = "draws from neural network"
        else:
            m1_legend = ""
            m2_legend = ""

        ax.scatter(energy, model_counts, color=sns.color_palette()[1], alpha=0.3, label=m2_legend)

        # generate the model spectrum with the original "physics" model
        mf_physics = powerlaw_counts(energy, np.exp(p[0]), p[1], exposure=exposure)
        # include Poisson data
        mc_physics = np.random.poisson(mf_physics)
        ax.scatter(energy, mc_physics, color="purple", alpha=0.1, label=m1_legend)


    ax.set_xlabel("Energy [keV]")
    ax.set_yscale("log")
    ax.set_ylabel("Counts")
    ax.legend()
    return ax

In [None]:
posterior_predictive(samples, energy, counts, exposure=exposure)

Okay, great! That seems to work! Hooray!

### Evaluate on the test set

Now we use the test set **once** to quantify generalization.
We'll compute:
- MSE in scaled space (what we trained on)
- and a more interpretable metric: median absolute fractional error in counts space

In [None]:
def predict_counts(model, X_scaled):
    model.eval()
    with torch.no_grad():
        yhat_scaled = model(torch.from_numpy(X_scaled)).numpy()
    yhat_log = y_scaler.inverse_transform(yhat_scaled)
    yhat_counts = inverse_transform_targets(yhat_log)
    return yhat_counts

# Predict on test set
yhat_counts_lin = predict_counts(linear_model, Xte)

# True counts for test (unscaled) - recover from Y_test (log space)
ytrue_counts = inverse_transform_targets(Y_test)

mse_counts = np.mean((yhat_counts_lin - ytrue_counts)**2)
frac_err = np.abs(yhat_counts_lin - ytrue_counts) / ytrue_counts
med_frac = np.median(frac_err)

print("Test MSE in counts space:", mse_counts)
print("Median |fractional error|:", med_frac)

# Plot a few random test cases
idx = rng.choice(len(X_test), size=3, replace=False)
for j,i in enumerate(idx):
    plot_spectrum(energy, ytrue_counts[i], yhat_counts_lin[i],
                  title=f"Linear surrogate on test example {j}")
plt.show()

## Broken Power Laws

A slightly more difficult (toy) model: a broken power law.

We choose:
- energy grid: 0.3–10 keV
- pivot energy: 1 keV (just for context)
- **broken broken power-law shape parameters**: low-energy photon index $\Gamma_1$, high-energy photon index $\Gamma_2$, and break energy $E_b$

We will *learn* the mapping:

$$
(\Gamma_1, \Gamma_2, \log_{10} E_b) \rightarrow \text{spectrum}(E)
$$

**Important:** we do **not** train on an overall normalization/amplitude. Instead we normalize each simulated spectrum to **unit total counts**. That means the network learns only the **shape**, not the amplitude.

In real fitting, you can often recover amplitude later by fitting a single scalar multiplier at the same time as the shape parameters.

In [None]:
def broken_powerlaw_counts(energy, gamma1, gamma2, eb, amp=1.0, exposure=1.0):
    """Deterministic toy counts spectrum for a *broken* power law (no noise).

    The photon flux is continuous at the break energy Eb by construction:

        F(E) = Ab * (E/Eb)^(-Gamma1)   for E < Eb
             = Ab * (E/Eb)^(-Gamma2)   for E >= Eb

    Parameters
    ----------
    energy : array, shape (nE,)
        Energy grid in keV.
    gamma1, gamma2 : float
        Photon indices below/above the break.
    eb : float
        Break energy in keV.
    amp : float
        Normalization defined at E = Eb (kept fixed here; we do NOT train on amplitude).
    exposure : float
        Exposure time in arbitrary units.

    Returns
    -------
    counts : array, shape (nE,)
        Expected counts per energy bin (arbitrary units).
    """
    flux = np.where(energy < eb, (energy/eb)**(-gamma1), (energy/eb)**(-gamma2))
    photon_flux = amp * flux
    counts = photon_flux * exposure
    return counts

### Simulating the Data

In [None]:
def simulate_bpl_dataset_lhc(n_samples, energy,
                         gamma1_range=(0.0, 1.5),
                         gamma2_range=(1.5, 5.0),
                         eb_range=(0.5, 5.0),
                         exposure=1.0,
                         noisy=False):

    """
    Simulate a dataset of (parameters -> spectrum), 
    using a Latin Hypercube Sampling.

    Parameters
    ----------
    n_samples : int
        The number of samples to draw in each dimension. Full dataset will be 
        of size nsamples

    energy : array
        The energy grid to simulate the data over

    logamp_range : (float, float)
        The lower and upper boundary between which to sample 
        the amplitude

    gamma_range : (float, float)
        The lower and upper boundary between which to sample 
        the power law index

    exposure : float
        The exposure of the observation

    noisy : bool, default False
        If True, add Poisson noise to the spectrum

    Returns
    -------
    X : array of shape (nsamples, 2)
        The array containing all the pairs of parameters

    Y : array of (nsamples, len(energy))
        The array with all simulated spectra 
    """
    # set up Latin Hypercube Sampling
    sampler = qmc.LatinHypercube(d=3)

    # sample randomly using LHS
    sample = sampler.random(n=n_samples)

    # set array of lower and upper bounds
    l_bounds = [gamma1_range[0], gamma2_range[0], eb_range[0]]
    u_bounds = [gamma1_range[1], gamma2_range[1], eb_range[1]]

    # scale sample to bounds
    X = qmc.scale(sample, l_bounds, u_bounds)

    Y = np.stack([broken_powerlaw_counts(energy, gamma1_i, gamma2_i, eb_i, amp=1.0, exposure=1.0) for gamma1_i, gamma2_i, eb_i in X]).astype(np.float32)

    if noisy:
        # Poisson noise around the expected counts
        Y = rng.poisson(Y).astype(np.float32)

    return X, Y

In [None]:
# make some simulated data 
X_lhs, Y_lhs = simulate_bpl_dataset_lhc(5000, energy)
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10,5))

ax1.scatter(X_lhs[:,0], X_lhs[:,1], s=2)

for i in range(3):
    plot_spectrum(energy, Y_lhs[i], ax=ax2,
                  title=f"Example spectrum {i}  (logA={X_lhs[i,0]:.2f}, Γ={X_lhs[i,1]:.2f})")
plt.show()

### Pre-processing

Okay, we'll have to do similar preprocessing as earlier:

In [None]:
# transform targets to log flux
Y = transform_targets(Y_lhs).astype(np.float32)


Next, the train-validation-test split:

In [None]:
# Train/val/test split: 70/15/15
X_train, X_tmp, Y_train, Y_tmp = train_test_split(X_lhs, Y, test_size=0.30, random_state=123)
X_val, X_test, Y_val, Y_test = train_test_split(X_tmp, Y_tmp, test_size=0.50, random_state=123)

print(X_train.shape, X_val.shape, X_test.shape)

Now we're applying our standard scaler again:

In [None]:
# scaler for the inputs
x_scaler = StandardScaler().fit(X_train)

# scaler for the outputs
y_scaler = StandardScaler().fit(Y_train)

# transform train/val/test features
Xtr = x_scaler.transform(X_train).astype(np.float32)
Xva = x_scaler.transform(X_val).astype(np.float32)
Xte = x_scaler.transform(X_test).astype(np.float32)

# transform train/val/test targets
Ytr = y_scaler.transform(Y_train).astype(np.float32)
Yva = y_scaler.transform(Y_val).astype(np.float32)
Yte = y_scaler.transform(Y_test).astype(np.float32)

nE = Ytr.shape[1]
print("nE:", nE)

We have to define our DataLoader classes again:

In [None]:
train_loader = make_dataloader(Xtr, Ytr, batch_size=256, shuffle=True)
val_loader   = make_dataloader(Xva, Yva, batch_size=256, shuffle=True)

### Training the One-Layer MLP

Let's train our simple one-layer MLP:

In [None]:
# our linear model
mlp = MLP(in_dim=3, out_dim=nE, hidden=128)

# set the learning rate for the optimizer
lr = 1e-3

# same optimizer as previously
opt = torch.optim.Adam(mlp.parameters(), lr=lr)

# same loss function as previously
loss_fn = nn.MSELoss()

#let's train!
hist_mlp = train_epochs(mlp, train_loader, val_loader, epochs=3000, patience=40)

plt.figure(figsize=(6,4))
plt.plot(hist_mlp["train"], label="train")
plt.plot(hist_mlp["val"], label="val")
plt.yscale("log")
plt.xlabel("epoch")
plt.ylabel("MSE (scaled log-count space)")
plt.title("Learning curves: linear surrogate")
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
mlp.eval()

with torch.no_grad():
    yhat_train = mlp(torch.from_numpy(Xtr)).numpy()
    yhat_val  = mlp(torch.from_numpy(Xva)).numpy()


ntrain = Xtr.shape[0]
nval = Xva.shape[0]

idx_train = np.random.randint(0, ntrain)
idx_val = np.random.randint(0, nval)

# invert scaling to counts space
yhat_train_counts = y_scaler.inverse_transform(yhat_train)[idx_train]
ytrain_counts     = np.exp(Y_train[idx_train])

yhat_val_counts = y_scaler.inverse_transform(yhat_val)[idx_val]
yval_counts     = np.exp(Y_val[idx_val])

plot_example(energy, ytrain_counts, np.exp(yhat_train_counts), title="Training example: target vs prediction")
plot_example(energy, yval_counts,   np.exp(yhat_val_counts),   title="Validation example: target vs prediction")


### Adding More Layers

So, that's a bit harder for the model to learn, because of the discontinuity in the data. Let's try to make the model more complex. We're going to simply add more hidden layers, and we'll implement in a way so that you can decide afterwards how many layers you want, and how many nodes each of them should have.

In [None]:

class MLP(nn.Module):
    """
    Dynamically constructed multi-layer perceptron.

    Architecture:
        Linear -> ReLU -> Linear -> ReLU -> ... -> Linear (output)
    i.e. ReLU after every hidden layer, no activation on the output layer.

    Parameters
    ----------
    in_dim : int
        Input feature dimension.
    out_dim : int
        Output dimension.
    hidden_dims : Sequence[int]
        Hidden layer widths, e.g. (64, 64, 32).
    """
    def __init__(self, in_dim, out_dim, hidden_dims):
        super().__init__()

        # total dimensions includes input and output dimensions
        dims = [in_dim, *hidden_dims, out_dim]

        # generate list of layers
        layers = []
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i + 1]))
            # ReLU after all layers except the last (output) layer
            if i < len(dims) - 2:
                layers.append(nn.ReLU())

        self.net = nn.Sequential(*layers)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

In [None]:
in_dim = 3
out_dim = Y_train.shape[1]

# PICK A NUMBER OF HIDDEN LAYERS, AND A NUMBER OF 
# NODES PER LAYER
# EXAMPLE:
# hidden_layers = [64, 128, 256, 64]
# produces 4 hidden layers, with the firs 64 nodes, 
# the second 128 nodes, and so on
hidden_layers = [] # ADD NUMBER OF LAYERS AND NODES IN THIS LIST

mlp = MLP(in_dim, out_dim, hidden_layers)

# set the learning rate for the optimizer
lr = 1e-4

# same optimizer as previously
opt = torch.optim.Adam(mlp.parameters(), lr=lr)

# same loss function as previously
loss_fn = nn.MSELoss()

#let's train!
hist_mlp = train_epochs(mlp, train_loader, val_loader, epochs=3000, patience=40)

plt.figure(figsize=(6,4))
plt.plot(hist_mlp["train"], label="train")
plt.plot(hist_mlp["val"], label="val")
plt.yscale("log")
plt.xlabel("epoch")
plt.ylabel("MSE (scaled log-count space)")
plt.title("Learning curves: linear surrogate")
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
mlp.eval()

with torch.no_grad():
    yhat_train = mlp(torch.from_numpy(Xtr)).numpy()
    yhat_val  = mlp(torch.from_numpy(Xva)).numpy()


ntrain = Xtr.shape[0]
nval = Xva.shape[0]

idx_train = np.random.randint(0, ntrain)
idx_val = np.random.randint(0, nval)

# invert scaling to counts space
yhat_train_counts = y_scaler.inverse_transform(yhat_train)[idx_train]
ytrain_counts     = np.exp(Y_train[idx_train])

yhat_val_counts = y_scaler.inverse_transform(yhat_val)[idx_val]
yval_counts     = np.exp(Y_val[idx_val])

plot_example(energy, ytrain_counts, np.exp(yhat_train_counts), title="Training example: target vs prediction")
plot_example(energy, yval_counts,   np.exp(yhat_val_counts),   title="Validation example: target vs prediction")


At this point, you can experiment with
* the number of hidden layers
* the number of nodes per layer
* the activation function (see [here](https://docs.pytorch.org/docs/stable/nn.html))
* the learning rate and learning rate schedulers
* the batch size
* implement regularization methods like dropout


### How complexity and data size interact

Two knobs control generalization:
- **data size** (more examples reduces variance)
- **model capacity** (more expressive models reduce bias but can increase variance)

We'll run a small experiment:
- train the *same* MLP architecture on increasing numbers of training examples
- record validation error

This is a **learning curve**: it tells you whether more data is likely to help.

Let's simulate a lot more data:

In [None]:
# make some simulated data 
X_lhs, Y_lhs = simulate_bpl_dataset_lhc(100000, energy)

# transform targets to log flux
Y = transform_targets(Y_lhs).astype(np.float32)

# Train/val/test split: 90/5/5
X_train, X_tmp, Y_train, Y_tmp = train_test_split(X_lhs, Y, test_size=0.1, random_state=123)
X_val, X_test, Y_val, Y_test = train_test_split(X_tmp, Y_tmp, test_size=0.50, random_state=123)

print(X_train.shape, X_val.shape, X_test.shape)
# scaler for the inputs
x_scaler = StandardScaler().fit(X_train)

# scaler for the outputs
y_scaler = StandardScaler().fit(Y_train)

# transform train/val/test features
Xtr = x_scaler.transform(X_train).astype(np.float32)
Xva = x_scaler.transform(X_val).astype(np.float32)
Xte = x_scaler.transform(X_test).astype(np.float32)

# transform train/val/test targets
Ytr = y_scaler.transform(Y_train).astype(np.float32)
Yva = y_scaler.transform(Y_val).astype(np.float32)
Yte = y_scaler.transform(Y_test).astype(np.float32)

nE = Ytr.shape[1]
print("nE:", nE)

train_loader = make_dataloader(Xtr, Ytr, batch_size=256, shuffle=True)
val_loader   = make_dataloader(Xva, Yva, batch_size=256, shuffle=True)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10,5))

ax1.scatter(X_lhs[:,0], X_lhs[:,1], s=2, alpha=0.1)

for i in range(3):
    plot_spectrum(energy, Y_lhs[i], ax=ax2,
                  title=f"Example spectrum {i}  (logA={X_lhs[i,0]:.2f}, Γ={X_lhs[i,1]:.2f})")
plt.show()

Let's loop over different numbers of training examples and see how the performance changes:

In [None]:
n_train_all = [10, 100, 1000, 10000, 90000]

hist_train_all, hist_val_all = [], []
for n_train in n_train_all:
    print("="*30 + "\n")
    print(f"Training with {n_train} training samples")

    # sample subset from the existing training set
    idx = rng.choice(len(Xtr), size=n_train, replace=False)
    Xn = Xtr[idx]; Yn = Ytr[idx]
    
    small_train_loader = make_dataloader(Xn, Yn, batch_size=256, shuffle=True)
    
    in_dim = 3
    out_dim = Y_train.shape[1]
    
    # six hidden layers!!!
    hidden_layers = [64, 64, 64, 64]
    
    mlp = MLP(in_dim, out_dim, hidden_layers)
    
    # set the learning rate for the optimizer
    lr = 1e-4
    
    # same optimizer as previously
    opt = torch.optim.Adam(mlp.parameters(), lr=lr)
    
    # same loss function as previously
    loss_fn = nn.MSELoss()
    
    #let's train!
    hist_mlp = train_epochs(mlp, small_train_loader, val_loader, epochs=3000, patience=40)

    hist_val_all.append(np.min(hist_mlp["val"]))
    hist_train_all.append(np.min(hist_mlp["train"]))
    
    plt.figure(figsize=(6,4))
    plt.plot(hist_mlp["train"], label="train")
    plt.plot(hist_mlp["val"], label="val")
    plt.yscale("log")
    plt.xlabel("epoch")
    plt.ylabel("MSE (scaled log-count space)")
    plt.title(f"Learning curves: MLP surrogate, n={n_train} training examples")
    plt.legend()
    plt.tight_layout()
    plt.show()

One thing you might notice is that for few training examples, the training and validation scores scissor apart very quickly, and while the performance on the training data gets better and better, this is not true for the validation data. This is typical of **overfitting**: we have a very flexible model, and few training examples, so the model massively overfits on the few training examples it has, which means it generalizes really badly to new, unseen examples. 

As we add more training examples, training and validation performance become more and more similar. This could be a sign that adding more complexity to the model may be helpful in allowing the model to learn more intricate features.

Let's plot the best validation score as a function of training examples:

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(8,4))

ax.plot(n_train_all, hist_train_all, marker="o", lw=1, label="best training loss")
ax.plot(n_train_all, hist_val_all, marker="o", lw=1, label="best validation loss")

ax.set_xlabel("Number of training examples")
ax.set_ylabel("Best MSE loss")
ax.set_yscale("log")


Let's look at some examples:

In [None]:
mlp.eval()

with torch.no_grad():
    yhat_train = mlp(torch.from_numpy(Xtr)).numpy()
    yhat_val  = mlp(torch.from_numpy(Xva)).numpy()


ntrain = Xtr.shape[0]
nval = Xva.shape[0]

idx_train = np.random.randint(0, ntrain)
idx_val = np.random.randint(0, nval)

# invert scaling to counts space
yhat_train_counts = y_scaler.inverse_transform(yhat_train)[idx_train]
ytrain_counts     = np.exp(Y_train[idx_train])

yhat_val_counts = y_scaler.inverse_transform(yhat_val)[idx_val]
yval_counts     = np.exp(Y_val[idx_val])

plot_example(energy, ytrain_counts, np.exp(yhat_train_counts), title="Training example: target vs prediction")
plot_example(energy, yval_counts,   np.exp(yhat_val_counts),   title="Validation example: target vs prediction")


A useful way to visualize results is to look at the residuals as a function of input parameter. This allows you to see whether there are regions in energy where the model is doing worse, and also whether there are parts of parameter space where your model may be biased.

In [None]:
mlp.eval()

pars_train = x_scaler.inverse_transform(Xn)
pars_val = x_scaler.inverse_transform(Xva)

with torch.no_grad():
    yhat_train = mlp(torch.from_numpy(Xtr)).numpy()
    yhat_val  = mlp(torch.from_numpy(Xva)).numpy()


ntrain = Xtr.shape[0]
nval = Xva.shape[0]

# invert scaling to counts space
yhat_train_counts = np.exp(y_scaler.inverse_transform(yhat_train))
ytrain_counts     = np.exp(Y_train)

yhat_val_counts = np.exp(y_scaler.inverse_transform(yhat_val))
yval_counts     = np.exp(Y_val)

res_train = np.abs(ytrain_counts - yhat_train_counts) / ytrain_counts
res_val = np.abs(yval_counts - yhat_val_counts) / yval_counts


In [None]:
idx_train_gamma1 = np.argsort(pars_train[:,0])
idx_train_gamma2 = np.argsort(pars_train[:,1])
idx_train_eb = np.argsort(pars_train[:,2])

idx_val_gamma1 = np.argsort(pars_val[:,0])
idx_val_gamma2 = np.argsort(pars_val[:,1])
idx_val_eb = np.argsort(pars_val[:,2])

In [None]:
fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(10,7), sharey=True, sharex=True)
pcm1 = ax1.pcolormesh(res_train[idx_train_gamma1], vmax=0.02)
#fig.colorbar(pcm1, ax=ax1)
ax1.set_xlabel("Energy [keV]")
ax1.set_ylabel(r"$\Gamma_1$")
pcm2 = ax2.pcolormesh(res_train[idx_train_gamma2], vmax=0.02)
#fig.colorbar(pcm2, ax=ax2)
ax2.set_xlabel("Energy [keV]")
ax2.set_ylabel(r"$\Gamma_2$")

pcm3 = ax3.pcolormesh(res_train[idx_train_gamma2], vmax=0.02)
fig.colorbar(pcm3, ax=ax3)
ax3.set_xlabel("Energy [keV]")
ax3.set_ylabel(r"$E_b$")

fig.tight_layout()

I haven't had a chance to fix the tick labels for the axes, so they're not correct, but hopefully you get the idea! Ben has code to do this in a much nicer way! :) 

## Hyperparameter optimization

Run a grid of models with different numbers of layers and nodes per layer. Figure out what combination of nodes and layers does best.
Use ~10000 training examples, otherwise it'll never finish. :) 


## Final model and test evaluation (once)

Pick the hidden size with the best validation MSE, retrain on **train+val**, and evaluate on the held-out **test** set.

Further exploring:
* For the network above the energy grid is hardcoded into the surrogate model. This is not ideal if you want to use the surrogate model for more than one instrument. There are new neural network architectures like [Neural Operators](https://pytorch.org/blog/neuraloperatorjoins-the-pytorch-ecosystem/), [DeepONets](https://arxiv.org/abs/1910.03193) and [Neural Implicit Representations](https://medium.com/@nathaliemariehager/an-introduction-to-neural-implicit-representations-with-use-cases-ad331ca12907) that can abstract away from the grid and learn a model that predicts the output for any input (within the training data)
* The [Universal Approximation Theorem](https://en.wikipedia.org/wiki/Universal_approximation_theorem) guarantees that under certain conditions, a neural network can in principle approximate any continuous function to any desired degree of accuracy. This is great! The major problem: "in principle" might mean, in practice, more training data than you're willing (or able) to simulate. This is where you have to be clever, and impose structure on your neural network to help it learn. What that looks like depends on the specific problem in question, but [convolutional neural networks](https://en.wikipedia.org/wiki/Convolutional_neural_network) maybe be a starting point, as could be the now-ubiquitous [Transformers](https://en.wikipedia.org/wiki/Transformer_(deep_learning)), or one could explore [Physics-Informed Machine Learning](https://www.nature.com/articles/s42254-021-00314-5) to directly impose physics-based constraints onto the model.
* There are a lot of other things we haven't talked about that are now standard as part of machine learning development, such as [Dropout](https://towardsdatascience.com/dropout-in-neural-networks-47a162d621d9/) and [Batch Normalization](https://en.wikipedia.org/wiki/Batch_normalization).
* We wrote [a paper](https://ui.adsabs.harvard.edu/abs/2023arXiv231012528H/abstract) with best practices for machine learning in astronomy projects! Go and read it! 


And finally, experiment, experiment, experiment! For many machine learning problems in astronomy, there isn't a good default answer!