In [None]:
import os
import numpy as np
import shutil
import matplotlib
import matplotlib.pyplot as plt
from astropy import table
from astropy.io import ascii
import cmasher as cmr
import dynamite as dyn
from plotbin import display_pixels

In [None]:
# Get the DYNAMITE models
fname = 'F2_config11_with_pm.yaml'
c = dyn.config_reader.Configuration(fname,
                                    reset_logging=True,
                                    user_logfile='test_nnls',
                                    reset_existing_output=False)
stars = c.system.get_unique_triaxial_visible_component()
# print(f'stars.kinematic_data[0]:\n{stars.kinematic_data[0]}\n\nstars.kinematic_data[1]:\n{stars.kinematic_data[1]}')
_ = dyn.model_iterator.ModelIterator(config=c)

In [None]:
p = dyn.plotter.Plotter(c)

In [None]:
# Structure of proper motion input data
for i, k in enumerate(stars.kinematic_data):
    if isinstance(k, dyn.kinematics.ProperMotions):
        i_pm = i
        pm_data = k.data
        pm_input = k
        break
print('Shapes of proper motion input data:')
for k in sorted(pm_data):
    print(f"{k}: {pm_data[k].shape}")
print(f"{pm_data['vxrange'] = }, {pm_data['vyrange'] = }")
h2d = pm_input.as_histogram2d()
print(f"{pm_input.hist_width = }, {pm_input.hist_bins = }")
print(f"{len(h2d.xedg[0]) = }, 2dhist vx-range: {h2d.xedg[0][0]}, {h2d.xedg[0][-1]}")
print(f"{len(h2d.x[0]) = }, {h2d.x[0][0] = }, {h2d.x[0][-1] = }")
print(f"h2d dv: {h2d.dx[0][0]}, {h2d.dx[0][-1]}")

In [None]:
# Strange: nstarbin > sum of all PM_2dhist entries
for i in range(0, pm_data['PM_2dhist'].shape[0], 10):  # don't print everything...
    print(f"Spatial bin {i}:\t{pm_data['nstarbin'][i]=},\t{np.sum(pm_data['PM_2dhist'][i])=}")

In [None]:
# Plot the input data
h2d = pm_input.as_histogram2d()
# First, a few individual spatial bins...
for sp_bin_idx in(11, 48, 75):
    print(f'{sp_bin_idx=}')
    _ = p.hist2d_plot(h2d, orb_idx=0, sp_bin_idx=sp_bin_idx, show_1d=True, empty_bins=True)
# Now the global 2d histogram...
h2d_global = dyn.kinematics.Histogram2D(xedg=h2d.xedg,
                                        y=np.sum(h2d.y, axis=3)[:,:,:,np.newaxis],
                                        normalise=False)
_ = p.hist2d_plot(h2d_global, orb_idx=0, sp_bin_idx=0, show_1d=True, empty_bins=True)

In [None]:
# Experimental: quiver plot -> problem in get_mean or plotting...?
v_mean = pm_input.as_histogram2d().get_mean()
x, y = pm_data['xbin'], pm_data['ybin']
u, v = v_mean[0][0], v_mean[1][0]  # first index: vx/vy, second index: orb_idx
plt.quiver(x, y, u, v, color='g')
plt.axis('equal')

In [None]:
# Get best model
best_model_idx = c.all_models.get_best_n_models_idx(n=1)[0]
model = c.all_models.get_model_from_row(best_model_idx)
orblib = model.get_orblib()
print(f'{orblib.parset=}')
_ = model.get_weights(orblib)
orb_weights = model.weights
hist2d_orblib = orblib.vel_histograms[i_pm]
# print(f'{hist2d_orblib.y.shape = }')

In [None]:
# plot pm histograms for all spatial bins
n_bins = hist2d_orblib.y.shape[-1]

vx_range = [hist2d_orblib.x[0].min(), hist2d_orblib.x[0].max()]
vy_range = [hist2d_orblib.x[1].min(), hist2d_orblib.x[1].max()]
ratio = (vy_range[1]-vy_range[0]) / (vx_range[1]-vx_range[0])
print('Ranges are different from input data due to velocity scaling:')
print(f'{vx_range=}, {vy_range=}')
print(f"vx-range from xedg: {hist2d_orblib.xedg[0][0]}, {hist2d_orblib.xedg[0][-1]}")
print(f"vy-range from xedg: {hist2d_orblib.xedg[1][0]}, {hist2d_orblib.xedg[1][-1]}")

# data = np.einsum('ijkl,i', hist2d.y, orb_weights)[:,:,bin_idx]
data = np.dot(hist2d_orblib.y.T, orb_weights).T  # specific orbit: hist2d.y[orb_idx,:,:,bin_idx]

