<center><strong><font size=+3>Applications of robust 2D median estimators to HERA data</font></center>
<br><br>
</center>
<center><strong><font size=+2>Matyas Molnar and Bojan Nikolic</font><br></strong></center>
<br><center><strong><font size=+1>Astrophysics Group, Cavendish Laboratory, University of Cambridge</font></strong></center>

In [None]:
import os
import sys

import numpy as np
import seaborn as sns
from matplotlib import gridspec
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1.inset_locator import mark_inset, zoomed_inset_axes
from scipy.stats.mstats import gmean as geometric_mean

from hera_cal.io import HERAData
from hera_cal.redcal import get_reds

from robstat.plotting import row_heatmaps
from robstat.robstat import Cmardia_median, geometric_median, mardia_median, mv_median, \
tukey_median
from robstat.stdstat import rsc_mean
from robstat.utils import DATAPATH

In [None]:
plt.rcParams['figure.figsize'] = (12, 8)
%matplotlib inline

In [None]:
plot_figs = False
if plot_figs:
    import matplotlib as mpl
    mpl.rcParams['figure.dpi'] = 300

### Load HERA visibility data

In [None]:
sample_data = os.path.join(DATAPATH, 'zen.2458098.43869.HH.OCRSA.uvh5')

hd = HERAData(sample_data)
data, flags, _ = hd.read()

reds = get_reds(hd.antpos, pols=hd.pols)
flat_bls = [bl for grp in reds for bl in grp if bl in data.keys()]
reds = [grp for grp in reds if set(grp).issubset(flat_bls)]
bl_dict = {k: i for i, k in enumerate(flat_bls)}

data = {k: np.ma.array(v, mask=flags[k], fill_value=np.nan) for k, v \
        in data.items()}
mdata = np.ma.empty((hd.Nfreqs, hd.Ntimes, hd.Nbls), fill_value=np.nan, \
                     dtype=complex)
for i, bl in enumerate(flat_bls):
    mdata[..., i] = data[bl].transpose()
    
data = mdata.filled() # dimensions (freqs, times, bls)
flags = mdata.mask

### Redundant averaging

In [None]:
slct_bls = reds[0]
slct_bl_idxs = np.array([bl_dict[slct_bl] for slct_bl in slct_bls])
slct_data = data[..., slct_bl_idxs]
slct_flags = flags[..., slct_bl_idxs]
assert slct_flags.sum() == np.isnan(slct_data).sum()
print('Looking at baselines redundant to {}'.format(slct_bls[0]))

In [None]:
from robstat.utils import decomposeCArray

In [None]:
# Look at one time integration / frequency slice with high variance
idxs = np.unravel_index(np.nanargmax(np.nanstd(slct_data, axis=-1)), \
                        slct_data.shape[:2])
print('Selecting freq / time slice {}'.format(idxs))
slct_data_slice = slct_data[idxs[0], idxs[1], :]

flt_nan = lambda x: x[~np.isnan(x)]
sample_gmean = geometric_mean(flt_nan(slct_data_slice))
sample_gmed = geometric_median(slct_data_slice, weights=None)
sample_tmed = tukey_median(slct_data_slice)['barycenter']
sample_mmed = Cmardia_median(slct_data_slice)
bad_med = lambda x : np.nanmedian(x.real) + np.nanmedian(x.imag)*1j
sample_bmed = bad_med(slct_data_slice)
sample_hmean = rsc_mean(slct_data_slice)

In [None]:
med_ests = list(zip([sample_gmean, sample_gmed, sample_tmed, sample_mmed, sample_bmed, sample_hmean], 
               ['Geometric Mean', 'Geometric Median', 'Tukey Median', 'Mardia Median', \
                'Separate Median', 'HERA Mean']))
for me in med_ests:
    print('{:17s}: {:4f}'.format(me[1], me[0]))

In [None]:
fig, ax = plt.subplots(figsize=(6, 6), dpi=100)

