In [None]:
from glob import glob
import pickle

def load_results(path):
    with open(path, 'rb') as p: 
        result = pickle.load(p)
    return result

paths = sorted(glob('../outputs/*.pkl'))
path = paths[-1]

In [None]:
import os


os.chdir('../')
result = load_results(path[3:])
result.oversample = 12
result.construct_image(overwrite=False)
result.construct_grid_model(overwrite=False)
result.fit_psf()

In [None]:
from astropy.modeling.fitting import LevMarLSQFitter
from astropy.nddata import CCDData, StdDevUncertainty
from astropy.table import Table
import astropy.units as u

from photutils.psf import SourceGrouper, PSFPhotometry
from photutils.background import LocalBackground
from photutils.detection import DAOStarFinder


grouper = SourceGrouper(min_separation=result.crit_separation)
fitter = LevMarLSQFitter(calc_uncertainties=True)
lb = LocalBackground(
        inner_radius=10, outer_radius=30
    )

photometry = PSFPhotometry(
    grouper=grouper,
    localbkg_estimator=lb,
    psf_model=result.grid_model,
    fitter=fitter,
    fit_shape=result.fit_shape,
    aperture_radius=50,
    progress_bar=False,
)

In [None]:
import roman_datamodels.datamodels as rdd
import matplotlib.pyplot as plt
from astropy.visualization import simple_norm
import numpy as np

dm = rdd.open(result.path_image)
data = dm.data
error = dm.err
mask = dm.dq != 0

In [None]:
flux_order = np.argsort(np.array(result.fit_results['flux_fit']))
guesses = result.fit_results[flux_order][-3:].copy()
guesses.remove_columns([c for c in guesses.colnames if c not in ['x_fit', 'y_fit', 'flux_fit']])
guesses.rename_columns(['x_fit', 'y_fit', 'flux_fit'], ['x_init', 'y_init', 'flux_init'])
phot_result = photometry(data=data.value, error=10*error.value, init_params=guesses, mask=mask)

In [None]:
model_image = photometry.make_model_image(data.value.T.shape, result.grid_model.data.shape[1:])
residual_image = data.value - model_image

max_indices = np.array(list(phot_result[['y_init', 'x_init']].itercols())).T

fig, ax = plt.subplots(3, 3, figsize=(10, 10))

for i, max_index in enumerate(max_indices):
    norm = simple_norm(model_image, 'asinh', min_cut=-10, max_cut=300)
    c = ax[i, 0].imshow(model_image, norm=norm)
    # plt.colorbar(c, ax=ax[i, 0])

    ax[i, 0].set(
        ylim = [max_index[0] + 20, max_index[0] - 20],
        xlim = [max_index[1] - 20, max_index[1] + 20],
        title='model (WebbPSF+photutils)'
    )

    c = ax[i, 1].imshow(data.value, norm=norm)
    # plt.colorbar(c, ax=ax[i, 1])

    ax[i, 1].set(
        ylim = [max_index[0] + 20, max_index[0] - 20],
        xlim = [max_index[1] - 20, max_index[1] + 20],
        title='obs (romanisim)'
    )

    c = ax[i, 2].imshow(residual_image, norm=norm)

    ax[i, 2].set(
        title='residual',
        ylim = [max_index[0] + 20, max_index[0] - 20],
        xlim = [max_index[1] - 20, max_index[1] + 20],
    )

fig.tight_layout()
fig.savefig('plots/psf_residuals.png', bbox_inches='tight', dpi=200)
plt.show()