fig = plt.figure(figsize=(20, 20 * (n_bins // 4 // 4 + 2) * ratio))
for bin_idx in range(0, n_bins, 1):  ##################################
    ax = plt.subplot(n_bins // 4 + (1 if n_bins % 4 > 0 else 0), 4, bin_idx + 1)
    # im = ax.imshow(data, aspect='equal', interpolation='bilinear', #cmap=cm.RdYlGn,
    im = ax.imshow(data[:,:,bin_idx], aspect='equal', interpolation='none', #cmap=cm.RdYlGn,
                   origin='lower', extent=vx_range + vy_range,
                   vmax=data.max(), vmin=0)
    fig.colorbar(im, ax=ax, shrink=1)
    ax.set_title(f'Bin {bin_idx}')
plt.show()

In [None]:
# plot pm histograms for specific bin
hist2d = orblib.vel_histograms[1]
n_orbits = hist2d.y.shape[0]

bin_idx = 31

hist2d = orblib.vel_histograms[1]
vx_range = [hist2d.x[0].min(), hist2d.x[0].max()]
vy_range = [hist2d.x[1].min(), hist2d.x[1].max()]
ratio = (vy_range[1]-vy_range[0]) / (vx_range[1]-vx_range[0])
print(f'{vx_range=}, {vy_range=}')

fig = plt.figure(figsize=(20, 20 * (n_orbits // 4 // 4 + 2) * ratio))
for orb_idx in range(0, n_orbits, 1):  ##################################
    data = hist2d.y[orb_idx,:,:,bin_idx]
    ax = plt.subplot(n_orbits // 4 + (1 if n_orbits % 4 > 0 else 0), 4, orb_idx + 1)
    im = ax.imshow(data, aspect='equal', interpolation='none', #cmap=cm.RdYlGn,
                   origin='lower', extent=vx_range + vy_range,
                   vmax=data.max(), vmin=0)
    fig.colorbar(im, ax=ax)
    ax.set_title(f'Bin {bin_idx}, orbit {orb_idx}')

plt.show()

In [None]:
# plot the losvd and pm histograms' mean values
# mostly from plotter.py
orb_skip = 50
for i, orb_idx in enumerate(range(0, orblib.n_orbs, orb_skip)):

    def create_plot(dp_args, data):
        # get aperture and bin data
        x = dp_args['x']
        y = dp_args['y']
        dx = dp_args['dx']
        grid = dp_args['idx_bin_to_pix']
        angle_deg = dp_args['angle']
        # Only select the pixels that have a bin associated with them.
        s = np.ravel(np.where((grid >= 0)))
    
        #fhist, _ = np.histogram(grid[s], bins=len(data))
        #data = data / fhist
        # plot settings
        #data_min = min(data[grid[s]] / max(data))
        #data_max = max(data[grid[s]] / max(data))
        # The galaxy has NOT already rotated with PA to align major axis with x
        map1 = cmr.get_sub_cmap('twilight_shifted', 0.05, 0.6)
        kw_display_pixels1 = dict(pixelsize=dx,
                                  angle=angle_deg,
                                  colorbar=True,
                                  nticks=7,
                                  # cmap='sauron')
                                  cmap=map1)
        # PLOT THE DATA
        #plt.figure()
        #c_c = data[grid[s]] / max(data)
        c_c = data[grid[s]]
        display_pixels.display_pixels(x, y, c_c,
                                      vmin=min(data[grid[s]]), vmax=max(data[grid[s]]),
                                      label='velocity',
                                      **kw_display_pixels1)
        #plt.gca().set_title(title)
        #plt.show()

    fig = plt.figure(figsize=(20, 30))

    n_rows, n_cols = orblib.n_orbs // orb_skip + 1, 3

    ax = plt.subplot(n_rows, n_cols, n_cols * i + 1)
    if len(orblib.vel_histograms) > 1:  # assuming that 1d histograms are first, if existing
        ax.set_title(f'Orbit {orb_idx}, losvd')
        mean = orblib.vel_histograms[0].get_mean()[orb_idx]
        create_plot(stars.kinematic_data[0].dp_args, mean)

    idx = 1 if len(orblib.vel_histograms) > 1 else 0  # assuming that 1d histograms are first, if existing

    ax = plt.subplot(n_rows, n_cols, n_cols * i + 2)
    ax.set_title(f'Orbit {orb_idx}, pm x')
    mean = orblib.vel_histograms[idx].get_mean()[0][orb_idx]
    create_plot(stars.kinematic_data[idx].dp_args, mean)

    ax = plt.subplot(n_rows, n_cols, n_cols * i + 3)
    ax.set_title(f'Orbit {orb_idx}, pm y')
    mean = orblib.vel_histograms[idx].get_mean()[1][orb_idx]
    create_plot(stars.kinematic_data[idx].dp_args, mean)

In [None]:
# Orblib pm vs input pm for specific sp_bin_idx
sp_bin_idx_list = [25, 48, 70]

data = np.dot(hist2d_orblib.y.T, orb_weights).T
print(f'{data.shape = }')
print(f'{np.max(hist2d_orblib.y)=}')
for sp_bin_idx in sp_bin_idx_list:
    print(f'{sp_bin_idx = }')
    hist_2d = dyn.kinematics.Histogram2D(xedg=hist2d_orblib.xedg,
                                         y=data[np.newaxis,:,:,:],
                                         normalise=False)
    _ = p.hist2d_plot(hist_2d, orb_idx=0, sp_bin_idx=sp_bin_idx, show_1d=True, empty_bins=True)
    _ = p.hist2d_plot(pm_input.as_histogram2d(), orb_idx=0, sp_bin_idx=sp_bin_idx, show_1d=True, empty_bins=True)

In [None]:
import datetime
datetime.datetime.now()

In [None]:
import numpy as np

In [None]:
f=np.load('Francisco/rot_mod_d5kpc_i00deg/kinematics_hist2d.npz')

In [None]:
f['PSF_weight']

In [None]:
a=np.load('dynamite_input/PM_kinfile_dv100.npz')
a['PSF_sigma'], a['PSF_weight']