In [1]:
%load_ext autoreload
%autoreload 2

import jax
jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp  # noqa: E402
import jax.random as jrandom  # noqa: E402

import galsim  # noqa: E402

In [48]:
gal_size = 0.5
psf_fwhm = 0.9

gal = galsim.Gaussian(fwhm=gal_size)
gal = galsim.DeVaucouleurs(half_light_radius=gal_size)
gal = gal.shear(e1=-0.2, e2=0.1)
psf = galsim.Gaussian(fwhm=psf_fwhm)

img = galsim.Convolve(
    gal, psf
).shift(-0.1, 0.1).drawImage(scale=0.2, nx=51, ny=51)

img_gal_only = gal.shift(-0.1, 0.1).drawImage(scale=0.2, nx=51, ny=51, method="no_pixel")

img_psf_only = psf.drawImage(scale=0.2, nx=51, ny=51)

In [49]:
from ngmix.flags import get_flags_str
from admom_jax import admom
from admom_core import Obs, fwhm_to_T, gen_guess_admom

cenonly = False
cen = (img.array.shape[0] - 1) / 2

print("gal T:", fwhm_to_T(gal_size))

obs_psf = obs = Obs(
    image=img_psf_only.array,
    weight=jnp.ones_like(img.array),
    cen_x=cen,
    cen_y=cen,
    dudx=0.2,
    dudy=0.0,
    dvdx=0.0,
    dvdy=0.2,
)

res_psf = admom(
    [obs_psf],
    gen_guess_admom(jrandom.key(0), guess_T=0.7, jac_scale=0, rng_scale=0),
    unroll=10,
    cenonly=cenonly,
    maxitr=200,
)

print("psf T trueT:", res_psf[0][4], fwhm_to_T(psf_fwhm))

obs = Obs(
    image=img.array,
    weight=jnp.ones_like(img.array),
    cen_x=cen,
    cen_y=cen,
    dudx=0.2,
    dudy=0.0,
    dvdx=0.0,
    dvdy=0.2,
    psf_T=res_psf[0][4],
)

res = admom(
    [obs],
    gen_guess_admom(jrandom.key(0), guess_T=0.7, jac_scale=0, rng_scale=0),
    unroll=10,
    cenonly=cenonly,
    maxitr=200,
)

obs_gal_only = Obs(
    image=img_gal_only.array,
    weight=jnp.ones_like(img_gal_only.array),
    cen_x=cen,
    cen_y=cen,
    dudx=0.2,
    dudy=0.0,
    dvdx=0.0,
    dvdy=0.2,
    psf_T=0.0,
)

res_gal_only = admom(
    [obs_gal_only],
    gen_guess_admom(jrandom.key(0), guess_T=0.7, jac_scale=0, rng_scale=0),
    unroll=10,
    cenonly=cenonly,
    maxitr=200,
)

print("gal+psf  [v,y,e1,e2,T,f] flags:", res[0], f"'{get_flags_str(res[1])}'")
print("gal only [v,y,e1,e2,T,f] flags:", res_gal_only[0], f"'{get_flags_str(res[1])}'")

gal T: 0.09016844005556021
psf T trueT: 0.2988577575396447 0.2921457457800151
gal+psf  [v,y,e1,e2,T,f] flags: [ 0.09972017 -0.0997425  -0.14851261  0.07426018  0.23077588  0.37410313] ''
gal only [v,y,e1,e2,T,f] flags: [ 0.09999653 -0.09999723 -0.13440801  0.13037816  0.08198719  0.21868884] ''


In [125]:
gal = galsim.Gaussian(fwhm=0.5)
gal = galsim.Exponential(half_light_radius=0.5)
gal = gal.shear(e1=-0.2, e2=0.1)
obs_list = []

for i, psf_fwhm in enumerate([0.9, 0.7, 1.1]):
    psf = galsim.Gaussian(fwhm=psf_fwhm)
    img = galsim.Convolve(
        gal,
        psf,
    ).shift(-0.1, 0.1).drawImage(scale=0.2, nx=51, ny=51)

    img_psf_only = psf.drawImage(scale=0.2, nx=51, ny=51)
    obs_psf = Obs(
        image=img_psf_only.array,
        weight=jnp.ones_like(img_psf_only.array) * 1e12,
        cen_x=cen,
        cen_y=cen,
        dudx=0.2,
        dudy=0.0,
        dvdx=0.0,
        dvdy=0.2,
    )
    res_psf = admom(
        [obs_psf],
        gen_guess_admom(jrandom.key(0), guess_T=0.7, jac_scale=0, rng_scale=0),
        unroll=10,
        cenonly=cenonly,
        maxitr=200,
    )

    obs_list.append(Obs(
        image=img.array * (i + 1),
        weight=jnp.ones_like(img.array) * 1e8,
        cen_x=cen,
        cen_y=cen,
        dudx=0.2,
        dudy=0.0,
        dvdx=0.0,
        dvdy=0.2,
        psf_T=res_psf[0][4],
        psf=obs_psf,
    ))

res = admom(
    obs_list,
    gen_guess_admom(jrandom.key(0), n_obs=len(obs_list), guess_T=0.7, jac_scale=0, rng_scale=0),
    unroll=10,
    cenonly=cenonly,
    maxitr=20,
)

print("gal+psf  [v,y,e1,e2,T,f0,...,fn] flags:", res[0], f"'{get_flags_str(res[1])}'")

gal+psf  [v,y,e1,e2,T,f0,...,fn] flags: [ 0.09873036 -0.09888167 -0.18533163  0.09266735  0.36865253  0.47311857
  0.9480117   1.42151347] ''


In [123]:
import numpy as np

import ngmix

from admom_core import obs_to_ngmix_obs
from metadetect.fitting import fit_mbobs_gauss
from metadetect.procflags import get_procflags_str

mbobs = ngmix.MultiBandObsList()
for obs in obs_list:
    ol = ngmix.ObsList()
    ol.append(obs_to_ngmix_obs(obs))
    mbobs.append(ol)

res = fit_mbobs_gauss(mbobs=mbobs, bmask_flags=0, rng=np.random.RandomState(42))

print(
    "mbobs    [e1,e2] T flags:",
    np.array(ngmix.shape.g1g2_to_e1e2(*res["gauss_g"][0])),
    res["gauss_T"][0],
    f"'{get_procflags_str(res['gauss_flags'])}'"
)


mbobs    [e1,e2] T flags: [-0.18780845  0.09387851] 0.36384801361922803 ''
