In [1]:
import jaxtromet
import jax
import jax.numpy as jnp
from astropy.time import Time
import numpy as np
from jax.config import config
config.update("jax_debug_nans", True)
config.update("jax_enable_x64", True)



In [22]:
SOURCE_ID: int = 2467955656448455040
RA = 27.998863499
DEC = -5.161573806
KAPPA = 8.144
BLENDING_G = 1
BASELINE_G = 20.41

In [23]:
import scanninglaw.times
from scanninglaw.source import Source
from scanninglaw.config import config

In [24]:
SCANNING_TIMES = scanninglaw.times.Times(version='dr3_nominal')

Loading auxilliary data ...
t = 32.986 s
  auxilliary:  22.812 s
          sf:   2.756 s
interpolator:   7.419 s


In [31]:
def define_lensed_quasar(ra, dec, mass_l, pi_lens,
                         pmrac_lens, pmdec_lens, t_0, beta_0,
                         epoch: float = 2016.):
    """
    Defines astromet parameters for a microlensed quasar with assumed
    parallax and proper motions of zero and a dark lens (blending=1)
    Args:
        - ra            float - right ascension, deg
        - dec           float - declination, deg
        - mass_l        float - lens mass in solar masses
        - pi_lens       float - lens parallax in mas
        - pmrac_lens    float - lens proper motion in rac direction in mas/yr
        - pmdec_lens    float - lens proper motion in dec direction in mas/yr
        - t_0           float - closest approach time
        - beta_0        float - closest approach distance on the sky in mas
    Returns:
        - params        astromet parameters
    """
    params = {}

    params['drac'] = 0  # mas
    params['ddec'] = 0  # mas

    # binary parameters
    # left default for a non-binary system
    params['period'] = 1  # year
    params['a'] = 0  # AU
    params['e'] = 0.1
    params['q'] = 0
    params['l'] = 0  # assumed < 1 (though may not matter)
    params['vtheta'] = jnp.pi / 4
    params['vphi'] = jnp.pi / 4
    params['vomega'] = 0
    params['tperi'] = 0  # jyear

    # blend parameters
    params['blenddrac'] = 0  # mas
    params['blendddec'] = 0  # mas

    # Below are assumed to be derived from other params
    # (I.e. not(!) specified by user)
    params['totalmass'] = -1  # solar mass
    params['Delta'] = -1

    # the epoch determines when RA and Dec (and other astrometry)
    # are centred - for dr3 it's 2016.0, dr2 2015.5, dr1 2015.0
    params['epoch'] = epoch
    
    params['ra'] = ra
    params['dec'] = dec

    # source motion
    params['parallax'] = 0.
    params['pmrac'] = 0.
    params['pmdec'] = 0.

    # lens motion
    params['blendparallax'] = pi_lens
    params['blendpmrac'] = pmrac_lens
    params['blendpmdec'] = pmdec_lens

    # lensing event
    params['thetaE'] = jnp.sqrt(KAPPA*mass_l*pi_lens)
    params['blendl'] = 0.

    t_0 = Time(t_0+2450000, format='jd').decimalyear
    mu_rel = jnp.array([params['pmrac'] - params['blendpmrac'],
                      params['pmdec'] - params['blendpmdec']])
    offset_t0 = mu_rel*(t_0-params['epoch'])
    offset_u0_dir = jnp.array([mu_rel[1], -mu_rel[0]])
    offset_u0 = offset_u0_dir/jnp.linalg.norm(offset_u0_dir) * beta_0  # separation at t0
    offset_mas = offset_t0 - offset_u0
    params['blenddrac'], params['blendddec'] = offset_mas[0], offset_mas[1]

    return params

In [32]:
params = define_lensed_quasar(RA, DEC, 1e5, 0.125, -1, 1., 7200., 1.)
params

