<center><strong><font size=+3>Applications of robust 2D median estimators to HERA data</font></center>
<br><br>
</center>
<center><strong><font size=+2>Matyas Molnar and Bojan Nikolic</font><br></strong></center>
<br><center><strong><font size=+1>Astrophysics Group, Cavendish Laboratory, University of Cambridge</font></strong></center>

In [None]:
import itertools
import os
import textwrap

import matplotlib as mpl
import numpy as np
import seaborn as sns
from matplotlib import gridspec
from matplotlib import pyplot as plt
from mpl_toolkits import mplot3d
from mpl_toolkits.axes_grid1.inset_locator import mark_inset, zoomed_inset_axes
from scipy import signal
from scipy.stats import chi2, shapiro
from scipy.stats.mstats import gmean as geometric_mean
from statsmodels.nonparametric.kernel_regression import KernelReg

from hera_cal.io import HERAData
from hera_cal.redcal import get_reds

from robstat.hera_vis import agg_tint_rephase
from robstat.ml import extrem_nans, nan_interp2d
from robstat.plotting import grid_heatmaps, SeabornFig2Grid, row_heatmaps
from robstat.robstat import c_mardia_median, geometric_median, mardia_median, mv_median, \
mv_normality, mv_outlier, tukey_median
from robstat.stdstat import mad_clip, rsc_mean
from robstat.utils import DATAPATH, flt_nan

In [None]:
plt.rcParams['figure.figsize'] = (12, 8)
%matplotlib inline

In [None]:
# turn on multiprocessing
mp = True
import multiprocess as multiprocessing

from matplotlib import rc
rc('font',**{'family':'serif','serif':['cm']})
rc('text', usetex=True)
rc('text.latex', preamble=r'\usepackage{amssymb} \usepackage{amsmath}')

plot_figs = False
if plot_figs:
    mpl.rcParams['figure.dpi'] = 300

In [None]:
# params
mad_sigma = 5.0 # sigma threshold for MAD-clipping, default is 5
no_bins_agg = 2 # averaging over n consecutive time bins in LST averaging
# (2 by default, like in HERA analysis pipeline)

### Load HERA visibility data

In [None]:
sample_data = os.path.join(DATAPATH, 'zen.2458098.43869.HH.OCRSA.uvh5')

hd = HERAData(sample_data)
data, flags, _ = hd.read()

reds = get_reds(hd.antpos, pols=hd.pols)
flat_bls = [bl for grp in reds for bl in grp if bl in data.keys()]
reds = [grp for grp in reds if set(grp).issubset(flat_bls)]
bl_dict = {k: i for i, k in enumerate(flat_bls)}

data = {k: np.ma.array(v, mask=flags[k], fill_value=np.nan) for k, v \
        in data.items()}
mdata = np.ma.empty((hd.Nfreqs, hd.Ntimes, hd.Nbls), fill_value=np.nan, \
                     dtype=complex)
for i, bl in enumerate(flat_bls):
    mdata[..., i] = data[bl].transpose()

data = mdata.filled() # dimensions (freqs, times, bls)
flags = mdata.mask

### Redundant averaging

In [None]:
slct_bls = reds[0]
slct_bl_idxs = np.array([bl_dict[slct_bl] for slct_bl in slct_bls])
slct_data = data[..., slct_bl_idxs]
slct_flags = flags[..., slct_bl_idxs]
assert slct_flags.sum() == np.isnan(slct_data).sum()
print('Looking at baselines redundant to {}'.format(slct_bls[0]))

In [None]:
# Look at one time integration / frequency slice with high variance
idxs = np.unravel_index(np.nanargmax(np.nanstd(slct_data, axis=-1)), \
                        slct_data.shape[:2])
print('Selecting freq / time slice {}'.format(idxs))
slct_data_slice = slct_data[idxs[0], idxs[1], :]

sample_gmean = geometric_mean(flt_nan(slct_data_slice))
sample_gmed = geometric_median(slct_data_slice, options=dict(keep_res=True))
sample_tmed = tukey_median(slct_data_slice)['barycenter']
sample_mmed = c_mardia_median(slct_data_slice)
marg_med = lambda x : np.nanmedian(x.real) + np.nanmedian(x.imag)*1j
sample_bmed = marg_med(slct_data_slice)
sample_hmean = rsc_mean(slct_data_slice, sigma=mad_sigma)

In [None]:
med_ests = list(zip([sample_gmean, sample_gmed, sample_tmed, sample_mmed, sample_bmed, sample_hmean], 
               ['Geometric Mean', 'Geometric Median', 'Tukey Median', 'Mardia Median', \
                'Marginal Median', 'HERA Mean']))
for me in med_ests:
    print('{:17s}: {:4f}'.format(me[1], me[0]))

In [None]:
re_label = r'$\mathfrak{Re}(V)$'
im_label = r'$\mathfrak{Im}(V)$'

In [None]:
fig, ax = plt.subplots(figsize=(7, 7))

ax.scatter(slct_data_slice.real, slct_data_slice.imag, alpha=0.5)
ax.plot(sample_gmean.real, sample_gmean.imag, 'co', label='Geometric Mean')
ax.plot(sample_gmed.real, sample_gmed.imag, 'ro', label='Geometric Median')
ax.plot(sample_tmed.real, sample_tmed.imag, 'yo', label='Tukey Median')
ax.plot(sample_mmed.real, sample_mmed.imag, 'ko', label='Mardia Median')
ax.plot(sample_bmed.real, sample_bmed.imag, 'bo', label='Marginal Median')
ax.plot(sample_hmean.real, sample_hmean.imag, 'go', label='HERA Mean')

ax.annotate(slct_bls[0], xy=(0.05, 0.05), xycoords='axes fraction', \
            bbox=dict(boxstyle='round', facecolor='white'), size=12)
ax.annotate('Freq: {:.2f} MHz, LST: {:.3f}'.format(hd.freqs[idxs[0]]/1e6, hd.lsts[idxs[1]]*12/np.pi), \
            xy=(0.05, 0.95), xycoords='axes fraction', \
            bbox=dict(boxstyle='round', facecolor='white'), size=12)

ax.set_xlabel(re_label)
ax.set_ylabel(im_label)

plt.legend(loc='lower right', prop={'size': 10})
plt.tight_layout()
plt.show()

In [None]:
res_dir = os.path.join(DATAPATH, 'loc_res')
if not os.path.exists(res_dir):
    os.mkdir(res_dir)
    
red_res_fn = os.path.join(res_dir, os.path.basename(sample_data).replace('.uvh5', '.res.npz'))
if not os.path.exists(red_res_fn):

    time_int = np.where(~np.isnan(data).all(axis=(0, 2)))[0][0] # first non-nan index

    gmean_res = np.empty((hd.Nfreqs, len(reds)), dtype=complex)
    gmed_res, tmed_res, mmed_res, bmed_res, hmean_res = \
        [np.empty_like(gmean_res) for _ in range(5)]

    gmed_bf_init = None
    for bl, bl_grp in enumerate(reds):
        slct_bl_idxs = np.array([bl_dict[slct_bl] for slct_bl in bl_grp])
        for f, frow in enumerate(data[:, time_int, slct_bl_idxs]):
            if np.isnan(frow).all():
                gmean_bf = gmed_bf = tmed_bf = mmed_bf = bmed_bf = hmean_bf = np.nan + 1j*np.nan
            else:
                gmean_bf = geometric_mean(flt_nan(frow))
                gmed_bf = geometric_median(frow, init_guess=gmed_bf_init, options=dict(keep_res=True))
                gmed_bf_init = gmed_bf
                tmed_bf = tukey_median(frow)['barycenter']
                mmed_bf = c_mardia_median(frow, init_guess=None)
                bmed_bf = marg_med(frow)
                hmean_bf = rsc_mean(frow, sigma=mad_sigma)
            gmean_res[f, bl] = gmean_bf
            gmed_res[f, bl] = gmed_bf
            tmed_res[f, bl] = tmed_bf
            mmed_res[f, bl] = mmed_bf
            bmed_res[f, bl] = bmed_bf
            hmean_res[f, bl] = hmean_bf
            
    np.savez(red_res_fn, gmean_res=gmean_res, gmed_res=gmed_res, tmed_res=tmed_res, \
             mmed_res=mmed_res, bmed_res=bmed_res, hmean_res=hmean_res)

else:
    red_res = np.load(red_res_fn)
    gmean_res = red_res['gmean_res']
    gmed_res = red_res['gmed_res']
    tmed_res = red_res['tmed_res']
    mmed_res = red_res['mmed_res']
    bmed_res = red_res['bmed_res']
    hmean_res = red_res['hmean_res']
        
med_est_res = list(zip([i[1] for i in med_ests], \
    [gmean_res, gmed_res, tmed_res, mmed_res, bmed_res, hmean_res]))

In [None]:
fig = plt.figure(constrained_layout=True, figsize=(10, 20), dpi=100)
spec = gridspec.GridSpec(nrows=2*len(med_ests), figure=fig, ncols=2)

axes = []
for i in range(len(med_ests)):
    ax1 = fig.add_subplot(spec[i*2:2+i*2, 0])
    ax2 = fig.add_subplot(spec[i*2, 1])
    ax3 = fig.add_subplot(spec[i*2+1, 1])
    axes.append([ax1, ax2, ax3])

color = [None for i in med_est_res]
for m, med_est in enumerate(med_est_res):
    for i, bl_grp in enumerate(range(len(reds))):
        axes[m][0].plot(hd.freqs, med_est[1][:, i].real, color=color[m], \
            label='{}'.format(reds[i][0]) + re_label)
        c = axes[m][0].get_lines()[-1].get_color()
        color[m] = next(axes[m][0]._get_lines.prop_cycler)['color']
        axes[m][0].plot(hd.freqs, med_est[1][:, i].imag, color=c, \
            label='{}'.format(reds[i][0]) + im_label, ls='--')
        axes[m][1].plot(hd.freqs, np.abs(med_est[1][:, i]), color=c, \
            label='{}'.format(reds[i][0]))
        axes[m][2].plot(hd.freqs, np.angle(med_est[1][:, i]), color=c, \
            label='{}'.format(reds[i][0]), ls='--')
        axes[m][0].text(x=0.05, y=0.5, s=med_est[0], transform=axes[m][0].transAxes, \
            fontsize=10, style='normal', weight='light')

for ax in axes:
    ax[0].set_ylabel(r'$V$')
    ax[1].set_ylabel(r'$|V|$')
    ax[2].set_ylabel(r'$\varphi$')
    
