# 8. Coloring - DRAFT -
This notebook takes a look at coloring. Upon its first start, it will create a model for the example galaxy FCC167. Based on that model and observational data, the notebook will fit age and metallicity of the populations.

## 8.1. Prerequisites

Import the required modules.

In [None]:
import importlib
import numpy as np
from scipy.interpolate import griddata
import matplotlib.pyplot as plt
import cmasher

import pymc as pm  # only necessary for posterior plots

from vorbin.voronoi_2d_binning import voronoi_2d_binning

import dynamite as dyn

print('DYNAMITE version', dyn.__version__)
# print('    installed at ', dyn.__path__)  # Uncomment to print the complete DYNAMITE installation path

Run the DYNAMITE model. The configuration file `FCC167_config.yaml` fixes all parameter to a specific value, resulting in a single model. For performance, the orbit library is rather small, so the calculation time should only take a few minutes. Note that the model will not be run if already existing on the disk (e.g., if the notebook is re-run).

In [None]:
fname = 'FCC167_config.yaml'
c = dyn.config_reader.Configuration(fname, reset_logging=True)
# c = dyn.config_reader.Configuration(fname, reset_logging=True, reset_existing_output=True)
_ = dyn.model_iterator.ModelIterator(c)

## 8.2. Voronoi phase space binning
The model's orbits in the phase space of circularity $\lambda_z$ versus time-averaged orbital radius $r$ are binned into $N_\mathrm{bundle}$ Voronoi orbit bundles. Each Voronoi orbit bundle amounts to a certain minimum total weight.

First, we need to choose the underlying resolution in $r$ and $\lambda_z$:

In [None]:
# Number of desired r and lambda_z bins
nr = 20  # 6
nl = 41  # 7

Optional step: Before binning the orbits, let's have a look at the $r$ - $\lambda_z$ phase space. The standard DYNAMITE orbit distribution plot with the parameter `force_lambda_z=True` will plot the distribution of all orbits. Note that the plot title is incorrect (all orbits are included in the plot, not only the short axis tubes).

In [None]:
plotter = dyn.plotter.Plotter(c)

In [None]:
# NOTE: when using force_lambda_z=True, then the title of the orbit-distribution plot is incorrect: all orbits are shown in this distribution - not only short-axis tubes!
fig2 = plotter.orbit_distribution(model=None, minr=None, maxr=None, r_scale='linear', nr=nr, nl=nl,
                                  orientation='vertical', subset='short', force_lambda_z=True)

The Voronoi binning of orbits in the radius-circularity phase space is done by a method in the `dyn.coloring.Coloring` class:

In [None]:
coloring = dyn.coloring.Coloring(c, nr=nr, nl=nl)

The original `n_orbits` orbit bundles are binned into fewer $N_\mathrm{bundle}=$`n_bundle` Voronoi orbit bundles with each of these Voronoi bundles accounting for a weight of at least `vor_weight`.

In [None]:
vor_weight = 0.01  # 0.05  # define the desired (minimum) total orbital weight in each Voronoi bin

The method `coloring.bin_phase_space()` performs the binning of orbits, based on the best-fitting model so far (in our case this is the only model). The result of the binning is a tuple `(vor_bundle_mapping, phase_space_binning)`:
```
vor_bundle_mapping :    np.array of shape (n_bundle, n_orbits)
                        Mapping between the "original" orbit bundles and the Voronoi
                        orbit bundles: vor_bundle_mapping(i_bundle, i_orbit) is the
                        fraction of i_orbit assigned to i_bundle, multiplied by i_orbit's weight.
phase_space_binning :   dict
                        'in':   np.array of shape (3, nr*nl), the binning input:
                                bin r, bin lambda_z, bin total weight
                        'out':  np.array of shape (3, n_bundle), the Voronoi binning output:
                                weighted Voronoi bin centroid coordinates r_bar, lambda_bar
                                and Voronoi bin total weights
                        'map':  np.array of shape (nr*nl,) the phase space mapping:
                                Voronoi bin numbers for each input bin
```

As the binning can be time-consuming, `coloring.bin_phase_space()` will write the binning result to the model directory so that subsequent calls with the same parameters for the same model will read the existing binning from disk if `use_cache=True`.

In [None]:
vor_bundle_mapping, phase_space_binning = coloring.bin_phase_space(model=None,
                                                                   vor_weight=vor_weight,
                                                                   vor_ignore_zeros=False,
                                                                   make_diagnostic_plots=False,
                                                                   extra_diagnostic_output=True,
                                                                   cvt=False,
                                                                   wvt=False,
                                                                   use_cache=True)

