In [None]:
import copy
from jolideco.core import MAPDeconvolver
from jolideco.models import FluxComponents

from utils import read_datasets, stack_datasets
import numpy as np
from pathlib import Path
import yaml

RANDOM_STATE = np.random.RandomState(7362)

In [None]:
PATH = Path("./../..") / "data"

In [None]:
bkg_level = "bg3"
instrument = "chandra"
scenario = "aster1"

In [None]:
filename_config  = PATH.parent / "config" / f"{scenario}/{bkg_level}/{instrument}.yaml"

In [None]:
def prepare_datasets_jolideco(datasets):
    """Prepare datasets for jolideco"""
    datasets = copy.deepcopy(datasets)

    for dataset in datasets.values():
        dataset["psf"] = {"flux": dataset["psf"]}

    return datasets


def get_flux_init(datasets, oversample=10.0):
    """Get flux init"""
    stacked = stack_datasets(datasets=datasets)

    flux = (stacked["counts"] - stacked["background"]) / stacked["exposure"]

    flux_mean = np.nanmean(np.clip(flux, 0, np.inf))

    flux_init = RANDOM_STATE.gamma(oversample * flux_mean, size=flux.shape) / oversample
    return flux_init.astype(np.float32)

In [None]:
pattern = f"{instrument}_gauss_fwhm4710_128x128_sim00_{bkg_level}_{scenario}_iter*.fits"
filenames_counts = list(PATH.glob(pattern))

In [None]:
pattern = f"{instrument}_gauss_fwhm4710_128x128_psf_33x33.fits"
filenames_psf = [PATH / pattern] * len(filenames_counts)

In [None]:
datasets = read_datasets(
    filenames_counts=filenames_counts,
    filenames_psf=filenames_psf,
)

In [None]:
with filename_config.open("r") as fh:
    config = yaml.safe_load(fh)

config = config["runs"][-2]

In [None]:
flux_init = get_flux_init(datasets=datasets)

if config["components"]["flux"].get("upsampling_factor", 1) > 1:
    flux_init = flux_init.repeat(2, axis=0).repeat(2, axis=1)

config["components"]["flux"]["flux_upsampled"] = flux_init
components = FluxComponents.from_dict(config["components"])

deconvolver = MAPDeconvolver(**config["deconvolver"])
deconvolver.n_epochs = 500

datasets = prepare_datasets_jolideco(datasets=datasets)

In [None]:
%%time
result = deconvolver.run(
    datasets=datasets, components=components,
)

In [None]:
import matplotlib.pyplot as plt
from astropy.visualization import simple_norm

flux = result.components.flux_upsampled_total_numpy
norm = simple_norm(flux, stretch="asinh", min_cut=0, max_cut=10, asinh_a=0.01)
plt.imshow(flux, origin="lower", norm=norm)

In [None]:
result.plot_trace_loss()
plt.ylim(6, 7)

In [None]:

print(components)