for ax in axes[-1]:
    ax.set_xlabel(r'$\nu$')
    
axes[0][0].set_title('Cartesian')
axes[0][1].set_title('Polar')

for ax in axes[0]:
    ax.legend(framealpha=0.5, loc=1)

plt.suptitle('Median estimates for 14-m EW baselines')
plt.show()

### LST averaging

In [None]:
xd_vis_file = os.path.join(DATAPATH, 'xd_vis_extd_rph.npz')
# xd_vis_file = os.path.join(DATAPATH, 'lstb_no_avg/idr2_lstb_14m_ee_1.40949.npz')
sample_xd_data = np.load(xd_vis_file)

In [None]:
xd_data = sample_xd_data['data']
xd_redg = sample_xd_data['redg']
xd_pol = sample_xd_data['pol'].item()
xd_rad_lsts = sample_xd_data['lsts']
xd_hr_lsts = xd_rad_lsts*12/np.pi # in hours
JDs = sample_xd_data['JDs']
no_days = JDs.size

lstb_format = 'lstb_no_avg' in xd_vis_file

if lstb_format:
    band_1 = [175, 334]
    band_2 = [515, 694]

    band_i = band_2 # select band here
    chans = np.arange(band_i[0], band_i[1]+1)
    plt_chans = chans

    # data dimensions (2xdays, freqs, times, bls)
    xd_data = xd_data[:, chans, ...]
    
    xd_flags = np.isnan(xd_data)
    no_chans = xd_data.shape[1]
    freqs = np.linspace(1e8, 2e8, 1025)[:-1][chans]
    new_no_tints = xd_data.shape[2]
    no_bins_agg = 1
    avg_hr_lsts = xd_hr_lsts

else:
    # data dimensions (days, freqs, times, bls)
    xd_flags = sample_xd_data['flags']
    xd_data[xd_flags] *= np.nan

    xd_times = sample_xd_data['times']
    avg_hr_lsts = np.mean(xd_hr_lsts.reshape(-1, no_bins_agg), axis=1)
    JDs = sample_xd_data['JDs']

    freqs = sample_xd_data['freqs']
    chans = sample_xd_data['chans']
    if chans[-1]%100 == 99:
        plt_chans = np.append(chans, chans[-1]+1)
    else:
        plt_chans = chans

    no_chans = chans.size
    no_tints = xd_times.size
    new_no_tints = int(np.ceil(no_tints/no_bins_agg))
    
    no_bins_agg = 2 # averaging over n consecutive time bins in LST averaging

    # rephase if averaging over consecutive time bins
    if 'rph' in os.path.basename(xd_vis_file) and no_bins_agg > 1:
        print('Rephasing visibilities such that every {} rows in time have the same phase centre.'.format(no_bins_agg))
        xd_antpos = np.load(xd_vis_file, allow_pickle=True)['antpos'].item()
        xd_data = agg_tint_rephase(xd_data, xd_redg, freqs, xd_pol, xd_rad_lsts, xd_antpos, \
                                   no_bins_agg=no_bins_agg)

In [None]:
myround = lambda x, base=25: base * max(1, round(x/base))
tbase = myround(new_no_tints/5, base=10)
fbase = myround(no_chans/5)

In [None]:
bl_grp = 0 # only look at 0th baseline group

slct_bl_idxs = np.where(xd_redg[:, 0] == bl_grp)[0]
flags = xd_flags[..., slct_bl_idxs]
slct_red_bl = xd_redg[slct_bl_idxs[0], :][1:]
xd_data_bls = xd_data[..., slct_bl_idxs]
no_bls = slct_bl_idxs.size
print('Looking at baselines redundant to ({}, {}, \'{}\')'.\
      format(*slct_red_bl, xd_pol))

In [None]:
slct_no_bls = 4 # just pick the first four baselines from the selected baseline group

lst_res_fn = os.path.join(res_dir, os.path.basename(xd_vis_file).replace('.npz', '.lst_res.npz'))

if not os.path.exists(lst_res_fn):

    def freq_iter(freq):
        xd_gmed_res_t_f = np.empty((1, new_no_tints, slct_no_bls), dtype=complex)
        xd_tmed_res_t_f, xd_bmed_res_t_f, xd_hmean_res_t_f = \
            [np.empty_like(xd_gmed_res_t_f) for _ in range(3)]

        gmed_ft_init = None
        for bl in range(slct_no_bls):
            for tint in range(new_no_tints):
                if lstb_format:
                    xd_data_bft = xd_data_bls[:, freq, tint, bl].flatten()
                else:
                    # use no_bins_agg time integrations for each median
                    # (2 consecutive ones are used in HERA LST-binning)
                    xd_data_bft = xd_data_bls[:, freq, no_bins_agg*tint:no_bins_agg*tint+no_bins_agg, \
                                              bl].flatten()

                if np.isnan(xd_data_bft).all():
                    gmed_ft = tmed_ft = bmed_ft = hmean_ft = np.nan + 1j*np.nan
                else:
                    gmed_ft = geometric_median(xd_data_bft, init_guess=gmed_ft_init, \
                                               options=dict(keep_res=True))
                    gmed_ft_init = gmed_ft
                    tmed_ft = tukey_median(xd_data_bft)['barycenter']
                    bmed_ft = marg_med(xd_data_bft)
                    hmean_ft = rsc_mean(xd_data_bft, sigma=mad_sigma)

                xd_gmed_res_t_f[:, tint, bl] = gmed_ft
                xd_tmed_res_t_f[:, tint, bl] = tmed_ft
                xd_bmed_res_t_f[:, tint, bl] = bmed_ft
                xd_hmean_res_t_f[:, tint, bl] = hmean_ft

        return xd_gmed_res_t_f, xd_tmed_res_t_f, xd_bmed_res_t_f, xd_hmean_res_t_f

    if mp:
        m_pool = multiprocessing.Pool(multiprocessing.cpu_count())
        pool_res = m_pool.map(freq_iter, range(no_chans))
        m_pool.close()
        m_pool.join()
    else:
        pool_res = list(map(freq_iter, range(no_chans)))

    loc_res = np.concatenate(pool_res, axis=1)
    xd_gmed_res_t = loc_res[0, ...]
    xd_tmed_res_t = loc_res[1, ...]
    xd_bmed_res_t = loc_res[2, ...]
    xd_hmean_res_t = loc_res[3, ...]

    np.savez(lst_res_fn, xd_gmed_res_t=xd_gmed_res_t, xd_tmed_res_t=xd_tmed_res_t, \
             xd_bmed_res_t=xd_bmed_res_t, xd_hmean_res_t=xd_hmean_res_t)

else:
    lst_res = np.load(lst_res_fn)
    xd_gmed_res_t = lst_res['xd_gmed_res_t']
    xd_tmed_res_t = lst_res['xd_tmed_res_t']
    xd_bmed_res_t = lst_res['xd_bmed_res_t']
    xd_hmean_res_t = lst_res['xd_hmean_res_t']

In [None]:
arrs = [xd_gmed_res_t, xd_tmed_res_t, xd_bmed_res_t, xd_hmean_res_t]
flt_arrs = []
for arr in arrs:
    nan_bl = np.isnan(arr).all(axis=(0, 1))
    if nan_bl.any():
        arr = np.delete(arr, np.where(nan_bl)[0], axis=-1)
    flt_arrs.append(arr)

In [None]:
grid_arrs = [[arr[..., i] for i in range(flt_arrs[0].shape[-1])] for arr in flt_arrs]
titles = ['Geometric Median', 'Tukey Median', 'Marginal Median', 'HERA Mean']

ylabels = [str(ylab) + '\n\nFrequency Channel' for ylab in reds[bl_grp][:slct_no_bls]]

grid_heatmaps(grid_arrs, apply_np_fn='abs', titles=titles, xbase=tbase, ybase=fbase, \
              xlabels='Time bin', ylabels=ylabels, clip_pctile=1, yticklabels=plt_chans, \
              figsize=(12, 10))

In [None]:
grid_heatmaps(grid_arrs, apply_np_fn='angle', titles=titles, xbase=tbase, ybase=fbase, \
              xlabels='Time bin', ylabels=ylabels, yticklabels=plt_chans, figsize=(12, 10))

In [None]:
grid_heatmaps(grid_arrs, apply_np_fn='real', titles=titles, xbase=tbase, ybase=fbase, \
              xlabels='Time bin', ylabels=ylabels, clip_pctile=1, yticklabels=plt_chans, \
              figsize=(12, 10))

In [None]:
grid_heatmaps(grid_arrs, apply_np_fn='imag', titles=titles, xbase=tbase, ybase=fbase, \
              xlabels='Time bin', ylabels=ylabels, clip_pctile=1, yticklabels=plt_chans, \
              figsize=(12, 10))

### LST + redundant averaging

In [None]:
# Look at no_bins_agg consecutive time integrations / 1 frequency slice with high variance
idxs = np.asarray(np.unravel_index(np.nanargmax(np.nanstd(xd_data_bls[..., \
    :xd_data_bls.shape[2]-no_bins_agg+1, :], axis=(0, -1))), xd_data_bls.shape[1:-1]))

idxs = [0, 57]

t_adj = idxs[1]%no_bins_agg
if t_adj != 0:
    idxs[1] -= t_adj
    
print('Selecting freq / %time slice: ({}, {}-{})'.format(idxs[0], idxs[1], idxs[1]+no_bins_agg-1))

# Have visibilities across days for the same baseline (2 time bins)
# flatten the data array and perform statistics on the whole dataset
if lstb_format:
    data_slice = xd_data_bls[:, idxs[0], idxs[1], :].flatten()
else:
    data_slice = xd_data_bls[:, idxs[0], idxs[1]:idxs[1]+no_bins_agg, :].flatten()

xd_sample_gmean = geometric_mean(flt_nan(data_slice))
xd_sample_gmed = geometric_median(data_slice, options=dict(keep_res=True))
xd_sample_tmed = tukey_median(data_slice)['barycenter']
xd_sample_mmed = c_mardia_median(data_slice)
xd_sample_bmed = marg_med(data_slice)
xd_sample_hmean = rsc_mean(data_slice, sigma=mad_sigma)

Alternatively, we could take the median of the visibility amplitude and the Mardia median of the phase. While this is an improvement on doing the median on cartesian coordinates separately, it still does not wholly consider the complex data. The geometric median or the Tukey median would be preferable methods.