Let's visualize the Voronoi bundle mapping, starting with the probability density of stellar orbits in the circularity-radius phase space, overlayed by teh Voronoi binning scheme:

In [None]:
coloring.orbit_bundle_plot(phase_space_mapping=phase_space_binning['map']);

Now, let's visualize how much weight each original orbit bundle contributes to each Voronoi orbit bundle.

Note that
- original orbit bundles with zero weight in the specific model will not contribute to any of the Voronoi orbit bundles;
- each original orbit bundle consists of one actual orbit if `dithering=1` in the configuration file's `orblib_settings`, but will comprise multiple orbits if ``dithering > 1``. In the latter case, one original orbit bundle can be split among neighboring bins in the radius-circularity phase space and can hence contribut to multiple (neighboring) Voronoi orbit bundles.

In [None]:
plt.figure(figsize=(24,4))
plt.gca().set_title('Weight that each orbit bundle contributes to Voronoi orbit bundles')
plt.pcolormesh(np.log10(vor_bundle_mapping), shading='flat', cmap='Greys')
plt.xlabel('Original orbit bundle id')
plt.ylabel('Voronoi orbit bundle id')
plt.colorbar(label='log Weight')

For the subsequent analysis, we will need to know how much mass (for mass-weighted models) or flux (for light-weighted models) each Voronoi orbit bundle contributes to each spatial bin. As the concept of orbit bundles per se is independent of coloring, the associated method is part of the `dyn.analysis.Analysis` class:

In [None]:
a = dyn.analysis.Analysis(c)  # orbit bundle maps are residing in the Analysis class

The method ``get_orbit_bundle_maps()`` returns an astropy table with as many rows as there are spatial bins and the number of columns is the number of Voronoi orbit bundles plus 1: Each Voronoi orbit bundle's contribution corresponds to one column and at the end there is the combined bundle's contribution in column 'flux_all'.

Setting the parameter `create_figure=True` will create a figure of the individual orbit bundles' contributions and return it along with the bundle maps (see the commented-out example below - try it...). Setting `normalize=True` will normalize the mass (flux) contributions so that in each spatial bin the sum of all orbit bundles' contributions is 1. We activate this option because we will need normalized flux in the Bayesian statistical analysis further down.

In [None]:
# bundle_maps, bundle_figure = a.get_orbit_bundle_maps(pop_set=0,
#                                                      bundle_mapping=vor_bundle_mapping,
#                                                      normalize=True,
#                                                      create_figure=True)  # comment-in to view the plots
bundle_maps = a.get_orbit_bundle_maps(pop_set=0,
                                      bundle_mapping=vor_bundle_mapping,
                                      normalize=True,
                                      create_figure=False)  # calculation only, won't display the plots

In [None]:
print(f'{type(bundle_maps) = }\n{len(bundle_maps) = }\n{bundle_maps.colnames = }')

As you see, the `bundle_maps` contain the Voronoi orbit bundles' contributions along with the aggregate map. For the subsequent calculations, we will only need the data for the individual Voronoi bundles which we store in `flux_data_norm`:

In [None]:
flux_data_norm = np.array([bundle_maps[a] for a in bundle_maps.columns if a != 'flux_all']).T
print(f'{flux_data_norm.shape = }') # (n_spatial_bins, n_bundle)

## 8.3. Bayesian statistical analysis

Fitting age and metallicity essentially follows the procedure described in Zhu et al., 2020, MNRAS, 496, 1579. For brevity, only the priors "R1" for both age and metallicity will be used. In the following sections, equation numbers refer to the corresponding equations in that paper.

The fitting is done via Bayesian inference as provided by the Python package PyMC. The model uses a truncated normal or lognormal distribution for the prior of the observed quantity and a Student's t distribution with fixed $\sigma$ (Half-Cauchy distributed with $\beta=5$) and $\nu$ (Exponential distributed with parameter 1/30) parameters for the likelihood of the observed data. The solution method uses the Markov chain Monte Carlo (MCMC) sampling algorithm NUTS (No-U-Turn Sampler), initialized with the ADVI(Automatic Differentiation Variational Inference) method.

There is no need to interact with PyMC. DYNAMITE users can call the coloring.fit_bayesian() method that has many settings predefined already (see below).

The chain is initialized with the ADVI (automatic differentiation variational inference) with 200000 draws and 2500 tuning steps. The last 500 MCMC steps are used to get the mean and standard deviation of the age and metallicity, respectively:

In [None]:
sample = {'n_draws': 500,
          'n_tune': 2500,
          'advi_init': 200000}

