In [None]:
import glob
import proplot as pplt
import numpy as np

In [None]:
%matplotlib inline

In [None]:
from des_y6utils.piff import (
    measure_t_grid_for_piff_model, 
    nanmad, 
    make_good_regions_for_piff_model_star_and_gal_grid,
)

In [None]:
from pizza_cutter.des_pizza_cutter._piff_tools import compute_piff_flags

import yaml
import os
from esutil.pbar import PBar
import pprint
import piff
import json
import joblib
import fitsio


def _run_file(fname):
    pmod = piff.read(fname)    
    if "_z_" in fname:
        piff_kwargs = {"IZ_COLOR": 0.34}
    else:
        piff_kwargs = {"GI_COLOR": 0.61}
    res = make_good_regions_for_piff_model_star_and_gal_grid(
        piff_mod=pmod, piff_kwargs=piff_kwargs, 
        any_bad_thresh=5, flag_bad_thresh=2,
    )
    t_gal = res["t_gal"]
    bf = np.mean(np.isnan(t_gal))
    harr = (res["t_gal"] - res["t_star"]).ravel()
    mdev = np.nanmax(np.abs(harr - np.nanmedian(harr)))/nanmad(harr)
    if bf > 0:
        print(fname, bf, mdev, flush=True)

    return fname, res["flags"], np.mean(np.isnan(t_gal)), mdev


def _run_jobs(jobs, results):
    if len(jobs) > 0:        
        with joblib.Parallel(n_jobs=2, backend="loky", verbose=100) as par:
            outputs = par(jobs)

        for fname, flags, bf, mdev in outputs:
            results[tname][band][os.path.basename(fname)] = {
                "bad_frac": bf,
                "flags": flags,
                "maxdev": mdev,
            }

        with open("data.json", "w") as fp:
            json.dump(results, fp)
    

TNAMES = ["DES0229-0416", "DES0137-3749", "DES0131-3206", "DES0221-0750"]
BANDS = ["i", "z", "r", "g"]

os.system("rm -f ./data.json")

results = {}

for tname in TNAMES:
    results[tname] = {}
    for band in BANDS:
        print("%s-%s" % (tname, band), flush=True)
        
        results[tname][band] = {}

        with open(
            "/Users/beckermr/MEDS_DIR/des-pizza-slices-y6-test/pizza_cutter_info/"
            "%s_%s_pizza_cutter_info.yaml" % (tname, band)
        ) as fp:
            yml = yaml.safe_load(fp.read())

        jobs = []
        nmax = 24
        for i, src in PBar(enumerate(yml["src_info"]), total=len(yml["src_info"])):
            if len(jobs) >= nmax:
                _run_jobs(jobs, results)                
                jobs = []
            else:
                if (
                    compute_piff_flags(
                        piff_info=src["piff_info"],
                        max_fwhm_cen=3.6, 
                        min_nstar=30, 
                        max_exp_T_mean_fac=4, 
                        max_ccd_T_std_fac=0.3,
                    ) == 0
                ):
                    jobs.append(joblib.delayed(_run_file)(src["piff_path"]))

    _run_jobs(jobs, results)

In [None]:
for fname, flags, b in outputs:
    if flags == 0:
        bf = 1.0 - ((b["xmax"] - b["xmin"])*(b["ymax"] - b["ymin"]))/(4096*2048)
    else:
        bf = 1.0

    results[tname][band][os.path.basename(fname)] = {
        "no_box_frac": bf,
        "flags": flags,
        "bbox": b,
    }

with open("data.json", "w") as fp:
    json.dump(results, fp)

In [None]:
os.system("mkdir -p piff_plots_%s" % BAND)