ax.scatter(slct_data_slice.real, slct_data_slice.imag, alpha=0.5)
ax.plot(sample_gmean.real, sample_gmean.imag, 'ro', label='Geo mean')
ax.plot(sample_gmed.real, sample_gmed.imag, 'co', label='Geo med')
ax.plot(sample_tmed.real, sample_tmed.imag, 'yo', label='Tukey')
ax.plot(sample_mmed.real, sample_mmed.imag, 'ko', label='Mardia')
ax.plot(sample_bmed.real, sample_bmed.imag, 'bo', label='Separate')
ax.plot(sample_hmean.real, sample_hmean.imag, 'go', label='HERA')

ax.annotate(slct_bls[0], xy=(0.05, 0.05), xycoords='axes fraction')
ax.set_xlabel(r'$\mathfrak{Re} \; (V)$')
ax.set_ylabel(r'$\mathfrak{Im}(V)$')

plt.legend()
plt.show()

In [None]:
time_int = np.where(~np.isnan(data).all(axis=(0, 2)))[0][0] # first non-nan index
# perhaps find index with fewest nans?
gmean_res = np.empty((data.shape[0], len(reds)), dtype=complex)
gmed_res, tmed_res, mmed_res, bmed_res, hmean_res = [np.empty_like(gmean_res) for _ in range(5)]

gmed_ij, mmed_ij = None, None
for i, bl_grp in enumerate(reds):
    slct_bl_idxs = np.array([bl_dict[slct_bl] for slct_bl in bl_grp])
    for j, row in enumerate(data[:, time_int, slct_bl_idxs]):
        if np.isnan(row).all():
            gmean_ij = gmed_ij = tmed_ij = mmed_ij = bmed_ij = hmean_ij = np.nan
        else:
            gmean_ij = geometric_mean(flt_nan(row))
            gmed_ij = geometric_median(row, weights=None, init_guess=gmed_ij)
            tmed_ij = tukey_median(row)['barycenter']
            mmed_ij = Cmardia_median(row, init_guess=mmed_ij)
            bmed_ij = bad_med(row)
            hmean_ij = rsc_mean(row)
        gmean_res[j, i] = gmean_ij
        gmed_res[j, i] = gmed_ij
        tmed_res[j, i] = tmed_ij
        mmed_res[j, i] = mmed_ij
        bmed_res[j, i] = bmed_ij
        hmean_res[j, i] = hmean_ij
        
med_est_res = list(zip([i[1] for i in med_ests], \
                  [gmean_res, gmed_res, tmed_res, mmed_res, bmed_res, hmean_res]))

In [None]:
fig = plt.figure(constrained_layout=True, figsize=(10, 20), dpi=100)
spec = gridspec.GridSpec(nrows=2*len(med_ests), figure=fig, ncols=2)

axes = []
for i in range(len(med_ests)):
    ax1 = fig.add_subplot(spec[i*2:2+i*2, 0])
    ax2 = fig.add_subplot(spec[i*2, 1])
    ax3 = fig.add_subplot(spec[i*2+1, 1])
    axes.append([ax1, ax2, ax3])

color = [None for i in med_est_res]
for m, med_est in enumerate(med_est_res):
    for i, bl_grp in enumerate(range(len(reds))):
        axes[m][0].plot(hd.freqs, med_est[1][:, i].real, color=color[m], \
                   label='{}'.format(reds[i][0]) + r' $\mathfrak{Re}$')
        c = axes[m][0].get_lines()[-1].get_color()
        color[m] = next(axes[m][0]._get_lines.prop_cycler)['color']
        axes[m][0].plot(hd.freqs, med_est[1][:, i].imag, color=c, \
                   label='{}'.format(reds[i][0]) + r' $\mathfrak{Im}$', ls='--')
        axes[m][1].plot(hd.freqs, np.abs(med_est[1][:, i]), color=c, \
                   label='{}'.format(reds[i][0]) + r' $|V|$')
        axes[m][2].plot(hd.freqs, np.angle(med_est[1][:, i]), color=c, \
                   label='{}'.format(reds[i][0]) + r' $\varphi$', ls='--')
        axes[m][0].text(x=0.05, y=0.5, s=med_est[0], transform=axes[m][0].transAxes, \
                        fontsize=10, style='normal', weight='light')