Finally, we need to assign some DYNAMITE data structures (note that currently, we only support one population dataset `stars.population_data[0]`):

In [None]:
stars = c.system.get_unique_triaxial_visible_component()
pops = stars.population_data[0]
age, dage, met, dmet = [pops.get_data()[i] for i in ('age', 'dage', 'met', 'dmet')]  # dage and dmet will not be used

### 8.3.1. Age

The age prior is a bounded normal distribution $f(t_k|\mu_k,\sigma_k)$ with a lower boundary of 0, an upper boundary of 20, and $\mu_k=\mathrm{Randn}(<t_\mathrm{obs}>,2\sigma(t_\mathrm{obs}))$ and $\sigma_k=2\sigma(t_\mathrm{obs})$, see Eq. (11)-(13):

In [None]:
prior_t = {'mu': np.random.normal(age.mean(), 2 * age.std(), size=len(vor_bundle_mapping)),  # Eq. (12)
           'sigma': 2 * age.std(),                                                           # Eq. (13)
           'lower': 0,
           'upper': 20}

In [None]:
model_t, trace_t = coloring.fit_bayesian(prior_dist='normal',
                                         prior_par=prior_t,
                                         flux_data_norm=flux_data_norm,
                                         obs_data=age,
                                         sample=sample)

Optional: inspect the Bayesian model details via

In [None]:
model_t

Note that PyMC per default uses four chains, corresponding to the physical CPU cores available. There are 3000 draws in each chain, corresponding to 2500 tuning steps and 500 draws for the results.

The resulting age values are accessible via `trace_t.posterior['qty']`:

In [None]:
trace_t.posterior['qty']

We store it in symbol `age_posterior` for later. It is a data structure compatible with an array of shape (< number of chains >, < number of draws >, <$N_\mathrm{bundle}$>). Consequently, the mean and error of each Voronoi orbit bundle's age are given by `age_mean = age_posterior.mean(axis=(0,1))` and `age_err = age_posterior.std(axis=(0,1))`, respectively:

In [None]:
age_posterior = trace_t.posterior['qty']
age_mean = age_posterior.mean(axis=(0,1))
age_err = age_posterior.std(axis=(0,1))

In [None]:
age_mean

Optional: display the posterior plot (requires the PyMC module to be imported):

In [None]:
pm.plot_trace(trace_t, combined=True)

### 8.3.2. Metallicity

The metallicity prior is a bounded lognormal distribution $f(Z_k|\mu_k,\sigma_k)$ with a lower boundary of 0, an upper boundary of 10, and $\mu_k=\ln(\mathrm{Randn}(<Z_\mathrm{obs}>,\sigma(Z_\mathrm{obs})))$ and $\sigma_k=\sigma(Z_\mathrm{obs})$, see Eq. (16)-(18):

In [None]:
prior_z = {'mu': np.log(np.random.normal(met.mean(), met.std(), size=len(vor_bundle_mapping))),  # (17)
           'sigma': met.std(),  # (18)
           'lower': 0,
           'upper': 10}

In [None]:
model_z, trace_z = coloring.fit_bayesian(prior_dist='lognormal',
                                         prior_par=prior_z,
                                         flux_data_norm=flux_data_norm,
                                         obs_data=met,
                                         sample=sample)

Optional: inspect the Bayesian model details via

In [None]:
model_z

In analogy to above, the resulting metallicity values are accessible via `trace_z.posterior['qty']`, a data structure compatible with an array of shape (< number of chains >, < number of draws >, <$N_\mathrm{bundle}$>). After storing it in `met_posterior`, the mean and error of each Voronoi orbit bundle's metallicity are given by `met_mean = met_posterior.mean(axis=(0,1))` and `met_err = met_posterior.std(axis=(0,1))`, respectively:

In [None]:
met_posterior = trace_z.posterior['qty']
met_mean = met_posterior.mean(axis=(0,1))
met_err = met_posterior.std(axis=(0,1))

Optional: display the posterior plot (requires the PyMC module to be imported):

In [None]:
pm.plot_trace(trace_z, combined=True)

## 8.4. Results

### 8.4.1. Check how the model matches the data

Here, we plot the observed color maps along with the fitted age and metallicity data maps. Note that for the observed data (first row), the errors refer to the read-in observation errors and for the fitted data (second row), the error columns refer to the standard deviations of the posteriors for age and metallicity, respectively. The color maps are consistemt with those used in the kinematic map plots. Also, the residuals are defined as residual = (model - data) / data_error, consistent with the kinematic maps.