{'drac': 0,
 'ddec': 0,
 'period': 1,
 'a': 0,
 'e': 0.1,
 'q': 0,
 'l': 0,
 'vtheta': 0.7853981633974483,
 'vphi': 0.7853981633974483,
 'vomega': 0,
 'tperi': 0,
 'blenddrac': DeviceArray(0.19066843, dtype=float64),
 'blendddec': DeviceArray(1.22354514, dtype=float64),
 'totalmass': -1,
 'Delta': -1,
 'epoch': 2016.0,
 'ra': 27.998863499,
 'dec': -5.161573806,
 'parallax': 0.0,
 'pmrac': 0.0,
 'pmdec': 0.0,
 'blendparallax': 0.125,
 'blendpmrac': -1,
 'blendpmdec': 1.0,
 'thetaE': DeviceArray(319.06112267, dtype=float64, weak_type=True),
 'blendl': 0.0}

In [33]:
def generate_times():
    # get Gaia scanning pattern for the field
    c = Source(RA,
               DEC,
               unit='deg')
    sl = SCANNING_TIMES(c,
                        return_times=True,
                        return_angles=True)
    ts = 2010 + np.squeeze(np.hstack(sl['times']))/365.25
    
    sorted_arg_ts = np.argsort(ts)
    
    ts = np.double(ts[sorted_arg_ts])
    phis = np.double(np.squeeze(np.hstack(sl['angles']))[sorted_arg_ts])
    
    return ts, phis

In [34]:
# TODO: write utility functions for those
# for now, these are global
times, phis = generate_times()
times = jnp.array(times)
phis = jnp.array(phis)
barypos = jaxtromet.barycentricPosition(times)

In [35]:
def mock_observations(mass_lens: jnp.float32, 
                      pi_lens: jnp.float32,
                      pmrac_lens: jnp.float32,
                      pmdec_lens: jnp.float32,
                      t_0: jnp.float32,
                      beta_0: jnp.float32,
                      al_error_scale: jnp.float32,
                      key):
  
    params = define_lensed_quasar(RA, DEC, mass_lens, pi_lens,
                         pmrac_lens, pmdec_lens, t_0, beta_0)
    
    ras, decs, mag_diff = jaxtromet.track(times, barypos, params)
    
    t_obs, x_obs, phi_obs, rac_obs, dec_obs = jaxtromet.mock_obs(times,
                                                                 phis,
                                                                 ras,
                                                                 decs,
                                                                 err=jaxtromet.sigma_ast(BASELINE_G))
    gaia_output=jaxtromet.gaia_results(jaxtromet.fit(t_obs, jaxtromet.barycentricPosition(t_obs), x_obs, phi_obs, jaxtromet.sigma_ast(BASELINE_G), RA, DEC))
    return jaxtromet.fit(t_obs, jaxtromet.barycentricPosition(t_obs), x_obs, phi_obs, jaxtromet.sigma_ast(BASELINE_G), RA, DEC)

In [36]:
key, _ = jax.random.split(jax.random.PRNGKey(4201231), 2)
mock_observations(100., 0.125, 1., 1., 7200., 1., 1., key)

{'UWE': DeviceArray(0.97978907, dtype=float64),
 'chi2': DeviceArray(608.63152123, dtype=float64),
 'ddec': DeviceArray(-0.02952076, dtype=float64),
 'ddec_error': DeviceArray(0.42573744, dtype=float64),
 'ddec_parallax_corr': DeviceArray(-0.05771815, dtype=float64),
 'ddec_pmdec_corr': DeviceArray(-0.27142438, dtype=float64),
 'ddec_pmrac_corr': DeviceArray(0.32423594, dtype=float64),
 'dec_ref': DeviceArray(-5.16157381, dtype=float64, weak_type=True),
 'drac': DeviceArray(-0.4680944, dtype=float64),
 'drac_ddec_corr': DeviceArray(0.358099, dtype=float64),
 'drac_error': DeviceArray(0.56510641, dtype=float64),
 'drac_parallax_corr': DeviceArray(0.07066067, dtype=float64),
 'drac_pmdec_corr': DeviceArray(0.2080313, dtype=float64),
 'drac_pmrac_corr': DeviceArray(-0.07655212, dtype=float64),
 'excess_noise': DeviceArray(0., dtype=float64),
 'n_good_obs': DeviceArray(639, dtype=int64),
 'n_obs': DeviceArray(639, dtype=int64, weak_type=True),
 'parallax': DeviceArray(0.82601988, dtype=flo