In [None]:
med_ests = list(zip([xd_sample_gmean, xd_sample_gmed, xd_sample_tmed, xd_sample_mmed, \
                     xd_sample_bmed, xd_sample_hmean], \
               ['Geometric Mean', 'Geometric Median', 'Tukey Median', 'Mardia Median', \
                'Marginal Median', 'HERA Mean'], \
               ['co', 'ro', 'yo', 'ko', 'bo', 'go']))
for me in med_ests:
    print('{:17s}: {:4f}'.format(me[1], me[0]))

In [None]:
fig, ax = plt.subplots(figsize=(6, 6))

ax.scatter(flt_nan(data_slice).real, flt_nan(data_slice).imag, alpha=0.5)
for i, med_est in enumerate(med_ests):
    ax.plot(med_est[0].real, med_est[0].imag, med_est[2], label=med_est[1])

# zoomed in sub region of the original image
axins = zoomed_inset_axes(ax, zoom=6, loc=4)
# axins.scatter(flt_nan(data_slice).real, flt_nan(data_slice).imag, alpha=0.5, c='orange')
for i, med_est in enumerate(med_ests):
    axins.plot(med_est[0].real, med_est[0].imag, med_est[2])

x1 = np.floor(np.min([i[0].real for i in med_ests[:-2]]))
x2 = np.ceil(np.max([i[0].real for i in med_ests[:-2]]))
y1 = np.floor(np.min([i[0].imag for i in med_ests[:-2]]))
y2 = np.ceil(np.max([i[0].imag for i in med_ests[:-2]]))
axins.set_xlim(x1, x2)
axins.set_ylim(y1, y2)

axins.tick_params(axis='x', direction='in', pad=-15)
mark_inset(ax, axins, loc1=1, loc2=3, fc='none', ec='0.5')
axins.patch.set_alpha(0.5)

ax.annotate(tuple(slct_red_bl) + (str(xd_pol),), xy=(0.03, 0.04), \
    xycoords='axes fraction', bbox= dict(boxstyle='round', facecolor='white'))
ax.annotate('Chan: {}, LST: {:.3f}'.format(chans[idxs[0]], np.mean(xd_hr_lsts[[idxs[1], \
    idxs[1]+no_bins_agg-1]])), xy=(0.03, 0.95), xycoords='axes fraction', \
    bbox=dict(boxstyle='round', facecolor='white'))

ax.set_xlabel(re_label)
ax.set_ylabel(im_label)
ax.set_title(textwrap.fill('Bivariate location estimators for redundant '\
    'visibilities aggregated across JDs', 60))

ax.legend(loc=1, prop={'size': 8})
plt.show()

In [None]:
g = sns.jointplot(x=flt_nan(data_slice).real, y=flt_nan(data_slice).imag, \
                  kind='kde', height=8, cmap='Blues', fill=True, space=0)
g.set_axis_labels(re_label, im_label, size=14)
for i, med_est in enumerate(med_ests):
    g.ax_joint.plot(med_est[0].real, med_est[0].imag, med_est[2], label=med_est[1])
legend_properties = {'size': 10}
g.ax_joint.legend(prop=legend_properties, loc='upper right')
g.ax_joint.annotate(tuple(slct_red_bl) + (str(xd_pol),), xy=(0.03, 0.04), \
    xycoords='axes fraction', bbox= dict(boxstyle='round', facecolor='white'), \
    size=12)
g.ax_joint.annotate('Chan: {}, LST: {:.3f}'.format(chans[idxs[0]], np.mean(xd_hr_lsts[[idxs[1], \
    idxs[1]+no_bins_agg-1]])), xy=(0.03, 0.95), xycoords='axes fraction', bbox=dict(boxstyle='round', \
    facecolor='white'), size=12)
plt.tight_layout()
# save_fig_dir = '/Users/matyasmolnar/Dropbox/PhD/Papers/memo_robstat/Figures/'
# plt.savefig(os.path.join(save_fig_dir, 'density_big_diff_geo_hera.pdf'), bbox_inches='tight')
plt.show()

In [None]:
lst_red_res_fn = os.path.join(res_dir, os.path.basename(xd_vis_file).replace('.npz', '.lst_red_res.npz'))

if not os.path.exists(lst_red_res_fn):

    def freq_iter(freq):
        xd_gmed_res_f = np.empty((1, new_no_tints), dtype=complex)
        xd_tmed_res_f, xd_bmed_res_f, xd_hmean_res_f = [np.empty_like(xd_gmed_res_f) for \
                                                        _ in range(3)]

        gmed_ft_init = None
        for tint in range(new_no_tints):
            if lstb_format:
                xd_data_bft = xd_data_bls[:, freq, tint, :].flatten()
            else:
                xd_data_bft = xd_data_bls[:, freq, no_bins_agg*tint:no_bins_agg*tint+no_bins_agg, \
                                          :].flatten()

            if np.isnan(xd_data_bft).all():
                gmed_ft = tmed_ft = bmed_ft = hmean_ft = np.nan + 1j*np.nan
            else:
                gmed_ft = geometric_median(xd_data_bft, init_guess=gmed_ft_init, \
                                           options=dict(keep_res=True))
                gmed_ft_init = gmed_ft
                tmed_ft = tukey_median(xd_data_bft)['barycenter']
                bmed_ft = marg_med(xd_data_bft)
                hmean_ft = rsc_mean(xd_data_bft, sigma=mad_sigma)

            xd_gmed_res_f[:, tint] = gmed_ft
            xd_tmed_res_f[:, tint] = tmed_ft
            xd_bmed_res_f[:, tint] = bmed_ft
            xd_hmean_res_f[:, tint] = hmean_ft

        return xd_gmed_res_f, xd_tmed_res_f, xd_bmed_res_f, xd_hmean_res_f

    if mp:
        m_pool = multiprocessing.Pool(multiprocessing.cpu_count())
        pool_res = m_pool.map(freq_iter, range(no_chans))
        m_pool.close()
        m_pool.join()
    else:
        pool_res = list(map(freq_iter, range(no_chans)))

    loc_res = np.concatenate(pool_res, axis=1)
    xd_gmed_res = loc_res[0, ...]
    xd_tmed_res = loc_res[1, ...]
    xd_bmed_res = loc_res[2, ...]
    xd_hmean_res = loc_res[3, ...]

    np.savez(lst_red_res_fn, xd_gmed_res=xd_gmed_res, xd_tmed_res=xd_tmed_res, \
             xd_bmed_res=xd_bmed_res, xd_hmean_res=xd_hmean_res)

else:
    lst_red_res = np.load(lst_red_res_fn)
    xd_gmed_res = lst_red_res['xd_gmed_res']
    xd_tmed_res = lst_red_res['xd_tmed_res']
    xd_bmed_res = lst_red_res['xd_bmed_res']
    xd_hmean_res = lst_red_res['xd_hmean_res']

In [None]:
arrs = [xd_gmed_res, xd_tmed_res, xd_bmed_res, xd_hmean_res]

tr_arrs = lambda x, np_fn: [getattr(np, np_fn)(i) for i in x]
garrs = [tr_arrs(arrs, 'abs'), tr_arrs(arrs, 'angle'), tr_arrs(arrs, 'real'), tr_arrs(arrs, 'imag')]
garrs = [[arr[i] for arr in garrs] for i in range(len(garrs[0]))]

titles = ['Geometric Median', 'Tukey Median', 'Marginal Median', 'HERA Mean']
ylabels = [r'$|V|$', r'$\varphi$', re_label, im_label]
ylabels = [ylab + '\n\nFrequency Channel' for ylab in ylabels]

grid_heatmaps(garrs, titles=titles, figsize=(12, 10), xbase=tbase, ybase=fbase, clip_pctile=0, \
              xlabels='Time Bin', yticklabels=plt_chans, ylabels=ylabels)

In [None]:
if lstb_format:

    from mpl_toolkits.axes_grid1 import AxesGrid

    fig = plt.figure(figsize=(10, 8), dpi=600)

    grid = AxesGrid(fig, 111, nrows_ncols=(4, 4), axes_pad=0.1, share_all=True, cbar_location='right', \
                    cbar_mode='edge', cbar_size=0.1, cbar_pad=0.15, direction='row', aspect=False)

    ylabels = [r'$|V|$', r'$\varphi$', re_label, im_label]

    for col, arr in enumerate(garrs):
        for row, a in enumerate(arr):
            idx = row*4 + col
            ax = grid[idx]

            if row == 1:
                cmap = 'PiYG'
                vmin = -np.pi
                vmax = np.pi
            else:
                cmap = 'viridis'
                vmin = None
                vmax = None

            im = ax.pcolormesh(np.arange(xd_data.shape[2]), plt_chans, a, rasterized=True, 
                vmin=vmin, vmax=vmax, cmap=cmap)

            if idx % 4 == 0:
                grid.cbar_axes[row].colorbar(im, label=ylabels[row])

            if row == 0:
                ax.set_title(titles[col])

            if row == 3:
                ax.set_xlabel('Time Bin')
                ax.set_xticks([0, 10, 20, 30, 40, 50])

            if col == 0:
                ax.set_ylabel('Frequency Channel')
                ax.set_yticks([520, 560, 600, 640, 680])

    ax.invert_yaxis()

    plt.show()

#### Smoothness of median results

Calculate standard deviation of the distances between successive points in either frequency or time to get an idea of the smoothness of the location results.

##### Standard deviation of absolute distances

In [None]:
# in time
t_smoothness = []
for arr in arrs:
    t_stds = np.empty(arr.shape[0])
    for f in range(arr.shape[0]):
        dists = np.abs(np.ediff1d(arr[f, :]))
        t_stds[f] = np.nanstd(dists)
    t_smoothness.append(np.nanmean(t_stds))
print('Smoothness in time: \n{}\n{}\n'.format(titles, t_smoothness))

# in frequency
f_smoothness = []
for arr in arrs:
    f_stds = np.empty(arr.shape[1])
    for t in range(arr.shape[1]):
        dists = np.abs(np.ediff1d(arr[:, t]))
        f_stds[t] = np.nanstd(dists)
    f_smoothness.append(np.nanmean(f_stds))
print('Smoothness in frequency: \n{}\n{}'.format(titles, f_smoothness))

##### Standard deviation of complex differences

In [None]:
# in time
t_smoothness = []
for arr in arrs:
    t_stds = np.empty(arr.shape[0])
    for f in range(arr.shape[0]):
        dists = np.ediff1d(arr[f, :])
        t_stds[f] = np.nanstd(dists)
    t_smoothness.append(np.nanmean(t_stds))
print('Smoothness in time: \n{}\n{}\n'.format(titles, t_smoothness))