In [None]:
coloring.color_maps(colors={'age': 'Stellar age [Gyr]', 'met': r'Metallicity $Z/Z_\odot$'},  # choose the colors to be plotted
                    model_data=[age_mean, age_err, met_mean, met_err],                       # calculated from the posteriors
                    flux_norm=flux_data_norm,                                                # resulting from get_orbit_bundle_maps() above
                    cbar_lims='data');                                                       # 'data', 'model', or 'auto'

Including the error columns is optional, here is how to just plot the colors:

In [None]:
coloring.color_maps(colors={'age': 'Stellar age [Gyr]', 'met': r'Metallicity $Z/Z_\odot$'},
                    model_data=[age_mean, met_mean],
                    flux_norm=flux_data_norm,
                    cbar_lims='data');

### 8.4.2. Visualize the age-metallicity relation (AMR)

In [None]:
age_posterior.shape

In [None]:
coloring.color_color_plot(age_posterior, met_posterior, phase_space_binning['out'][2],
                          x_label='Age [Gyr]', y_label='$Z/Z_\\odot$',
                          x_scale='linear', y_scale='linear',
                          n_smooth=500);

### 8.4.3. Create an orbital decomposition plot for the colors

The orbital decomposition plot can deal with multiple models. The orbits' probability distribution and the colors in the phase space bins are then averaged over say, all models within a 1$\sigma$ confidence level of the model hyperparameters. Here, we have only one model and will use that to demonstrate how to create the orbital decomposition plot using the (only) model in row 0 of the all_models table, along with the just calculated Voronoi orbit bundles and the estimates for the age and metallicity:

In [None]:
distr = coloring.get_color_orbital_decomp(models=[c.all_models.get_model_from_row(0)], vor_bundle_mappings=[vor_bundle_mapping], colors=[[age_mean, met_mean]])
distr.shape

The intermediate result ``distr`` is a 4-dimensional numpy array. Its first dimension has three entries, corresponding to the orbital weight distribution plus the two color distributions (age and metallicity). The second and third indices are the phase space bins in $r$ and $\lambda_z$. The last index enumerates the number of models for which the weight and color distribution is available (here, only one model).

The next step is to average over the models and to plot the data. For this, ``distr`` can directly be passed to the plotting method, along with the desired labels for the individual plots:

In [None]:
coloring.coloring_decomp_plot(distr,
                              plot_labels=[r'Probability density [$M_*$/unit]', 'Stellar age [Gyr]', r'Metallicity $Z/Z_\odot$'],
                              colorbar_scale=['linear','linear','linear']);

### 8.4.4. Create a plot that shows the orbit probability distribution in (age, circularity) and the disk ratio vs age

We use `distr`, the result from the last section to extract the weight and age data and plot the orbit distribution in the (age, circularity) phase space, averaged over multiple DYNAMITE models. On top of that plot, we display the disk fraction as a function of the age and identify the the age at which the cold orbit fraction crosses 50%.
The method ``circularity_color_plot()`` expects the orbit bundle weights as the first parameter and the color distribution in the second. We can directly use ``distr[0]`` for the weights and ``distr[1]`` for the age data.

In [None]:
coloring.circularity_color_plot(distr[0], distr[2], c_label='Stellar age [Gyr]');

The method ``circularity_color_plot()`` has quite a few arguments. For details, please see its docstring. As an illustration, here are a few settings to experiment with:

In [None]:
coloring.circularity_color_plot(distr[0], distr[2],
                                c_label=r'Metallicity $Z/Z_\odot$',
                                c_scale='linear',
                                p_scale='linear',
                                n_color_bins=14,
                                interpolation='spline16',  # try 'none', 'spline16',...
                                disk_fraction=True);

## Appendix : experiments

In [None]:
plotter.orbit_plot(model=c.all_models.get_model_from_row(0), Rmax_arcs=316);

In [None]:
def component_radial_profile_plot(which_mge='mge_lum'):
    stars = c.system.get_unique_triaxial_visible_component()
    if which_mge == 'mge_lum':
        mge = stars.mge_lum.data
    elif which_mge == 'mge_pot':
        mge = stars.mge_pot.data
    else:
        raise ValueError(f"which_mge must be either 'mge_lum' or 'mge_pot', not '{which_mge}'.")
    fig = plt.figure(figsize=(16, 6))
    plt.subplot(1,2,1)
    cmap = matplotlib.cm.get_cmap('jet')
    Lk, sigma, q = mge['I'], mge['sigma'], mge['q']
    fc = np.sum(2 * np.pi * Lk * sigma ** 2 * q) * 0.65
    pass  # to be continued...