for i, src in PBar(enumerate(yml["src_info"]), total=len(yml["src_info"])):
    pmod = piff.read(src["piff_path"])
    print(src["image_path"])
    res = make_good_regions_for_piff_model(pmod, piff_kwargs={"GI_COLOR": 0.61}, seed=10, verbose=True)
    print("    " + pprint.pformat(src["piff_info"]))
    
    t_arr = res["t_arr"]
    t_mn = res["t_mn"]
    t_std = res["t_std"]
    flags = res["flags"]
    b = res["bbox"]
    
    msk = (~np.isfinite(t_arr)) | (np.abs(t_arr - t_mn) > 25 * t_std)

    if np.any(msk):

        msk = (~np.isfinite(t_arr)) | (np.abs(t_arr - t_mn) > 15 * t_std)
        fig, axs = pplt.subplots(nrows=1, ncols=2)
        axs[0].imshow(np.log10(t_arr/t_std), cmap="rocket", origin="lower")
        axs[0].grid(False)
        axs[0].set_title("T")

        axs[1].imshow(msk, cmap="rocket", origin="lower")
        axs[1].grid(False)
        axs[1].set_title("|T - <T>| > 15 sigma(T)")
        axs[1].plot([b["xmin"]/128, b["xmin"]/128], [b["ymin"]/128, b["ymax"]/128 - 1], color="red")
        axs[1].plot([b["xmax"]/128 - 1, b["xmax"]/128 - 1], [b["ymin"]/128, b["ymax"]/128 - 1], color="red")
        axs[1].plot([b["xmin"]/128, b["xmax"]/128 - 1], [b["ymin"]/128, b["ymin"]/128], color="red")
        axs[1].plot([b["xmin"]/128, b["xmax"]/128 - 1], [b["ymax"]/128 - 1, b["ymax"]/128 - 1], color="red")

        fig.savefig("piff_plots_%s/psf_%d.png" % (BAND, i))

In [None]:
import meds

bands = ["g", "r", "i", "z"]
mfiles = []
for band in bands:
    mfiles.append(meds.MEDS("DES0221-0750_r5592p01_%s_pizza-cutter-slices.fits.fz" % band))

In [None]:
# start_obj = 9008
# nrows = 16

# fig, axs = pplt.subplots(nrows=nrows, ncols=4)

# for row in range(nrows):
#     obj = start_obj + row
#     for col in range(4):
#         psf = mfiles[col].get_psf(obj, 0)
#         axs[row, col].imshow(np.arcsinh(psf/np.std(psf[20, :])), origin="lower", cmap="rocket")
#         axs[row, col].grid(False)

In [None]:
psf = mfiles[1].get_psf(9008, 0)

fig, axs = pplt.subplots()
axs.imshow(psf)

In [None]:
import fitsio
d = fitsio.read("DES0221-0750_r5592p01_%s_pizza-cutter-slices.fits.fz" % "r", ext="epochs_info")

In [None]:
msk = d["id"] == 9008

In [None]:
d[msk]

In [None]:
ii = fitsio.read("DES0221-0750_r5592p01_%s_pizza-cutter-slices.fits.fz" % "r", ext="image_info")

In [None]:
ii[d[msk]["image_id"]]["image_path"]

In [None]:
from pizza_cutter.des_pizza_cutter._piff_tools import compute_piff_flags

In [None]:
import piff

piffs = []
psf_imgs = []
piff_flags = []
einds = np.where(msk)[0]
for eind in einds:
    arr = d[eind:eind+1]
    image_id = arr["image_id"][0]
    fname = ii[image_id]["image_path"]
    src = None
    for i, _src in enumerate(yml["src_info"]):
        if _src["image_path"].endswith(fname.split("/")[1]):
            src = _src
            
    assert src is not None, fname.split("/")[1]
    print(fname, src["piff_path"])
    piffs.append(piff.read(src["piff_path"]))
    row = arr["psf_row_start"][0] + 13
    col = arr["psf_col_start"][0] + 13
    print(row, col, src["piff_info"])
    psf_imgs.append(piffs[-1].draw(x=col, y=row, GI_COLOR=0.61, chipnum=list(piffs[-1].wcs.keys())[0]).array)
    piff_flags.append(compute_piff_flags(
        piff_info=src["piff_info"],
        max_fwhm_cen=3.6, 
        min_nstar=25, 
        max_exp_T_mean_fac=4, 
        max_ccd_T_std_fac=0.3,
    ))

In [None]:
piff_flags

In [None]:
fig, axs = pplt.subplots(nrows=len(psf_imgs), ncols=1)

for i, ax in enumerate(axs):
    ax.imshow(psf_imgs[i])
    print(psf_imgs[i].sum())

In [None]:
%matplotlib inline

In [None]:
m = meds.MEDS("DES0221-0750_r_des-pizza-slices-y6-test_meds-pizza-slices-range9005-9010.fits.fz")

In [None]:
psf = m.get_psf(9008, 0)

fig, axs = pplt.subplots()
axs.imshow(psf)

In [None]:
import fitsio
ei = fitsio.read("DES0221-0750_r_des-pizza-slices-y6-test_meds-pizza-slices-range9005-9010.fits.fz", ext="epochs_info")

In [None]:
ei[ei["id"] == 9008]

In [None]:
2**5