# in frequency
f_smoothness = []
for arr in arrs:
    f_stds = np.empty(arr.shape[1])
    for t in range(arr.shape[1]):
        dists = np.ediff1d(arr[:, t])
        f_stds[t] = np.nanstd(dists)
    f_smoothness.append(np.nanmean(f_stds))
print('Smoothness in frequency: \n{}\n{}'.format(titles, f_smoothness))

#### Biggest difference in geometric median and HERA mean

In [None]:
ok_slice = np.mean(flags, axis=(0, 3))
ok_slice = ok_slice.reshape(no_chans, new_no_tints, -1).mean(axis=-1)
ok_slice = ok_slice < 0.5 # only if less than 50% of flags are flagged

In [None]:
ok_argmax = np.nanargmax(np.abs(xd_gmed_res[ok_slice] - xd_hmean_res[ok_slice]))
ok_slice_idxs = np.where(ok_slice)
bd_idx = (ok_slice_idxs[0][ok_argmax], ok_slice_idxs[1][ok_argmax])

print('Frequency/time slice {} shows strong deviation between the geometric median and the '\
      'HERA mean.'.format(bd_idx))

if lstb_format:
    bd_data = xd_data_bls[:, bd_idx[0], bd_idx[1], :].flatten()
else:
    bd_data = xd_data_bls[:, bd_idx[0], no_bins_agg*bd_idx[1]:no_bins_agg*bd_idx[1]+no_bins_agg, :].flatten()

bd_med_ests = list(zip([xd_gmed_res[bd_idx], xd_hmean_res[bd_idx]], \
                       ['Geometric Median', 'HERA Mean'], \
                       ['ro', 'go']))

g = sns.jointplot(x=flt_nan(bd_data).real, y=flt_nan(bd_data).imag, \
                  kind='kde', height=8, cmap='Blues', fill=True, space=0)
g.set_axis_labels(re_label, im_label, size=14)
for i, med_est in enumerate(bd_med_ests):
    g.ax_joint.plot(med_est[0].real, med_est[0].imag, med_est[2], label=med_est[1])
legend_properties = {'size': 10}
g.ax_joint.legend(prop=legend_properties, loc='upper right')
g.ax_joint.annotate(tuple(slct_red_bl) + (str(xd_pol),), xy=(0.05, 0.05), \
    xycoords='axes fraction', bbox= dict(boxstyle='round', facecolor='white'), \
    size=12)
g.ax_joint.annotate('Chan: {}, Tint: {}'.format(chans[bd_idx[0]], np.arange(new_no_tints)[bd_idx[1]]), \
    xy=(0.05, 0.95), xycoords='axes fraction', bbox= dict(boxstyle='round', facecolor='white'), \
    size=12)
plt.tight_layout()
plt.show()

#### More density plots

To help visualize the data.

In [None]:
# pick good (low flagging) data slices to plot as examples of redundant visibility distributions
ok_idxs = np.where(ok_slice)
index = np.random.choice(ok_idxs[0].shape[0], 2, replace=False)
slice1 = ok_idxs[0][index[0]], ok_idxs[1][index[0]]
slice2 = ok_idxs[0][index[1]], ok_idxs[1][index[1]]

print('Random slices: {} & {}'.format(slice1, slice2))

for rnd_slice in (slice1, slice2):

    if lstb_format:
        bd_data = xd_data_bls[:, rnd_slice[0], rnd_slice[1], :].flatten()
    else:
        bd_data = xd_data_bls[:, rnd_slice[0], no_bins_agg*rnd_slice[1]:no_bins_agg*rnd_slice[1]\
                              +no_bins_agg, :].flatten()

    rnd_med_ests = list(zip([xd_gmed_res[rnd_slice], xd_hmean_res[rnd_slice]], \
                            ['Geometric Median', 'HERA Mean'], \
                            ['ro', 'go']))

    g = sns.jointplot(x=flt_nan(bd_data).real, y=flt_nan(bd_data).imag, \
        kind='kde', height=8, cmap='Blues', fill=True, space=0)
    g.set_axis_labels(re_label, im_label, size=14)
    for i, med_est in enumerate(rnd_med_ests):
        g.ax_joint.plot(med_est[0].real, med_est[0].imag, med_est[2], label=med_est[1])
    legend_properties = {'size': 10}
    g.ax_joint.legend(prop=legend_properties, loc='upper right')
    g.ax_joint.annotate(tuple(slct_red_bl) + (str(xd_pol),), xy=(0.05, 0.05), \
        xycoords='axes fraction', bbox= dict(boxstyle='round', facecolor='white'), \
        size=12)
    g.ax_joint.annotate('Chan: {}, Tint: {}'.format(chans[rnd_slice[0]], \
        np.arange(new_no_tints)[rnd_slice[1]]), xy=(0.05, 0.95), xycoords='axes fraction', \
        bbox= dict(boxstyle='round', facecolor='white'), size=12)
    plt.tight_layout()
    plt.show()

In [None]:
# KDE plots across the H1C_IDR2 JDs for the same baseline group
no_cols = 4
no_rows = int(no_cols*np.ceil(JDs.size/no_cols) / no_cols)

# select indices that have the lowest number of flags 
ft_flag_no = np.isnan(xd_data_bls).sum((0, 3)).reshape((no_chans, -1, no_bins_agg)).sum(axis=-1)
slct_ft_idxs = np.array(np.unravel_index(np.argmin(ft_flag_no), ft_flag_no.shape))
slct_ft_idxs[1] *= no_bins_agg
slct_ft_idxs = tuple(slct_ft_idxs)

pctc = 99
pad = 5

if lstb_format:
    clip_data = xd_data_bls[:, slct_ft_idxs[0], slct_ft_idxs[1], :].flatten()
else:
    clip_data = xd_data_bls[:, slct_ft_idxs[0], slct_ft_idxs[1]:slct_ft_idxs[1]+no_bins_agg, :]

re_lim = (np.floor(np.nanpercentile(clip_data.real, 100-pctc)) - pad, \
          np.nanpercentile(clip_data.real, pctc) + pad)
im_lim = (np.floor(np.nanpercentile(clip_data.imag, 100-pctc)) - pad, \
          np.nanpercentile(clip_data.imag, pctc) + pad)

gplots = []
count = 0
lcount = 0
for row in range(no_rows):
    for col in range(no_cols):
        if (row*no_cols)+col <= JDs.size-1:
            if lstb_format:
                jd_data = xd_data_bls[count, slct_ft_idxs[0], slct_ft_idxs[1], :].flatten()
            else:
                jd_data = xd_data_bls[count, slct_ft_idxs[0], slct_ft_idxs[1]:slct_ft_idxs[1]\
                                      +no_bins_agg, :].flatten()
            g = sns.jointplot(x=flt_nan(jd_data).real, y=flt_nan(jd_data).imag, \
                              kind='kde', height=8, cmap='Blues', fill=True, space=0, \
                              xlim=re_lim, ylim=im_lim)

            if not np.isnan(jd_data).all():
                jd_gmed = geometric_median(jd_data)
                jd_hmean = rsc_mean(jd_data, sigma=mad_sigma)
                g.ax_joint.plot(jd_gmed.real, jd_gmed.imag, 'ro', label='Geometric Median')
                g.ax_joint.plot(jd_hmean.real, jd_hmean.imag, 'go', label='HERA Mean')
                if lcount == 0:
                    g.ax_joint.legend(prop={'size': 7}, loc='upper right')
                lcount += 1

            if count == 0:
                g.ax_joint.annotate(tuple(slct_red_bl) + (str(xd_pol),), xy=(0.05, 0.05), \
                    xycoords='axes fraction', bbox= dict(boxstyle='round', facecolor='white'), \
                    size=8)
                g.ax_joint.annotate('Chan: {}, LST: {:.3f}'.format(chans[slct_ft_idxs[0]], \
                    np.mean(xd_hr_lsts[[slct_ft_idxs[1], slct_ft_idxs[1]+no_bins_agg-1]])), \
                    xy=(0.05, 0.95), xycoords='axes fraction', bbox= dict(boxstyle='round', facecolor='white'), \
                    size=8)
            g.ax_joint.annotate(str(JDs[count]), xy=(0.8, 0.05), \
                xycoords='axes fraction', bbox= dict(boxstyle='round', facecolor='white'), \
                size=8)

            g.set_axis_labels(re_label, im_label, size=8, labelpad=2)

            gplots.append(g)
            count += 1
            plt.close() # suppress individual plots from showing in notebook

fig = plt.figure(figsize=(16, 20))
gs = gridspec.GridSpec(no_rows, no_cols)

for i, gplot in enumerate(gplots):
    _ = SeabornFig2Grid(gplot, fig, gs[i])

gs.tight_layout(fig)
plt.show()

### Test of normality

#### Shapiro-Wilk test

We test the aggregated visibility data (over days, redundant baselines and consecutive time integrations) for normality using the Shapiro-Wilk test, to see if the data is Gaussian distributed for the $\mathfrak{Re}$ and $\mathfrak{Im}$ components separately, thus justifying the use of the mean.

In [None]:
shapiro_w_re = np.empty_like(xd_gmed_res, dtype=float)
shapiro_w_im, shapiro_p_re, shapiro_p_im = [np.empty_like(shapiro_w_re) for _ in range(3)]
for freq in range(no_chans):
    for tint in range(new_no_tints):
        if lstb_format:
            xd_data_bft = flt_nan(xd_data_bls[:, freq, tint, :].flatten())
        else:
            xd_data_bft = flt_nan(xd_data_bls[:, freq, no_bins_agg*tint:no_bins_agg*tint+no_bins_agg, \
                                              :].flatten())
        if np.isnan(xd_data_bft).all():
            re_shapiro_stat = im_shapiro_stat = re_shapiro_pval = re_shapiro_pval = np.nan + 1j*np.nan
        else:
            re_shapiro = shapiro(xd_data_bft.real)
            im_shapiro = shapiro(xd_data_bft.imag)
            re_shapiro_stat = re_shapiro.statistic
            im_shapiro_stat = im_shapiro.statistic
            re_shapiro_pval = re_shapiro.pvalue
            re_shapiro_pval = im_shapiro.pvalue

        shapiro_w_re[freq, tint] = re_shapiro.statistic
        shapiro_w_im[freq, tint] = im_shapiro.statistic
        shapiro_p_re[freq, tint] = re_shapiro.pvalue
        shapiro_p_im[freq, tint] = im_shapiro.pvalue

In [None]:
titles = [[r'$W \; \mathrm{statistic} \; - \; \mathfrak{Re}(V)$', \
          r'$p \; \mathrm{value} \; - \; \mathfrak{Re}(V)$'], \
          [r'$W \; \mathrm{statistic} \; - \; \mathfrak{Im}(V)$', \
          r'$p \; \mathrm{value} \; - \; \mathfrak{Im}(V)$']]
