In [None]:
import ngmix
import fitsio
import proplot as pplt
import numpy as np
import piff
import galsim
import yaml

In [None]:
%matplotlib inline

In [None]:
import glob
from esutil.pbar import PBar
import joblib

arrs = []
fnames = []
bands = []
arr_files = sorted(glob.glob("data/data_chunk*.fits"))
for arr_file in PBar(arr_files):
    fname_file = arr_file.replace("/data_", "/fnames_")
    arrs.append(fitsio.read(arr_file))
    fnames.append(fitsio.read(fname_file)["fnames"])
    bands.append([f.split("_")[1] for f in fnames[-1]])
    
arr = np.hstack(arrs)
fnames = np.hstack(fnames)
bands = np.hstack(bands)

print("any NaNs:", np.any(np.isnan(arr)))

In [None]:
def _plot_vals(arr, ax, title=None, msk=None):
    cols = pplt.get_colors("default")
    
    arr_bad = []
    arr_good = []
    arr_all = []
    isnan = []

    naninds = []
    cutinds = []
    oddinds = []

    def _stat(a):
        return np.nanmax(np.abs(a-np.nanmedian(a)))


    ng = 512
    for i in range(len(arr)//ng):
        if msk is not None and not msk[i]:
            continue
        start = i * ng
        stop = start + ng
        a = arr[start:stop]
        sa = _stat(a)
        if sa == 0:
            print("zero stat:", i)
            continue
        
        arr_all.append(sa)
        if np.any(np.isnan(a)):
            isnan.append(True)
            arr_bad.append(sa)
            naninds.append(i)
            print("really bad:", i, arr_bad[-1], np.any(np.isnan(a)))
        else:
            isnan.append(False)
            arr_good.append(sa)

            if sa > 0.14 and sa <= 0.15:
                oddinds.append(i)
                print("kind of odd:", i, sa)
            elif sa > 0.15:
                cutinds.append(i)

    cut = 0.15
    isnan = np.array(isnan)
    arr_all = np.array(arr_all)
    print("fraction cut:", np.mean(isnan | (arr_all > cut)))

    if len(arr_bad) > 0:
        arr_bad = np.log10(np.hstack(arr_bad))
    else:
        arr_bad = None
    arr_good = np.log10(np.hstack(arr_good))

    if arr_bad is not None:
        h, _, _ = ax.hist(
            [arr_bad, arr_good], 
            bins=75, 
            log=True, 
            density=False, 
            labels=["NaNs", "no NaNs"],
            colors=[cols[1], cols[0]],
            stacked=False,
            histtype="stepfilled",
            alpha=0.75,
        )
    else:
        h, _, _ = ax.hist(arr_good, bins=75, log=True, density=True, label="no NaNs", color=cols[0])
    
    ax.vlines(np.log10(cut), min(0.02, h.min()/2), max(h.max()*2, 1200), color='k')
    ax.legend()
    ax.set_xlabel("log10[max(|Tgal - median(Tgal)|)]")
    if title is not None:
        ax.set_title(title)
    ax.grid(False)

In [None]:
fig, axs = pplt.subplots(refwidth=4)
_plot_vals(arr, axs)

In [None]:
fig, axs = pplt.subplots(refwidth=4, ncols=4)

for i, band in enumerate(["g", "r", "i", "z"]):
    msk = bands == band 
    _plot_vals(arr, axs[i], title=band, msk=msk)

In [None]:
import os
import glob
from des_y6utils.piff import (
    measure_star_t_for_piff_model,
    map_star_t_to_grid,
    measure_t_grid_for_piff_model,
    make_good_regions_for_piff_model_gal_grid,
    make_good_regions_for_piff_model_star_and_gal_grid,
    nanmad,
)

sind = 3008
found = glob.glob(os.path.join(os.environ["DESDATA"], "**", fnames[sind]), recursive=True)
assert len(found) == 1

piff_mod = piff.read(found[0])
print(len(piff_mod.stars))

if "_z_" in fnames[sind]:
    piff_kwargs = {"IZ_COLOR": 0.34}
else:
    piff_kwargs = {"GI_COLOR": 0.61}

t_arr = measure_t_grid_for_piff_model(piff_mod, piff_kwargs, seed=14354)

fig, axs = pplt.subplots(ncols=1, share=0)
axs.imshow(t_arr, cmap="rocket")
axs.grid(False)
axs.set_title("grid of shapes at gal color")

In [None]:
import os
import glob
from des_y6utils.piff import (
    measure_star_t_for_piff_model,
    map_star_t_to_grid,
    measure_t_grid_for_piff_model,
    make_good_regions_for_piff_model_gal_grid,
    make_good_regions_for_piff_model_star_and_gal_grid,
    nanmad,
)

sind = 2516
found = glob.glob(os.path.join(os.environ["DESDATA"], "**", fnames[sind]), recursive=True)
assert len(found) == 1


piff_mod = piff.read(found[0])
print(len(piff_mod.stars))

if "_z_" in fnames[sind]:
    piff_kwargs = {"IZ_COLOR": 0.34}
else:
    piff_kwargs = {"GI_COLOR": 0.61}


thresh = 5
t_arr = measure_t_grid_for_piff_model(piff_mod, piff_kwargs, seed=14354)
data = measure_star_t_for_piff_model(piff_mod, piff_prop=list(piff_kwargs.keys())[0])
msk = np.isfinite(data["t"])
data = data[msk]
ts_arr = map_star_t_to_grid(data)

res_g = make_good_regions_for_piff_model_gal_grid(piff_mod, piff_kwargs=piff_kwargs, seed=11, verbose=False)
res_sg = make_good_regions_for_piff_model_star_and_gal_grid(
    piff_mod, piff_kwargs=piff_kwargs, seed=11, verbose=False,
    flag_bad_thresh=2, any_bad_thresh=5,
)

vmin = min(np.nanmin(ts_arr), np.nanmin(t_arr))
vmax = max(np.nanmax(ts_arr), np.nanmax(t_arr))

fig, axs = pplt.subplots(ncols=4, share=0)
h0 = axs[0].imshow(ts_arr, vmin=vmin, vmax=vmax, cmap="rocket")
axs[0].grid(False)
axs[0].set_title("stars on grid w/ 2d poly")
axs[0].colorbar(h0, loc='l')

axs[1].imshow(t_arr, vmin=vmin, vmax=vmax, cmap="rocket")
axs[1].grid(False)
axs[1].set_title("grid of shapes at gal color")

b = res_g["bbox"]
axs[1].plot([b["xmin"]/128, b["xmin"]/128], [b["ymin"]/128, b["ymax"]/128 - 1], color="red")
axs[1].plot([b["xmax"]/128 - 1, b["xmax"]/128 - 1], [b["ymin"]/128, b["ymax"]/128 - 1], color="red")
axs[1].plot([b["xmin"]/128, b["xmax"]/128 - 1], [b["ymin"]/128, b["ymin"]/128], color="red")
axs[1].plot([b["xmin"]/128, b["xmax"]/128 - 1], [b["ymax"]/128 - 1, b["ymax"]/128 - 1], color="red")

b = res_sg["bbox"]
axs[1].plot([b["xmin"]/128, b["xmin"]/128], [b["ymin"]/128, b["ymax"]/128 - 1], color="blue")
axs[1].plot([b["xmax"]/128 - 1, b["xmax"]/128 - 1], [b["ymin"]/128, b["ymax"]/128 - 1], color="blue")
axs[1].plot([b["xmin"]/128, b["xmax"]/128 - 1], [b["ymin"]/128, b["ymin"]/128], color="blue")
axs[1].plot([b["xmin"]/128, b["xmax"]/128 - 1], [b["ymax"]/128 - 1, b["ymax"]/128 - 1], color="blue")

axs[2].imshow(res_sg["bad_msk"], cmap="rocket")

harr = (t_arr-ts_arr).ravel()
std5 = nanmad(harr) * 2
h = axs[3].hist(harr, bins=50)
axs[3].vlines([np.nanmedian(harr)-std5, np.nanmedian(harr)+std5], 0, np.max(h[0]), color="k")
axs[3].set_xlabel("gal T - star T")

In [None]:
TNAMES = ["DES0131-3206", "DES0137-3749", "DES0221-0750", "DES0229-0416"]
BANDS = ["g", "r", "i", "z"]

with open(
    "/Users/beckermr/MEDS_DIR/des-pizza-slices-y6-test/pizza_cutter_info/"
    "%s_%s_pizza_cutter_info.yaml" % (TNAMES[1], BANDS[0])
) as fp:
    yml = yaml.safe_load(fp.read())

In [None]:
from des_y6utils.piff import (
    measure_star_t_for_piff_model,
    map_star_t_to_grid,
    measure_t_grid_for_piff_model,
    make_good_regions_for_piff_model_gal_grid,
    make_good_regions_for_piff_model_star_and_gal_grid,
    nanmad,
)

In [None]:
# find a bad model
import tqdm

for sind, src in tqdm.tqdm(enumerate(yml["src_info"]), total=len(yml["src_info"])):
    fname = src["piff_path"]

    pmod = piff.read(fname)
    if "_z_" in fname:
        piff_kwargs = {"IZ_COLOR": 0.34}
    else:
        piff_kwargs = {"GI_COLOR": 0.61}
    t_gal = measure_t_grid_for_piff_model(piff_mod, piff_kwargs, seed=14354)    
    bf = np.mean(np.isnan(t_gal))
    
    if bf > 0:
        print(sind)
        break

In [None]:
from esutil.pbar import prange
from pizza_cutter.des_pizza_cutter._piff_tools import compute_piff_flags

tot = 0
flagged = 0

TNAMES = ["DES0131-3206", "DES0137-3749", "DES0221-0750", "DES0229-0416"]
BANDS = ["g", "r", "i", "z"]


for i in prange(4):
    for j in range(4):
        with open(
            "/Users/beckermr/MEDS_DIR/des-pizza-slices-y6-test/pizza_cutter_info/"
            "%s_%s_pizza_cutter_info.yaml" % (TNAMES[i], BANDS[j])
        ) as fp:
            yml = yaml.safe_load(fp.read())
        
        for _sind, src in enumerate(yml["src_info"]):
            tot += 1
            if (compute_piff_flags(
                        piff_info=src["piff_info"],
                        max_fwhm_cen=3.6, 
                        min_nstar=35, 
                        max_exp_T_mean_fac=4, 
                        max_ccd_T_std_fac=0.3,
                    ) != 0):
                flagged += 1

print(flagged, flagged / tot)

In [None]:
from pizza_cutter.des_pizza_cutter._piff_tools import compute_piff_flags

TNAMES = [
    "DES0131-3206", 
    "DES0137-3749", 
    "DES0221-0750", 
    "DES0229-0416",
]

tfind = "D00372620_i_c62_r5702p01_piff-model.fits"
BANDS = [tfind.split("_")[1]]
sind = None

for band in BANDS:
    for tname in TNAMES:
        with open(
            "/Users/beckermr/MEDS_DIR/des-pizza-slices-y6-test/pizza_cutter_info/"
            "%s_%s_pizza_cutter_info.yaml" % (tname, band)
        ) as fp:
            yml = yaml.safe_load(fp.read())

        for _sind, src in enumerate(yml["src_info"]):
            if tfind in src["piff_path"]:
                sind = _sind
                print(
                    "sind|flags|tname:",
                    sind,
                    compute_piff_flags(
                        piff_info=src["piff_info"],
                        max_fwhm_cen=3.6, 
                        min_nstar=25, 
                        max_exp_T_mean_fac=4, 
                        max_ccd_T_std_fac=0.3,
                    ),
                    yml["tilename"],
                )
                break
        if sind is not None:
            break
    if sind is not None:
        break


In [None]:
np.nanmax(np.abs(harr - np.nanmedian(harr)))/nanmad(harr)

In [None]:
np.mean(res_sg["bad_msk"])

In [None]:
t_gal[0, 0] = np.nan
np.mean(np.isnan(t_gal))

In [None]:
list(piff_kwargs.keys())[0]

In [None]:
import ngmix

def get_star_stamp_pos(s, img, wgt):
    xint = int(np.floor(s.x - 1 + 0.5))
    yint = int(np.floor(s.y - 1 + 0.5))
    bbox = 17
    bbox_2 = (bbox - 1)//2
    
    return dict(
        img=img[yint-bbox_2: yint+bbox_2+1, xint-bbox_2: xint+bbox_2+1],
        wgt=wgt[yint-bbox_2: yint+bbox_2+1, xint-bbox_2: xint+bbox_2+1],
        xstart=xint-bbox_2, 
        ystart=yint-bbox_2,
        dim=bbox,
        x=s.x - 1,
        y=s.y - 1,
    )

def get_star_piff_obs(piff_mod, s, img, wgt):
    
    sres = get_star_stamp_pos(s, img, wgt)
    
    xv = sres["x"]+1
    yv = sres["y"]+1
    wcs = list(piff_mod.wcs.values())[0].local(
        image_pos=galsim.PositionD(x=xv, y=yv)
    ).jacobian()
    img = galsim.ImageD(sres["dim"], sres["dim"], wcs=wcs)
    cen = (
        sres["x"] - sres["xstart"] + 1,
        sres["y"] - sres["ystart"] + 1,
    )
    img = piff_mod.draw(
        x=xv, y=yv, chipnum=list(piff_mod.wcs.keys())[0],
        GI_COLOR=s.data.properties["GI_COLOR"],
        image=img, center=cen,
    )
    model_obs = ngmix.Observation(
        image=img.array,
        jacobian=ngmix.Jacobian(
            y=cen[1]-1,
            x=cen[0]-1,
            wcs=wcs,
        )
    )
    star_obs = ngmix.Observation(
        image=sres["img"],
        weight=sres["wgt"],
        jacobian=ngmix.Jacobian(
            y=cen[1]-1,
            x=cen[0]-1,
            wcs=wcs,
        )
    )
    return model_obs, star_obs, sres

In [None]:

x = []
y = []
t = []

for s in piff_mod.stars:

    mobs, sobs, sres = get_star_piff_obs(piff_mod, s, img, wgt)
    
    res = ngmix.admom.AdmomFitter(
        rng=np.random.RandomState(seed=10)
    ).go(mobs, ngmix.moments.fwhm_to_T(1))
    t.append(res["T"])
    x.append(sres["x"])
    y.append(sres["y"])
    

In [None]:
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import make_pipeline
from sklearn.linear_model import LinearRegression

degree = 2
polyreg = make_pipeline(PolynomialFeatures(degree), LinearRegression())
polyreg.fit(np.array([x, y]).T, np.array(t))

In [None]:
y, x = np.mgrid[0:4096:128, 0:2048:128] + 64
tg = polyreg.predict(np.array([x.ravel(), y.ravel()]).T)

In [None]:
tg = tg.reshape(x.shape)

In [None]:
fig, axs = pplt.subplots(ncols=2)
axs[0].imshow(tg)
axs[1].imshow(res["t_arr"])

In [None]:
def _nanmad(x, axis=None):
    """
    median absolute deviation - scaled like a standard deviation

        mad = 1.4826*median(|x-median(x)|)

    Parameters
    ----------
    x: array-like
        array to take MAD of
    axis : {int, sequence of int, None}, optional
        `axis` keyword for

    Returns
    -------
    mad: float
        MAD of array x
    """
    return 1.4826*np.nanmedian(np.abs(x - np.nanmedian(x, axis=axis)), axis=axis)


In [None]:
print(_nanmad(t), _nanmad(tg))

In [None]:
from des_y6utils.piff import make_good_regions_for_piff_model

In [None]:
res = make_good_regions_for_piff_model(piff_mod, piff_kwargs={"GI_COLOR": 0.61}, seed=10, verbose=False)

In [None]:
res["t_std"]

In [None]:
fig, axs = pplt.subplots()

axs.hist((res["t_arr"] - tg).ravel(), bins=50)

In [None]:
np.std((res["t_arr"] - tg).ravel()) * 5

In [None]:
np.max(np.abs(np.max(t) - np.median(t)))

In [None]:
_nanmad(t)

In [None]:
g = galsim.Gaussian(fwhm=0.5).dilate(1.1)

In [None]:
g

In [None]:
import fitsio

In [None]:
fitsio.write?