In [2]:
%load_ext autoreload
%autoreload 2

import jax
jax.config.update("jax_enable_x64", True)

import jax.numpy as jnp
import jax.random as jrandom

import galsim

In [3]:
gal = galsim.Gaussian(fwhm=0.5)
# gal = galsim.Exponential(half_light_radius=0.5)
gal = gal.shear(e1=-0.2, e2=0.1)
psf_fwhm = 0.9

img = galsim.Convolve(
    gal,
    galsim.Gaussian(fwhm=psf_fwhm),
).shift(-0.1, 0.1).drawImage(scale=0.2, nx=51, ny=51)

img_gal_only = gal.shift(-0.1, 0.1).drawImage(scale=0.2, nx=51, ny=51)

In [4]:
from ngmix.flags import get_flags_str
from admom_jax import admom
from admom_core import Obs, fwhm_to_T, gen_guess_admom, obs_to_admom_obs

cenonly = False
cen = (img.array.shape[0] - 1) / 2

obs = Obs(
    image=img.array,
    weight=jnp.ones_like(img.array),
    cen_x=cen,
    cen_y=cen,
    dudx=0.2,
    dudy=0.0,
    dvdx=0.0,
    dvdy=0.2,
    psf_T=fwhm_to_T(psf_fwhm),
)

res = admom(
    [obs],
    gen_guess_admom(jrandom.key(0), guess_T=0.7, jac_scale=0, rng_scale=0),
    unroll=10,
    cenonly=cenonly,
    maxitr=200,
)

obs_gal_only = Obs(
    image=img_gal_only.array,
    weight=jnp.ones_like(img_gal_only.array),
    cen_x=cen,
    cen_y=cen,
    dudx=0.2,
    dudy=0.0,
    dvdx=0.0,
    dvdy=0.2,
    psf_T=0.0,
)

res_gal_only = admom(
    [obs_gal_only],
    gen_guess_admom(jrandom.key(0), guess_T=0.7, jac_scale=0, rng_scale=0),
    unroll=10,
    cenonly=cenonly,
    maxitr=200,
)

print("gal+psf  [v,y,e1,e2,T,f] flags:", res[0], get_flags_str(res[1]))
print("gal only [v,y,e1,e2,T,f] flags:", res_gal_only[0], get_flags_str(res_gal_only[1]))

gal+psf  [v,y,e1,e2,T,f] flags: [ 0.09281231 -0.09329191 -0.18647416  0.09324537  0.09921244  0.49975894] 
gal only [v,y,e1,e2,T,f] flags: [ 0.09832833 -0.0988329  -0.18438229  0.09331477  0.09946821  0.50055251] 


In [9]:
gal = galsim.Gaussian(fwhm=0.5)
# gal = galsim.Exponential(half_light_radius=0.5)
gal = gal.shear(e1=-0.2, e2=0.1)
obs_list = []

for i, psf_fwhm in enumerate([0.9, 0.7, 1.1]):
    img = galsim.Convolve(
        gal,
        galsim.Gaussian(fwhm=psf_fwhm),
    ).shift(-0.1, 0.1).drawImage(scale=0.2, nx=51, ny=51)

    obs_list.append(Obs(
        image=img.array * (i + 1),
        weight=jnp.ones_like(img.array) * 10 ** (i - 1),
        cen_x=cen,
        cen_y=cen,
        dudx=0.2,
        dudy=0.0,
        dvdx=0.0,
        dvdy=0.2,
        psf_T=fwhm_to_T(psf_fwhm),
    ))

res = admom(
    obs_list,
    gen_guess_admom(jrandom.key(0), n_obs=len(obs_list), guess_T=0.7, jac_scale=0, rng_scale=0),
    unroll=10,
    cenonly=cenonly,
    maxitr=20,
)

print("gal+psf  [v,y,e1,e2,T,f0,...,fn] flags:", res[0], get_flags_str(res[1]))

gal+psf  [v,y,e1,e2,T,f0,...,fn] flags: [ 0.09208565 -0.09248379 -0.18649318  0.09325204  0.0992054   0.49969545
  0.99911226  1.49934406] 
