In [None]:
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from astropy import table
import cmasher as cmr
import dynamite as dyn
from plotbin import display_pixels

In [None]:
# First, create some models
fname = 'user_test_config_ml.yaml'
c = dyn.config_reader.Configuration(fname,                        #   |
                                    reset_logging=True,           #   |
                                    user_logfile='test_nnls',     #   v
                                    reset_existing_output=True)   # 'False' saves time, use only if orblibs have been created with the current orblib_new_mirror
_ = dyn.model_iterator.ModelIterator(c)
c.all_models.table.pprint_all()
which_chi2 = c.settings.parameter_space_settings['which_chi2']
compare_file = f"data/chi2_compare_ml_{c.settings.orblib_settings['nE']}{c.settings.orblib_settings['nI2']}{c.settings.orblib_settings['nI3']}.dat"
chi2_compare = table.Table.read(compare_file, format='ascii')
# plot diagnostic output...
plt.figure()
plt.scatter(chi2_compare['model_id'], chi2_compare[which_chi2], s=2000, facecolors='none', edgecolors='black')
plt.plot(range(len(c.all_models.table)), c.all_models.table[which_chi2], 'rx')
plt.gca().set_title(f'calculated {which_chi2} (red) vs should-be range (black circles)')
plt.xlabel('model_id')
plt.xticks(range(len(c.all_models.table)))
plt.ylabel(which_chi2)
plt.show()

In [None]:
# inspect the populations data (so far we need only the apertures and binning)
stars = c.system.get_unique_triaxial_visible_component()
print(f'{len(stars.kinematic_data)=}, {len(stars.population_data)=}')
pops = stars.population_data
for pop in pops:
    print(f'pop dataset {pop.name}: {pop.n_spatial_bins=}')  # one pop dataset uses the kinematics' apertures, one has its own (see config file)

In [None]:
# get the best model, its orblib and the populations' orbit densities
model = c.all_models.get_model_from_row(c.all_models.get_best_n_models_idx(n=1)[0])
orblib = model.get_orblib()
orblib.read_losvd_histograms(pops=True)
for pop_idx, p in enumerate(orblib.pops_projected_masses):
    print(f'{pops[pop_idx].name} number of orbits and spatial bins: {p.shape}')

In [None]:
# get model weights
_ = model.get_weights(orblib)
print(model.weights.shape)

In [None]:
# get the model orbit density by convoluting the pops_projected_masses with the orbit weights
model_proj_masses = []
for p in orblib.pops_projected_masses:
    model_proj_masses.append(np.dot(p.T, model.weights)) # .shape = n_spatial_bins

In [None]:
# plot the model's orbit density for the pops apertures

# mostly from plotter.py

for pop_idx, proj_mass in enumerate(model_proj_masses):
    fluxm = proj_mass
    # plotting borrowed from plotter.py...
    # get aperture and bin data
    dp_args = stars.population_data[pop_idx].dp_args
    x = dp_args['x']
    y = dp_args['y']
    dx = dp_args['dx']
    grid = dp_args['idx_bin_to_pix']
    angle_deg = dp_args['angle']
    # Only select the pixels that have a bin associated with them.
    s = np.ravel(np.where((grid >= 0)))
    fhist, _ = np.histogram(grid[s], bins=len(fluxm))
    fluxm = fluxm / fhist
    # plot settings
    minsb = min(np.array(list(map(np.log10, fluxm[grid[s]] / max(fluxm)))))
    maxsb = max(np.array(list(map(np.log10, fluxm[grid[s]] / max(fluxm)))))
    # The galaxy has NOT already rotated with PA to align major axis with x
    map1 = cmr.get_sub_cmap('twilight_shifted', 0.05, 0.6)
    kw_display_pixels1 = dict(pixelsize=dx,
                              angle=angle_deg,
                              colorbar=True,
                              nticks=7,
                              # cmap='sauron')
                              cmap=map1)
    # PLOT THE MODEL DATA
    plt.figure()
    c_c = np.array(list(map(np.log10, fluxm[grid[s]] / max(fluxm))))
    display_pixels.display_pixels(x, y, c_c,
                                  vmin=minsb, vmax=maxsb,
                                  label='surface brightness (log)',
                                  **kw_display_pixels1)
    plt.gca().set_title(pops[pop_idx].name)
    plt.show()