grid_heatmaps([[shapiro_w_re, shapiro_p_re], [shapiro_w_im, shapiro_p_im]], \
             titles=titles, figsize=(14, 7), xbase=tbase, ybase=fbase, share_cbar=True, clip_pctile=1, \
             xlabels='Time bin', yticklabels=plt_chans, ylabels='Frequency Channel')

In [None]:
# example histograms for aggregated visibility data

# picking frequency/time slice with worst shapiro statistic for Re visibilities
re_shap_min = np.unravel_index(np.nanargmin(shapiro_p_re), shapiro_w_re.shape)
print('Slice {} has Shapiro-Wilk test p value {:.5f} for the Re component.\n'\
      .format(re_shap_min, shapiro_p_re[re_shap_min]))

print('If the p value < the chosen alpha level (usually taken to be 0.05), then the null hypothesis '\
      'is rejected and there is evidence that the data tested are not normally distributed')

hist_data = flt_nan(xd_data_bls[:, re_shap_min[0], re_shap_min[1], :])

fig, ax = plt.subplots(ncols=2, figsize=(14, 7))

sns.histplot(hist_data.real, ax=ax[0], binwidth=2.5, kde=True)
sns.histplot(hist_data.imag, ax=ax[1], binwidth=2.5, kde=True)

ax[0].set_xlabel(re_label)
ax[1].set_xlabel(im_label)

plt.show()

#### Henze-Zirkler multivariate normality test

We use the HZ test as this considers the entirety of the data. Note that many alternatives tests also exist and that a single statistic does not definitely conclude if the multivariate data is normality distributed or not. 

##### LST + red aggregation

In [None]:
# MAD-clipping about Re and Im separately, like HERA
nan_flags = np.isnan(xd_data_bls)
re_clip_f = mad_clip(xd_data_bls.real, sigma=mad_sigma, axis=(0, 3), flags=nan_flags, \
                     verbose=True)[1]
im_clip_f = mad_clip(xd_data_bls.imag, sigma=mad_sigma, axis=(0, 3), flags=nan_flags, \
                     verbose=True)[1]

xd_data_bls_c = xd_data_bls.copy()
xd_data_bls_c[re_clip_f + im_clip_f] *= np.nan

In [None]:
mv_nrm_res_fn = os.path.join(res_dir, os.path.basename(xd_vis_file).replace('.npz', '.mv_nrm_res.npz'))

if not os.path.exists(mv_nrm_res_fn):

    hz_r = np.empty_like(shapiro_w_re)
    hz_p = np.empty_like(hz_r)
    hz_n = np.empty_like(hz_r, dtype=bool)

    hz_r_c = np.empty_like(hz_r)
    hz_p_c = np.empty_like(hz_r)
    hz_n_c = np.empty_like(hz_n)

    bool_dict = {'NO': False, 'YES': True, np.nan: False}

    for freq in range(no_chans):
        for tint in range(new_no_tints):
            if lstb_format:
                xd_data_bft = flt_nan(xd_data_bls[:, freq, tint, :].flatten())
                xd_data_bcft = flt_nan(xd_data_bls_c[:, freq, tint, :].flatten())
            else:
                xd_data_bft = flt_nan(xd_data_bls[:, freq, no_bins_agg*tint:no_bins_agg*tint+no_bins_agg, \
                                                  :].flatten())
                xd_data_bcft = flt_nan(xd_data_bls_c[:, freq, no_bins_agg*tint:no_bins_agg*tint+no_bins_agg, \
                                                     :].flatten())

            hz_res = mv_normality(xd_data_bft, method='hz')
            hz_r[freq, tint] = hz_res['HZ']
            hz_p[freq, tint] = hz_res['p value']
            hz_n[freq, tint] = bool_dict[hz_res['MVN']]

            hz_res_c = mv_normality(xd_data_bcft, method='hz')
            hz_r_c[freq, tint] = hz_res_c['HZ']
            hz_p_c[freq, tint] = hz_res_c['p value']
            hz_n_c[freq, tint] = bool_dict[hz_res_c['MVN']]

    np.savez(mv_nrm_res_fn, hz_r=hz_r, hz_p=hz_p, hz_n=hz_n, hz_r_c=hz_r_c, hz_p_c=hz_p_c, \
             hz_n_c=hz_n_c)

else:
    mv_nrm_res = np.load(mv_nrm_res_fn)
    hz_r = mv_nrm_res['hz_r']
    hz_p = mv_nrm_res['hz_p']
    hz_n = mv_nrm_res['hz_n']
    hz_r_c = mv_nrm_res['hz_r_c']
    hz_p_c = mv_nrm_res['hz_p_c']
    hz_n_c = mv_nrm_res['hz_n_c']

In [None]:
titles = [r'$HZ$ Statistic', r'$p$-value', 'Normality']
row_heatmaps([hz_r, hz_p, hz_n], titles=titles, figsize=(14, 7), share_cbar=False, \
             cbar_loc=None, clip_pctile=1, xlabels='Time Bin', ylabel='Frequency Channel', \
             yticklabels=plt_chans, xbase=tbase, ybase=fbase)

In [None]:
# MAD-clipped data
titles = [r'$HZ$ Statistic', r'$p$-value', 'Normality']
row_heatmaps([hz_r_c, hz_p_c, hz_n_c], titles=titles, figsize=(14, 7), share_cbar=False, \
             cbar_loc=None, clip_pctile=1, xlabels='Time Bin', ylabel='Frequency Channel', \
             yticklabels=plt_chans, xbase=tbase, ybase=fbase)

In [None]:
# better grid plot of visibilities
if lstb_format:

    from mpl_toolkits.axes_grid1 import AxesGrid
    from matplotlib.colors import LinearSegmentedColormap

    titles = [r'$HZ$ Statistic', r'$p$-value', 'Normality']

    fig = plt.figure(figsize=(8, 10), dpi=600)

    grid = AxesGrid(fig, 111, nrows_ncols=(2, 3), axes_pad=0.15, share_all=True, cbar_location='bottom', \
                    cbar_mode='edge', cbar_size=0.15, cbar_pad=0.5, direction='row', aspect=False)

    garrs = [[hz_r, hz_p, hz_n], [hz_r_c, hz_p_c, hz_n_c]]

    for row, arr in enumerate(garrs):
        for col, a in enumerate(arr):
            idx = row*3 + col
            ax = grid[idx]

            if col == 0:
                vmin = 0
                vmax = 8
                extend = 'max'
                cmap = 'RdPu'
            elif col == 1:
                vmin = 0
                vmax = 0.6
                extend = 'max'
                cmap = 'RdPu'
            elif col == 2:
                vmin = 0
                vmax = 1
                extend = None
                bool_colors = ((1.0, 0.0, 0.0), (0.0, 0.0, 1.0))
                cmap = LinearSegmentedColormap.from_list('Custom', bool_colors, len(bool_colors))

                if row == 0:
                    ax.annotate('pre-MAD', xy=(0.65, 0.94), xycoords='axes fraction', \
                        bbox=dict(boxstyle='round', facecolor='white', alpha=1))
                elif row == 1:
                    ax.annotate('post-MAD', xy=(0.62, 0.94), xycoords='axes fraction', \
                        bbox=dict(boxstyle='round', facecolor='white', alpha=1))

            im = ax.pcolormesh(np.arange(xd_data.shape[2]), plt_chans, a, rasterized=True, 
                vmin=vmin, vmax=vmax, cmap=cmap)

            if row == 0:
                ax.set_title(titles[col])

            if row == 1:
                ax.set_xlabel('Time Bin')
                ax.set_xticks([0, 10, 20, 30, 40, 50])
                cbar = grid.cbar_axes[col].colorbar(im, extend=extend)
                if col == 2:
                    cbar.set_ticks([0.25,0.75])
                    cbar.set_ticklabels(['False', 'True'])

            if col == 0:
                ax.set_ylabel('Frequency Channel')
                ax.set_yticks([520, 560, 600, 640, 680])            

    ax.invert_yaxis()

    plt.show()

In [None]:
# picking frequency/time slice with worst HZ statistic
hz_p_min = np.unravel_index(np.nanargmin(hz_p), hz_p.shape)
print('Slice {}: Chan {} & LST {:.4f} has HZ test p value {:.5f}.\n'\
      .format(hz_p_min, chans[hz_p_min[0]], avg_hr_lsts[hz_p_min[1]], hz_p[hz_p_min]))

print('If the p value < the chosen alpha level (usually taken to be 0.05), then the null hypothesis '\
      'is rejected and there is evidence that the data tested are not normally distributed')

if lstb_format:
    hz_data = flt_nan(xd_data_bls[:, hz_p_min[0], hz_p_min[1], :].flatten())
else:
    hz_data = flt_nan(xd_data_bls[:, hz_p_min[0], no_bins_agg*hz_p_min[1]:no_bins_agg*hz_p_min[1]\
                      +no_bins_agg, :].flatten())

bhz_med_ests = list(zip([xd_gmed_res[hz_p_min], xd_hmean_res[hz_p_min]], \
                        ['Geometric Median', 'HERA Mean'], \
                        ['ro', 'go']))

g = sns.jointplot(x=hz_data.real, y=hz_data.imag, \
                  kind='kde', height=8, cmap='Blues', fill=True, space=0)
for i, med_est in enumerate(bhz_med_ests):
    g.ax_joint.plot(med_est[0].real, med_est[0].imag, med_est[2], label=med_est[1])
g.set_axis_labels(re_label, im_label, size=14)
legend_properties = {'size': 10}
g.ax_joint.legend(prop=legend_properties, loc='upper right')
g.ax_joint.annotate(tuple(slct_red_bl) + (str(xd_pol),), xy=(0.05, 0.05), \
    xycoords='axes fraction', bbox= dict(boxstyle='round', facecolor='white'), \
    size=12)
g.ax_joint.annotate('Chan: {}, LST: {:.3f}'.format(chans[hz_p_min[0]], avg_hr_lsts[hz_p_min[1]]), \
    xy=(0.05, 0.95), xycoords='axes fraction', bbox=dict(boxstyle='round', facecolor='white'), size=12)
plt.tight_layout()
plt.show()

##### LST

We look at the multivariate normality of data aggregated over days only - we then average this statistic over baselines to get a more complete picture of the dataset.

In [None]:
mv_nrm_lst_res_fn = os.path.join(res_dir, os.path.basename(xd_vis_file).
                                 replace('.npz', '.mv_nrm_lst_res.npz'))