for ax in axes:
    ax[0].set_ylabel(r'$V$')
    
for ax in axes[-1]:
    ax.set_xlabel(r'$\nu$')
    
axes[0][0].set_title('Cartesian')
axes[0][1].set_title('Polar')

for ax in axes[0]:
    ax.legend(framealpha=0.5, loc=1)

plt.suptitle('Median estimates for 14-m EW baselines')
plt.show()

### LST + redundant averaging

In [None]:
sample_xd_data = np.load(os.path.join(DATAPATH, 'xd_vis.npz'))

In [None]:
xd_data = sample_xd_data['data'] # dimensions (days, freqs, times, bls)
xd_flags = sample_xd_data['flags']
xd_data[xd_flags] = np.nan

xd_redg = sample_xd_data['redg']
xd_times = sample_xd_data['times']
xd_freqs = sample_xd_data['chans']
xd_days = sample_xd_data['days']
xd_pol = sample_xd_data['pol'].item()

In [None]:
slct_bl_idxs = np.where(xd_redg[:, 0] == 0)[0]
data = xd_data[..., slct_bl_idxs]
flags = xd_flags[..., slct_bl_idxs]
slct_red_bl = xd_redg[slct_bl_idxs[0], :][1:]
print('Looking at baselines redundant to ({}, {}, \'{}\')'.\
      format(*slct_red_bl, xd_pol))

In [None]:
# Look at 2 consecutive time integrations / 1 frequency slice with high variance
idxs = np.unravel_index(np.nanargmax(np.nanstd(data[..., :-1, :], axis=(0, -1))), \
                        data.shape[1:-1])
print('Selecting freq / %time slice: ({}, {}-{})'.format(idxs[0], idxs[1], idxs[1]+1))

# Have visibilities across days for the same baseline - can flatten
# the data array and perform statistics on the whole dataset
data_slice = data[:, idxs[0], idxs[1]:idxs[1]+2, :].flatten()

xd_sample_gmean = geometric_mean(flt_nan(data_slice))
xd_sample_gmed = geometric_median(data_slice, weights=None)
xd_sample_tmed = tukey_median(data_slice)['barycenter']
xd_sample_mmed = Cmardia_median(data_slice)
xd_sample_bmed = bad_med(data_slice)
xd_sample_hmean = rsc_mean(data_slice)

Alternatively, we could take the median of the visibility amplitude and the Mardia median of the phase. While this is an improvement on doing the median on cartesian coordinates separately, it still does not wholly consider the complex data. The geometric median or the Tukey median would be preferable methods.

In [None]:
med_ests = list(zip([xd_sample_gmean, xd_sample_gmed, xd_sample_tmed, xd_sample_mmed, \
                     xd_sample_bmed, xd_sample_hmean], \
               ['Geometric Mean', 'Geometric Median', 'Tukey Median', 'Mardia Median', \
                'Separate Median', 'HERA Mean'], \
               ['ro', 'co', 'yo', 'ko', 'bo', 'go']))
for me in med_ests:
    print('{:17s}: {:4f}'.format(me[1], me[0]))

In [None]:
fig, ax = plt.subplots(figsize=(6, 6), dpi=100)

ax.scatter(flt_nan(data_slice).real, flt_nan(data_slice).imag, alpha=0.5)
for i, med_est in enumerate(med_ests):
    ax.plot(med_est[0].real, med_est[0].imag, med_est[2], label=med_est[1])

# zoomed in sub region of the original image
axins = zoomed_inset_axes(ax, zoom=6, loc=4)
axins.scatter(flt_nan(data_slice).real, flt_nan(data_slice).imag, alpha=0.5)
for i, med_est in enumerate(med_ests):
    axins.plot(med_est[0].real, med_est[0].imag, med_est[2])

