## JAX Memo Code

Example code in §2.3 of [HERA Memorandum #84](http://reionization.org/wp-content/uploads/2013/03/HERA084__A_Generalized_Approach_to_Redundant_Calibration_with_JAX.pdf), which uses the [JAX](https://github.com/google/jax) library.

In [None]:
from jax import numpy as jnp
from simpleredcal.red_likelihood import doRelCal, group_data, relabelAnts
from simpleredcal.red_utils import find_flag_file, find_zen_file, get_bad_ants

In [None]:
# Select dataset to calibrate
JD = 2458098.43869
zen_fn = find_zen_file(JD) # find path of dataset
bad_ants = get_bad_ants(zen_fn) # get bad antennas from commissioning
flags_fn = find_flag_file(JD, 'first') # import flags from firstcal

# Load dataset from uvh5 file to numpy array, with flagging applied
hdraw, RedG, cMData = group_data(zen_fn, pol='ee', chans=605, tints=0, \
                                 bad_ants=bad_ants, flag_path=flags_fn)
# 0 out of 741 data points flagged for visibility dataset 
# zen.2458098.43869.HH.uvh5

cData = jnp.squeeze(cMData.filled()) # filled with nans for flags
no_ants = jnp.unique(RedG[:, 1:]).size # number of antennas
no_unq_bls = jnp.unique(RedG[:, 0]).size # number of redundant baselines
cRedG = relabelAnts(RedG) # relabel antennas with consecutive numbering

In [None]:
res_rel = doRelCal(cRedG, cData, no_unq_bls, no_ants, distribution='cauchy', \
                   coords='cartesian', bounded=False, norm_gains=True)
# Optimization terminated successfully.

In [None]:
# from jax import jit

# ff = jit(functools.partial(relative_logLkl, credg, distribution, obsvis, \
#                            no_unq_bls, coords))

In [None]:
# from jax import jacfwd, jacrev

# jac = jacrev(ff) # Jacobian; reverse-mode faster for fewer outputs than inputs
# hess = jacfwd(jacrev(ff)) # Hessian; forward-over-reverse is more efficient

# res = minimize(ff, initp, bounds=bounds, method=method, \
#                jac=jac, hess=hess, options={'maxiter':max_nit})

## Deg Memo Code

Fully worked example in §3.1 of [HERA Memorandum #94](http://reionization.org/manual_uploads/HERA094__Comparing_Visibility_Solutions_from_Relative_Redundant_Calibration_by_Degenerate_Translation.pdf).

In [None]:
from hera_cal.io import HERAData
from jax import numpy as jnp
from simpleredcal.red_likelihood import doDegVisVis, doRelCal, group_data, red_ant_sep, \
relabelAnts, split_rel_results
from simpleredcal.red_utils import find_flag_file, find_nearest, find_zen_file, get_bad_ants, \
match_lst

In [None]:
# Select 1st dataset to relatively calibrate
JD1 = 2458098.43869
chan = 605 # frequency channel
time_int1 = 0 # time integration of 1st dataset
noise_dist = 'gaussian' # assumed noise distribution
coords = 'cartesian' # parameter coordinate system

zen_fn1 = find_zen_file(JD1) # find path of dataset
bad_ants1 = get_bad_ants(zen_fn1) # get bad antennas from commissioning
flags_fn1 = find_flag_file(JD1, 'first') # import flags from firstcal
print('Bad antennas for JD {} are: {}'.format(JD1, bad_ants1))

In [None]:
# Load dataset from uvh5 file to numpy array, with flagging applied
hdraw1, RedG1, cMData1 = group_data(zen_fn1, pol='ee', chans=chan, tints=time_int1, \
                                    bad_ants=bad_ants1, flag_path=flags_fn1)
# 0 out of 741 data points flagged for visibility dataset 
# zen.2458098.43869.HH.uvh5

cData1 = jnp.squeeze(cMData1.filled()) # filled with nans for flags
ants = jnp.unique(RedG1[:, 1:])
no_ants = ants.size # number of antennas
no_unq_bls = jnp.unique(RedG1[:, 0]).size # number of redundant baselines
cRedG1 = relabelAnts(RedG1) # relabel antennas with consecutive numbering

In [None]:
# Select 2nd dataset to relatively calibrate, that matches the LAST of the 1st
JD2 = match_lst(JD1, 2458099, tint=time_int1) # finding the JD_time of the dataset
# that matches the LAST of the dataset used in 1
zen_fn2 = find_zen_file(JD2)
bad_ants2 = get_bad_ants(zen_fn2)
flags_fn2 = find_flag_file(JD2, 'first')

# Find time int in dataset 2 that corresponds to closest LAST to that of dataset 1
hdraw2 = HERAData(zen_fn2)
time_int2 = int(find_nearest(hdraw2.lsts, hdraw1.lsts[time_int1])[1])

In [None]:
# Load dataset from uvh5 file to numpy array, with flagging applied
_, RedG2, cMData2 = group_data(zen_fn2, pol='ee', chans=chan, tints=time_int2, \
                               bad_ants=bad_ants2, flag_path=flags_fn2)
# 0 out of 741 data points flagged for visibility dataset 
# zen.2458098.43869.HH.uvh5

cData2 = jnp.squeeze(cMData2.filled()) # filled with nans for flags

print('Do the visibilities for JDs {} and {} have:\nthe same bad antennas? {}\n'\
      'the same redundant grouping? {}'.format(JD1, JD2, (bad_ants1 == bad_ants2)\
      .all(), (RedG1==RedG2).all()))

# Do the visibilities for JDs 2458098.43869 and 2458099.43124 have:
# the same bad antennas? True
# the same redundant grouping? True

In [None]:
# Relative redundant calibration of the 1st dataset
res_rel1, initp = doRelCal(cRedG1, cData1, no_unq_bls, no_ants, distribution=noise_dist, \
                           coords=coords, norm_gains=True, return_initp=True)
# Optimization terminated successfully.

In [None]:
# Relative redundant calibration of the 2nd dataset
res_rel2 = doRelCal(cRedG1, cData2, no_unq_bls, no_ants, distribution=noise_dist, \
                    coords=coords, norm_gains=True, initp=initp, phase_reg_initp=True)
# Optimization terminated successfully.

In [None]:
# Get the relatively calibrated gain and visibility solutions
res_rel_vis1, res_rel_gains2 = split_rel_results(res_rel1['x'], no_unq_bls, \
                                                 coords=coords)
res_rel_vis2, res_rel_gains2 = split_rel_results(res_rel2['x'], no_unq_bls, \
                                                 coords=coords)

In [None]:
# Translating between relatively calibrated visibility sets
ant_sep = red_ant_sep(RedG1, hdraw1.antpos)
res_deg = doDegVisVis(ant_sep, res_rel_vis1, res_rel_vis2, distribution=noise_dist)
# Optimization terminated successfully.
print('Degenerate parameters are:\nAmplitude = {}\n'\
      'Phase gradient in x = {:e}\nPhase gradient in y = {:e}'.format(*res_deg['x']))
# Degenerate parameters are:
# Amplitude = 0.9920773888258738
# Phase gradient in x = -6.344390e-06
# Phase gradient in y = 8.818280e-05

In [None]:
# # Cauchy
# Degenerate parameters are:
# Amplitude = 0.992077030076839
# Phase gradient in x = -6.347452e-06
# Phase gradient in y = 8.818999e-05

In [None]:
# python rel_cal.py '2458099.43124' --pol 'ee' --flag_type 'first' --dist 'gaussian'

# python rel_cal.py '2458099.43869' --pol 'ee' --flag_type 'first' --dist 'gaussian'

# python rel_cal.py '2458098.43869' --pol 'ee' --flag_type 'first' --dist 'gaussian' \
# --initp_jd 2458099

# python deg_cal.py '2458098.43869' --deg_dim 'jd' --pol 'ee' --dist 'gaussian' \
# --tgt_jd 2458099