if not os.path.exists(mv_nrm_lst_res_fn):

    hz_lst_r = np.empty((no_chans, new_no_tints, no_bls), dtype=float)
    hz_lst_p = np.empty_like(hz_lst_r)
    hz_lst_n = np.empty_like(hz_lst_r, dtype=bool)

    bool_dict = {'NO': False, 'YES': True, np.nan: False}

    for bl in range(no_bls):
        for freq in range(no_chans):
            for tint in range(new_no_tints):
                if lstb_format:
                    xd_data_bft = flt_nan(xd_data_bls[:, freq, tint, bl].flatten())
                else:
                    xd_data_bft = flt_nan(xd_data_bls[:, freq, no_bins_agg*tint:no_bins_agg*tint+no_bins_agg, \
                                                      bl].flatten())

                hz_res = mv_normality(xd_data_bft, method='hz')
                hz_lst_r[freq, tint, bl] = hz_res['HZ']
                hz_lst_p[freq, tint, bl] = hz_res['p value']
                hz_lst_n[freq, tint, bl] = bool_dict[hz_res['MVN']]

    np.savez(mv_nrm_lst_res_fn, hz_lst_r=hz_lst_r, hz_lst_p=hz_lst_p, hz_lst_n=hz_lst_n)

else:
    mv_nrm_lst_res = np.load(mv_nrm_lst_res_fn)
    hz_lst_r = mv_nrm_lst_res['hz_lst_r']
    hz_lst_p = mv_nrm_lst_res['hz_lst_p']
    hz_lst_n = mv_nrm_lst_res['hz_lst_n']

In [None]:
titles = [r'$HZ \; \mathrm{statistic}$', r'$p \; \mathrm{value}$', 'Normality']
hz_lst_res = [np.nanmean(hz_lst_r, axis=-1), np.nanmean(hz_lst_p, axis=-1), \
              np.nanmean(hz_lst_n.astype(float), axis=-1)]  # mean across baselines

row_heatmaps(hz_lst_res, titles=titles, figsize=(14, 7), share_cbar=False, \
             cbar_loc=None, clip_pctile=1, xlabels='Time Bin', ylabel='Frequency Channel', \
             yticklabels=plt_chans, xbase=tbase, ybase=fbase)

##### Redundant baselines

In [None]:
mv_nrm_red_res_fn = os.path.join(res_dir, os.path.basename(xd_vis_file).
                                 replace('.npz', '.mv_nrm_red_res.npz'))

if not os.path.exists(mv_nrm_red_res_fn):

    hz_red_r = np.empty((no_days, no_chans, new_no_tints), dtype=float)
    hz_red_p = np.empty_like(hz_red_r)
    hz_red_n = np.empty_like(hz_red_r, dtype=bool)

    bool_dict = {'NO': False, 'YES': True, np.nan: False}

    for day in range(no_days):
        for freq in range(no_chans):
            for tint in range(new_no_tints):
                if lstb_format:
                    xd_data_bft = flt_nan(xd_data_bls[day, freq, tint, :].flatten())
                else:
                    xd_data_bft = flt_nan(xd_data_bls[day, freq, no_bins_agg*tint:no_bins_agg*tint+no_bins_agg, \
                                                      :].flatten())

                hz_res = mv_normality(xd_data_bft, method='hz')
                hz_red_r[day, freq, tint] = hz_res['HZ']
                hz_red_p[day, freq, tint] = hz_res['p value']
                hz_red_n[day, freq, tint] = bool_dict[hz_res['MVN']]

    np.savez(mv_nrm_red_res_fn, hz_red_r=hz_red_r, hz_red_p=hz_red_p, hz_red_n=hz_red_n)

else:
    mv_nrm_red_res = np.load(mv_nrm_red_res_fn)
    hz_red_r = mv_nrm_red_res['hz_red_r']
    hz_red_p = mv_nrm_red_res['hz_red_p']
    hz_red_n = mv_nrm_red_res['hz_red_n']

In [None]:
titles = [r'$HZ \; \mathrm{statistic}$', r'$p \; \mathrm{value}$', 'Normality']
hz_red_res = [np.nanmean(hz_red_r, axis=0), np.nanmean(hz_red_p, axis=0), \
              np.nanmean(hz_red_n.astype(float), axis=0)] # mean across days

row_heatmaps(hz_red_res, titles=titles, figsize=(14, 7), share_cbar=False, \
             cbar_loc=None, clip_pctile=1, xlabels='Time Bin', ylabel='Frequency Channel', \
             yticklabels=plt_chans, xbase=tbase, ybase=fbase)

### Multivariate outlier detection

We use the robust Mahalanobis distance to detect outliers in the complex HERA data, as opposed to performing MAD-clipping on the $\mathfrak{Re}$ and$\mathfrak{Im}$ components separately.

#### Slice with worst HZ statistic

In [None]:
mvo_res = mv_outlier(hz_data)
mvo_res.head(5)

In [None]:
# MAD-clipping about Re and Im separately
re_clip_f_mvo = mad_clip(hz_data.real, sigma=mad_sigma, verbose=True)[1]
im_clip_f_mvo = mad_clip(hz_data.imag, sigma=mad_sigma, verbose=True)[1]

mvo_res['MAD-clip'] = re_clip_f_mvo + im_clip_f_mvo

In [None]:
fig, axes = plt.subplots(ncols=3, figsize=(14, 6), sharey=True)
sns.scatterplot(x=hz_data.real, y=hz_data.imag, hue=mvo_res['RS Mahalanobis Distance'], \
                ax=axes[0])
sns.scatterplot(x=hz_data.real, y=hz_data.imag, hue=mvo_res['Outlier'], ax=axes[1])
sns.scatterplot(x=hz_data.real, y=hz_data.imag, hue=mvo_res['MAD-clip'], ax=axes[2])
axes[0].set_ylabel(r'$\mathfrak{Im} \; (V)$')
for i, ax in enumerate(axes):
    ax.set_xlabel(re_label)
    ax.plot(bhz_med_ests[0][0].real, bhz_med_ests[0][0].imag, bhz_med_ests[0][2])
    legend_title = mvo_res.columns.values[1:][i]
    ax.legend(loc='lower right', title=legend_title)
axes[-1].annotate('Geometric Median', xy =(0.65, 0.95), xycoords='axes fraction', color='r', \
                  bbox= dict(boxstyle='round', facecolor='white'))
axes[0].annotate('Chan: {}, LST: {:.3f}'.format(chans[hz_p_min[0]], avg_hr_lsts[hz_p_min[1]]), \
    xy=(0.05, 0.95), xycoords='axes fraction', bbox=dict(boxstyle='round', facecolor='white'), size=12, \
    alpha=0.8)
plt.tight_layout()
plt.show()

In [None]:
# 97.5% quantile of the chi-square distribution is classically taken for outlier threshold
# let's look at a stricter threshold:
chi2_quantile = 0.99
strct_outliers = np.where(mvo_res['RS Mahalanobis Distance'].values \
                          > chi2.ppf(chi2_quantile, 2))[0]
print('Outliers when taking the chi-square quantile to be {}% are:'.format(chi2_quantile*100))
# print(*np.around(hz_data[strct_outliers], decimals=5).tolist(), sep='\n')
for i, s_outlier in enumerate(np.around(hz_data[strct_outliers], decimals=5)):
    print('{:19.5f}'.format(s_outlier), end = '  ' if (i+1) % 5 else '\n')

#### Sifting through the entire dataset

In [None]:
no_dp = xd_data_bls.shape[0]*no_bls*no_bins_agg

mah_out_res_fn = os.path.join(res_dir, os.path.basename(xd_vis_file).replace('.npz', '.mah_out_res.npz'))

if not os.path.exists(mah_out_res_fn):

    mah_outliers = np.empty((no_chans, new_no_tints, no_dp), dtype=bool)

    uf_xd_data = sample_xd_data['data'][..., slct_bl_idxs]

    for freq in range(no_chans):
        for tint in range(new_no_tints):
            if lstb_format:
                xd_data_bft = uf_xd_data[:, chans[freq], tint, :].flatten()
            else:
                xd_data_bft = uf_xd_data[:, freq, no_bins_agg*tint:no_bins_agg*tint+no_bins_agg, \
                                         :].flatten()

            if np.isnan(xd_data_bft).all():
                out_ft = np.empty(no_dp)*np.nan
            else:
                out_ft = mv_outlier(xd_data_bft)['Outlier']

            mah_outliers[freq, tint, :] = out_ft

    np.savez(mah_out_res_fn, mah_outliers=mah_outliers)

else:
    mah_outliers = np.load(mah_out_res_fn)['mah_outliers']

In [None]:
no_outliers = mah_outliers.sum(axis=-1)/no_dp*100
row_heatmaps(no_outliers, clip_pctile=2, xlabels='Time Bin', ylabel='Frequency Channel', \
             titles=['Percentage of of outliers found with the robust Mahalanobis distance '\
             'technique'], yticklabels=plt_chans, xbase=tbase, ybase=fbase)

In [None]:
cal_flags_xdbl = flags.sum(axis=(0, -1))
cal_flags = cal_flags_xdbl.reshape((no_chans, new_no_tints, -1)).sum(axis=-1)

mad_flags_xdbl = (re_clip_f + im_clip_f).sum(axis=(0, -1))
mad_flags = mad_flags_xdbl.reshape((no_chans, new_no_tints, -1)).sum(axis=-1)

comb_flags_xdbl = (flags + re_clip_f + im_clip_f).sum(axis=(0, -1))
comb_flags = comb_flags_xdbl.reshape((no_chans, new_no_tints, -1)).sum(axis=-1)

cal_f_pct = cal_flags / no_dp*100
mad_f_pct = mad_flags / no_dp*100
comb_f_pct = comb_flags / no_dp*100

In [None]:
titles=['Percentage of of flagged data from calibration', \
        'Percentage of of flagged data from MAD-clipping', \
        'Percentage of of flagged data from calibration + MAD-clipping']
titles = [textwrap.fill(t, 40) for t in titles]

row_heatmaps([cal_f_pct, mad_f_pct, comb_f_pct], clip_pctile=2, figsize=(14, 6), \
             titles=titles, xlabels='Time Bin', ylabel='Frequency Channel', yticklabels=plt_chans, \
             xbase=tbase, ybase=fbase)

### Statistical properties of location estimates

#### Visualize data

In [None]:
fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(8, 8), sharey='row', sharex='col')

