In [1]:
# %load_ext autoreload
# %autoreload 2

import jax
jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
import jax.random as jrandom

import galsim

In [2]:
# gal = galsim.Gaussian(fwhm=0.5)
gal = galsim.Exponential(half_light_radius=0.1)
gal = gal.shear(e1=-0.2, e2=0.1)
psf_fwhm = 0.4

img = galsim.Convolve(
    gal,
    galsim.Gaussian(fwhm=psf_fwhm),
).shift(-0.1, 0.1).drawImage(scale=0.2, nx=51, ny=51)

img_gal_only = gal.shift(-0.1, 0.1).drawImage(scale=0.2, nx=51, ny=51)

In [3]:
from ngmix.flags import get_flags_str
from admom_jax import admom, _fwhm_to_T, gen_guess_admom

cenonly = False

# image,
# weight,
# cen_x, cen_y,
# dudx, dudy, dvdx, dvdy,
# rng_key,
# maxitr=200,
# etol=1e-5,
# ttol=1e-5,
# maxshift=5.0,
# unroll=1,

cen = (img.array.shape[0] - 1) / 2

res = admom(
    jnp.array(img.array),
    jnp.ones_like(img.array),
    cen,
    cen,
    0.2, 0.0, 0.0, 0.2,
    gen_guess_admom(jrandom.key(0), guess_T=0.7, jac_scale=0, rng_scale=0),
    _fwhm_to_T(psf_fwhm),
    unroll=10,
    cenonly=cenonly,
    maxitr=200,
)

res_gal_only = admom(
    jnp.array(img_gal_only.array),
    jnp.ones_like(img_gal_only.array),
    cen,
    cen,
    0.2, 0.0, 0.0, 0.2,
    gen_guess_admom(jrandom.key(0), guess_T=0.7, jac_scale=0, rng_scale=0),
    _fwhm_to_T(0.0),
    unroll=10,
    cenonly=cenonly,
    maxitr=200,
)

print("gal+psf  [v,y,e1,e2,T,f] flags:", res[0], get_flags_str(res[1]))
print("gal only [v,y,e1,e2,T,f] flags:", res_gal_only[0], get_flags_str(res_gal_only[1]))

gal+psf  [v,y,e1,e2,T,f] flags: [ 0.09872239 -0.09882295 -0.13427564  0.06917057  0.02587661  0.49674432] 
gal only [v,y,e1,e2,T,f] flags: [ 0.1        -0.1        -0.02652265  0.11469315  0.04306725  0.56283145] 


In [4]:
import numpy as np
import ngmix

from ngmix.admom import run_admom, find_cen_admom

obs = ngmix.Observation(
    image=np.array(img_gal_only.array),
    weight=np.ones_like(img_gal_only.array),
    jacobian=ngmix.DiagonalJacobian(scale=0.2, row=cen, col=cen),
)

if cenonly:
    ngres = find_cen_admom(obs, 0.7, maxiter=200)
else:
    ngres = run_admom(obs, 0.7, cenonly=cenonly, maxiter=200)
print(ngres["flux"], ngres["flags"], ngres["numiter"])

0.5628027012923615 0 46