x1 = np.floor(np.min([i[0].real for i in med_ests[:-2]]))
x2 = np.ceil(np.max([i[0].real for i in med_ests[:-2]]))
y1 = np.floor(np.min([i[0].imag for i in med_ests[:-2]]))
y2 = np.ceil(np.max([i[0].imag for i in med_ests[:-2]]))
axins.set_xlim(x1, x2)
axins.set_ylim(y1, y2)

axins.tick_params(axis='x', direction='in', pad=-15)
mark_inset(ax, axins, loc1=1, loc2=3, fc='none', ec='0.5')

ax.annotate(tuple(slct_red_bl) + (str(xd_pol),), xy=(0.05, 0.05), \
            xycoords='axes fraction')
ax.set_xlabel(r'$\mathfrak{Re} \; (V)$')
ax.set_ylabel(r'$\mathfrak{Im}(V)$')

ax.legend(loc=1, prop={'size': 10})

plt.show()

In [None]:
g = sns.jointplot(x=flt_nan(data_slice).real, y=flt_nan(data_slice).imag, \
                  kind='kde', height=9, cmap='Blues', fill=True, space=0)
g.set_axis_labels(r'$\mathfrak{Re} \; (V)$', r'$\mathfrak{Im}(V)$', size=14)
for i, med_est in enumerate(med_ests):
    g.ax_joint.plot(med_est[0].real, med_est[0].imag, med_est[2], label=med_est[1])
legend_properties = {'size': 12}
g.ax_joint.legend(prop=legend_properties, loc='upper right')
plt.tight_layout()
plt.show()

In [None]:
xd_gmean_res = np.empty((*xd_data.shape[1:3], len(reds)), dtype=complex)
xd_gmed_res, xd_tmed_res, xd_mmed_res, xd_bmed_res, xd_hmean_res = [np.empty_like(xd_gmean_res) \
                                                                    for _ in range(5)]
gmed_ij, mmed_ij = None, None
for i, bl_grp in enumerate(reds):
    slct_bl_idxs = np.array([bl_dict[slct_bl] for slct_bl in bl_grp])
    xd_data_b = xd_data[..., slct_bl_idxs]
    for freq in range(xd_data_b[:2].shape[1]):
        for tint in range(xd_data_b[:2].shape[2]):
            xd_data_bft = xd_data_b[:, freq, tint, :].flatten()
            if np.isnan(xd_data_bft).all():
                gmean_ij = gmed_ij = tmed_ij = mmed_ij = bmed_ij = hmean_ij = np.nan
            else:
                gmean_ij = geometric_mean(flt_nan(xd_data_bft))
                gmed_ij = geometric_median(xd_data_bft, weights=None, init_guess=gmed_ij)
                tmed_ij = tukey_median(xd_data_bft)['barycenter']
                mmed_ij = Cmardia_median(xd_data_bft, init_guess=mmed_ij)
                bmed_ij = bad_med(xd_data_bft)
                hmean_ij = rsc_mean(xd_data_bft)
            xd_gmean_res[freq, tint, i] = gmean_ij
            xd_gmed_res[freq, tint, i] = gmed_ij
            xd_tmed_res[freq, tint, i] = tmed_ij
            xd_mmed_res[freq, tint, i] = mmed_ij
            xd_bmed_res[freq, tint, i] = bmed_ij
            xd_hmean_res[freq, tint, i] = hmean_ij

In [None]:
bl_grp = 0
arrs = [arr[..., bl_grp] for arr in (xd_gmed_res, xd_tmed_res, xd_bmed_res, xd_hmean_res)]

row_heatmaps(arrs, apply_np_fn='abs', clip_pctile=1)

In [None]:
row_heatmaps(arrs, apply_np_fn='angle', vmin=-np.pi, vmax=np.pi, center=0)

In [None]:
row_heatmaps(arrs, apply_np_fn='real')

In [None]:
row_heatmaps(arrs, apply_np_fn='imag')

In [None]:
# TODO
# find rmeds for all freqs and times separately - heatmap time and freq? check smootheness
# recreate the LST-Bin averaging process - sigma_clip about mean? 
# literature on median of median?