axes[0][0].plot(np.abs(xd_gmed_res), alpha=0.7)
axes[0][1].plot(np.abs(xd_hmean_res), alpha=0.7)
axes[1][0].plot(np.angle(xd_gmed_res), alpha=0.7)
axes[1][1].plot(np.angle(xd_hmean_res), alpha=0.7)

axes[0][0].set_ylabel(r'$|V|$')
axes[1][0].set_ylabel(r'$\varphi$')

axes[1][0].set_xlabel('Frequency Channel')
axes[1][1].set_xlabel('Frequency Channel')

axes[0][0].set_title('Geometric Median')
axes[0][1].set_title('HERA Mean')

for axr in axes:
    for axc in axr:
        axc.set_xticks(np.arange(plt_chans.size)[::fbase])
        axc.set_xticklabels(plt_chans[::fbase])

plt.tight_layout()
plt.show()

##### Fill in gaps

In [None]:
# grid interpolation to replace nan values
gmed_interp2 = nan_interp2d(xd_gmed_res)
hmean_interp2 = nan_interp2d(xd_hmean_res)

In [None]:
nan_chans = extrem_nans(np.isnan(xd_gmed_res).all(axis=1))
flt_chans = chans.copy()
if nan_chans.size != 0:
    flt_chans = np.delete(flt_chans, nan_chans, axis=0)

### Nonparametric kernel regression

In [None]:
# 1D kernel regression for the visibility amplitudes

kde_abs_gmed = np.empty_like(gmed_interp2, dtype=float)
kde_abs_hmean = np.empty_like(kde_abs_gmed)

for btint in range(gmed_interp2.shape[1]):
    kde_gmed = KernelReg(endog=np.abs(gmed_interp2[:, btint]), exog=flt_chans, \
                         reg_type='ll', var_type='c', bw=[3])
    kde_abs_gmed[:, btint] = kde_gmed.fit(flt_chans)[0]

    kde_hmean = KernelReg(endog=np.abs(hmean_interp2[:, btint]), exog=flt_chans, \
                          reg_type='ll', var_type='c', bw=[3])
    kde_abs_hmean[:, btint] = kde_hmean.fit(flt_chans)[0]

In [None]:
fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(8, 6), sharex='col', sharey='row')

axes[0][0].plot(flt_chans, kde_abs_gmed, alpha=0.7)
axes[0][1].plot(flt_chans, kde_abs_hmean, alpha=0.7)

gmed_residual = np.abs(gmed_interp2) - kde_abs_gmed
axes[1][0].plot(flt_chans, gmed_residual, alpha=0.3)
axes[1][0].plot(flt_chans, gmed_residual.mean(axis=1))

hmean_residual = np.abs(hmean_interp2) - kde_abs_hmean
axes[1][1].plot(flt_chans, hmean_residual, alpha=0.3)
axes[1][1].plot(flt_chans, hmean_residual.mean(axis=1))

axes[1][0].set_xlabel('Frequency Channel')
axes[1][1].set_xlabel('Frequency Channel')
axes[0][0].set_ylabel(r'$|V|$')
axes[1][0].set_ylabel('Residual')

axes[0][0].set_title('Geometric Median KDE')
axes[0][1].set_title('HERA Mean KDE')

for axr in axes:
    for axc in axr:
        axc.set_xticks(flt_chans[::fbase])

plt.tight_layout()
plt.show()

In [None]:
# 3D plot of complex visibilities with geometric median taken across days and baselines
slct_btint = 0
fig = plt.figure(figsize=(8, 6))
ax = plt.axes(projection='3d')
for i, chan in enumerate(range(gmed_interp2.shape[0])):
    ax.scatter(gmed_interp2[chan, :].real, gmed_interp2[chan, :].imag, \
               np.repeat(flt_chans[i], new_no_tints), c=np.arange(new_no_tints), \
               cmap='Greens', alpha=0.5, s=7)
ax.set_xlabel(re_label)
ax.set_ylabel(im_label)
ax.set_zlabel('Frequency Channel')
plt.tight_layout()
plt.show()

#### Allan deviation

In [None]:
f_resolution = np.median(np.ediff1d(freqs))

try:
    import allantools
    t_resolution = np.median(hd.integration_time)

    rate_exp = np.log10(f_resolution)
    tau_min = np.ceil(np.abs(rate_exp))*np.sign(rate_exp)

    taus = np.logspace(tau_min, tau_min+np.ceil(np.log10(gmed_interp2.shape[0])), 1000)

    gmed_ads = np.empty((int(np.ceil(gmed_interp2.shape[0]/2)-1), gmed_interp2.shape[1]))
    hmean_ads = np.empty_like(gmed_ads)

    for btint in range(gmed_interp2.shape[1]):
        # do OADEV on residuals rather than on signal with structure
        gmed_taus2, gmed_ad, gmed_ade, gmed_ns = allantools.oadev(gmed_residual[:, btint], \
            rate=1/f_resolution, data_type='freq', taus=taus)

        hmean_taus2, hmean_ad, hmean_ade, hmean_ns = allantools.oadev(hmean_residual[:, btint], \
            rate=1/f_resolution, data_type='freq', taus=taus)

        gmed_ads[:, btint] = gmed_ad
        hmean_ads[:, btint] = hmean_ad

    fig, ax = plt.subplots(ncols=2, figsize=(10, 6), sharey=True)

    ax[0].loglog(gmed_taus2, gmed_ads, alpha=0.5)
    ax[1].loglog(hmean_taus2, hmean_ads, alpha=0.5)

    ax[0].set_title('Geometric Median')
    ax[1].set_title('HERA Mean')

    ax[0].set_ylabel('Overlapping Allan deviation')
    ax[0].set_xlabel(r'$\tau$')
    ax[1].set_xlabel(r'$\tau$')

    plt.suptitle('Allan deviation')

    plt.show()
    
except ImportError:
    # get AllanTools package here https://github.com/aewallin/allantools
    # or do pip install allantools
    print('AllanTools package not installed - skipping.')

#### Power spectrum

##### Single time integration

In [None]:
gmed_delay, gmed_pspec = signal.periodogram(gmed_interp2[:, 0], fs=1/f_resolution, \
    window='hann', scaling='spectrum', nfft=None, detrend=False, \
    return_onesided=False)

delay_sort = np.argsort(gmed_delay)
gmed_delay = gmed_delay[delay_sort]
gmed_pspec = gmed_pspec[delay_sort]

hmean_delay, hmean_pspec = signal.periodogram(hmean_interp2[:, 0], fs=1./f_resolution, \
    window='hann', scaling='spectrum', nfft=None, detrend=False, \
    return_onesided=False)

delay_sort = np.argsort(hmean_delay)
hmean_delay = hmean_delay[delay_sort]
hmean_pspec = hmean_pspec[delay_sort]

In [None]:
fig, ax = plt.subplots(figsize=(10, 6))

ax.plot(gmed_delay, gmed_pspec, label='Geometric Median', alpha=0.8)
ax.plot(hmean_delay, hmean_pspec, label='HERA Mean', alpha=0.8)

ax.set_yscale('log')
ax.set_ylabel('Power spectrum')
ax.set_xlabel('Delay')
ax.legend(loc='upper right')

plt.show()

##### All time integrations

In [None]:
gmed_delay, gmed_pspec = signal.periodogram(gmed_interp2, fs=1/f_resolution, \
    window='hann', scaling='spectrum', nfft=None, detrend=False, \
    return_onesided=False, axis=0)

delay_sort = np.argsort(gmed_delay)
gmed_delay = gmed_delay[delay_sort]
gmed_pspec = gmed_pspec[delay_sort, :]

hmean_delay, hmean_pspec = signal.periodogram(hmean_interp2, fs=1/f_resolution, \
    window='hann', scaling='spectrum', nfft=None, detrend=False, \
    return_onesided=False, axis=0)

delay_sort = np.argsort(hmean_delay)
hmean_delay = hmean_delay[delay_sort]
hmean_pspec = hmean_pspec[delay_sort, :]

In [None]:
fig, axes = plt.subplots(ncols=3, figsize=(7, 5), sharey=True)

axes[0].plot(gmed_delay*1e6, gmed_pspec, alpha=0.3)
axes[0].plot(gmed_delay*1e6, gmed_pspec.mean(axis=1), alpha=1, color='orange')
axes[0].set_ylabel(r'Power Spectrum [Jy$^2$ Hz$^2$]')

axes[1].plot(hmean_delay*1e6, hmean_pspec, alpha=0.3)
axes[1].plot(hmean_delay*1e6, hmean_pspec.mean(axis=1), alpha=1, color='purple')

axes[2].plot(gmed_delay*1e6, gmed_pspec.mean(axis=1), alpha=0.7, color='orange', label='Geometric Median')
axes[2].plot(hmean_delay*1e6, hmean_pspec.mean(axis=1), alpha=0.7, color='purple', label='HERA Mean')

for ax in axes:
    ax.set_yscale('log')
    ax.set_xlabel('Delay [$\mu$s]')
    ax.set_xticks([-5, -2.5, 0, 2.5, 5])

axes[0].set_title('Geometric Median')
axes[1].set_title('HERA Mean')
axes[2].set_title('Comparison')
axes[2].legend(loc='lower center')

fig.tight_layout()
plt.show()

##### Cross-power spectrum between neighbouring time bins

In [None]:
gmed_interp2_1 = gmed_interp2[:, ::2]
gmed_interp2_2 = gmed_interp2[:, 1::2]

hmean_interp2_1 = hmean_interp2[:, ::2]
hmean_interp2_2 = hmean_interp2[:, 1::2]

gmed_delay, gmed_pspec = signal.csd(gmed_interp2_1, gmed_interp2_2, fs=1/f_resolution, \
    window='hann', scaling='spectrum', nfft=None, detrend=False, \
    nperseg=gmed_interp2_1.shape[0], return_onesided=False, axis=0)

delay_sort = np.argsort(gmed_delay)
gmed_delay = gmed_delay[delay_sort]
gmed_pspec = gmed_pspec[delay_sort, :]

hmean_delay, hmean_pspec = signal.csd(hmean_interp2_1, hmean_interp2_2, fs=1/f_resolution, \
    window='hann', scaling='spectrum', nfft=None, detrend=False, \
    nperseg=hmean_interp2_1.shape[0], return_onesided=False, axis=0)

delay_sort = np.argsort(hmean_delay)
hmean_delay = hmean_delay[delay_sort]
hmean_pspec = hmean_pspec[delay_sort, :]

In [None]:
fig, axes = plt.subplots(ncols=3, figsize=(7, 5), sharey=True)

axes[0].plot(gmed_delay*1e6, np.abs(gmed_pspec), alpha=0.3)
axes[0].plot(gmed_delay*1e6, np.abs(gmed_pspec.mean(axis=1)), alpha=1, color='orange')
axes[0].set_ylabel(r'Cross Power Spectrum [Jy$^2$ Hz$^2$]')

axes[1].plot(hmean_delay*1e6, np.abs(hmean_pspec), alpha=0.3)
axes[1].plot(hmean_delay*1e6, np.abs(hmean_pspec.mean(axis=1)), alpha=1, color='purple')

axes[2].plot(gmed_delay*1e6, np.abs(gmed_pspec.mean(axis=1)), alpha=0.8, color='orange', label='Geometric Median')
axes[2].plot(hmean_delay*1e6, np.abs(hmean_pspec.mean(axis=1)), alpha=0.8, color='purple', label='HERA Mean')

for ax in axes:
    ax.set_yscale('log')
    ax.set_xlabel('Delay [$\mu$s]')
#     ax.set_xticks([-5, -2.5, 0, 2.5, 5])
#     ax.set_ylim(5e-7, 1e4)

axes[0].set_title('Geometric Median')
axes[1].set_title('HERA Mean')
axes[2].set_title('Comparison')
axes[2].legend(loc='lower center')

fig.tight_layout()
plt.show()

### Only average visibilities across days

And further average across baselines post power spectrum computation by computing cross-power spectrum across all baseline permutations

In [None]:
lst_full_res_fn = os.path.join(res_dir, os.path.basename(xd_vis_file).replace('.npz', '.lst_full_res.npz'))

if not os.path.exists(lst_full_res_fn):

    bls_to_do = np.arange(no_bls)[slct_no_bls:]

    def freq_iter(freq):
        xd_gmed_res_bl_f = np.empty((1, new_no_tints, bls_to_do.size), dtype=complex)
        xd_hmean_res_bl_f = np.empty_like(xd_gmed_res_bl_f)

        gmed_ft_init = None
        for b, bl in enumerate(bls_to_do):
            for tint in range(new_no_tints):
                if lstb_format:
                    xd_data_bft = xd_data_bls[:, freq, tint, bl].flatten()
                else:
                    xd_data_bft = xd_data_bls[:, freq, no_bins_agg*tint:no_bins_agg*tint+no_bins_agg, \
                                              bl].flatten()

                if np.isnan(xd_data_bft).all():
                    gmed_ft = hmean_ft = np.nan + 1j*np.nan
                else:
                    gmed_ft = geometric_median(xd_data_bft, init_guess=gmed_ft_init, \
                                               options=dict(keep_res=True))
                    gmed_ft_init = gmed_ft
                    hmean_ft = rsc_mean(xd_data_bft, sigma=mad_sigma)

                xd_gmed_res_bl_f[:, tint, b] = gmed_ft
                xd_hmean_res_bl_f[:, tint, b] = hmean_ft

        return xd_gmed_res_bl_f, xd_hmean_res_bl_f

    if mp:
        m_pool = multiprocessing.Pool(multiprocessing.cpu_count())
        pool_res = m_pool.map(freq_iter, range(no_chans))
        m_pool.close()
        m_pool.join()
    else:
        pool_res = list(map(freq_iter, range(no_chans)))

    loc_res = np.concatenate(pool_res, axis=1)
    xd_gmed_res_bl =  np.concatenate((xd_gmed_res_t, loc_res[0, ...]), axis=-1)
    xd_hmean_res_bl = np.concatenate((xd_hmean_res_t, loc_res[1, ...]), axis=-1)

    np.savez(lst_full_res_fn, xd_gmed_res_bl=xd_gmed_res_bl, xd_hmean_res_bl=xd_hmean_res_bl)

else:
    red_res = np.load(lst_full_res_fn)
    xd_gmed_res_bl = red_res['xd_gmed_res_bl']
    xd_hmean_res_bl = red_res['xd_hmean_res_bl']

# remove baselines with only nan entries
nan_bls = np.where(np.isnan(xd_data_bls).all(axis=(0, 1, 2)))[0]
flt_no_bls = no_bls - nan_bls.size
xd_gmed_res_bl = np.delete(xd_gmed_res_bl, nan_bls, axis=2)
xd_hmean_res_bl = np.delete(xd_hmean_res_bl, nan_bls, axis=2)

In [None]:
# plot the visibility location estimates for a selected time slice
fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(8, 8), sharey='row', sharex='col')

slct_tint = 0

axes[0][0].plot(np.abs(xd_gmed_res_bl[:, slct_tint, :]), alpha=0.7)
axes[0][1].plot(np.abs(xd_hmean_res_bl[:, slct_tint, :]), alpha=0.7)
axes[1][0].plot(np.angle(xd_gmed_res_bl[:, slct_tint, :]), alpha=0.7)
axes[1][1].plot(np.angle(xd_hmean_res_bl[:, slct_tint, :]), alpha=0.7)

axes[0][0].set_ylabel(r'$|V|$')
axes[1][0].set_ylabel(r'$\varphi$')

axes[1][0].set_xlabel('Frequency Channel')
axes[1][1].set_xlabel('Frequency Channel')

axes[0][0].set_title('Geometric Median')
axes[0][1].set_title('HERA Mean')

for axr in axes:
    for axc in axr:
        axc.set_xticks(np.arange(plt_chans.size)[::fbase])
        axc.set_xticklabels(plt_chans[::fbase])

plt.tight_layout()
plt.show()

In [None]:
# 3D plot of visibility amplitudes across all redundant baselines in time and frequency
# look at geometric median estimates
fig = plt.figure(figsize=(8, 6))
ax = plt.axes(projection='3d')
for tint in range(new_no_tints):
    for bl in range(xd_gmed_res_bl.shape[2]):
        ax.plot(chans, np.repeat(tint, chans.size), np.abs(xd_gmed_res_bl[:, tint, bl]))
ax.set_xlabel('Frequency Channel')
ax.set_ylabel('Time bin')
ax.set_zlabel(r'$|V|$')
plt.tight_layout()
plt.show()

In [None]:
# 2D interpolation for each baseline separately
gmed_interp_bl_list = []
hmean_interp_bl_list = []
nan_idxs_f = []
nan_idxs_t = []

for bl in range(flt_no_bls):
    gmed_i, gmed_nidxf, gmed_nidxt = nan_interp2d(xd_gmed_res_bl[..., bl], rtn_nan_idxs=True)
    hmean_i, hmean_nidxf, hmean_nidxt = nan_interp2d(xd_hmean_res_bl[..., bl], rtn_nan_idxs=True)
    gmed_interp_bl_list.append(gmed_i)
    hmean_interp_bl_list.append(hmean_i)
    nan_idxs_f.append(gmed_nidxf)
    nan_idxs_f.append(hmean_nidxf)
    nan_idxs_t.append(gmed_nidxt)
    nan_idxs_t.append(hmean_nidxt)

if np.unique(nan_idxs_f).size != 0:
    gmed_interp_bl_list = [np.delete(gmed_i, np.unique(nan_idxs_f), axis=0) for gmed_i in gmed_interp_bl_list]
    hmean_interp_bl_list = [np.delete(hmean_i, np.unique(nan_idxs_f), axis=0) for hmean_i in hmean_interp_bl_list]

if np.unique(nan_idxs_t).size != 0:
    gmed_interp_bl_list = [np.delete(gmed_i, np.unique(nan_idxs_t), axis=1) for gmed_i in gmed_interp_bl_list]
    hmean_interp_bl_list = [np.delete(hmean_i, np.unique(nan_idxs_t), axis=1) for hmean_i in hmean_interp_bl_list]

gmed_interp2_bl = np.moveaxis(np.array(gmed_interp_bl_list), 0, 2)
hmean_interp2_bl = np.moveaxis(np.array(hmean_interp_bl_list), 0, 2)

In [None]:
# cross-PS between all baseline pairs
bl_pairs = list(itertools.permutations(np.arange(flt_no_bls), r=2))
bls1 = [i[0] for i in bl_pairs]
bls2 = [i[1] for i in bl_pairs]

gmed_delay, gmed_pspec = signal.csd(gmed_interp2_bl[..., bls1], gmed_interp2_bl[..., bls2], \
    fs=1/f_resolution, window='hann', scaling='spectrum', nfft=None, detrend=False, \
    nperseg=gmed_interp2_bl.shape[0], return_onesided=False, axis=0)

delay_sort = np.argsort(gmed_delay)
gmed_delay = gmed_delay[delay_sort]
gmed_pspec = gmed_pspec[delay_sort, :]

hmean_delay, hmean_pspec = signal.csd(hmean_interp2_bl[..., bls1], hmean_interp2_bl[..., bls2], \
    fs=1/f_resolution, window='hann', scaling='spectrum', nfft=None, detrend=False, \
    nperseg=hmean_interp2_bl.shape[0], return_onesided=False, axis=0)

delay_sort = np.argsort(hmean_delay)
hmean_delay = hmean_delay[delay_sort]
hmean_pspec = hmean_pspec[delay_sort, :]

gmed_pspec = np.nanmean(gmed_pspec, axis=2)
hmean_pspec = np.nanmean(hmean_pspec, axis=2)

In [None]:
fig, axes = plt.subplots(ncols=3, figsize=(10, 6), sharey=True)

axes[0].plot(gmed_delay*1e6, np.abs(gmed_pspec), alpha=0.3)
axes[0].plot(gmed_delay*1e6, np.abs(gmed_pspec.mean(axis=1)), alpha=1, color='orange')
axes[0].set_ylabel(r'Power Spectrum [Jy$^2$ Hz$^2$]')

axes[1].plot(hmean_delay*1e6, np.abs(hmean_pspec), alpha=0.3)
axes[1].plot(hmean_delay*1e6, np.abs(hmean_pspec.mean(axis=1)), alpha=1, color='purple')

# average over times
axes[2].plot(gmed_delay*1e6, np.abs(gmed_pspec.mean(axis=1)), alpha=0.6, color='orange', label='Geometric Median')
axes[2].plot(hmean_delay*1e6, np.abs(hmean_pspec.mean(axis=1)), alpha=0.6, color='purple', label='HERA Mean')

for ax in axes:
    ax.set_yscale('log')
    ax.set_xlabel(r'Delay [$\mu$s]')

axes[0].set_title('Geometric Median')
axes[1].set_title('HERA Mean')
axes[2].set_title('Comparison')
axes[2].legend(loc='lower center')
# plt.suptitle('Power spectra')

fig.tight_layout()
plt.show()