<center><strong><font size=+3>Applications of robust 2D median estimators to HERA data</font></center>
<br><br>
</center>
<center><strong><font size=+2>Matyas Molnar and Bojan Nikolic</font><br></strong></center>
<br><center><strong><font size=+1>Astrophysics Group, Cavendish Laboratory, University of Cambridge</font></strong></center>

Reduced version of the [hera_application](https://github.com/matyasmolnar/robstat/blob/main/notebooks/hera_application.ipynb) notebook that looks at visibility and power spectrum results from the geometric median and MAD-clipped + mean estimator. No R functions are called, and larger datasets are considered.

In [None]:
import itertools
import os

import matplotlib as mpl
import numpy as np
from matplotlib import pyplot as plt
from scipy import signal

from robstat.hera_vis import agg_tint_rephase
from robstat.ml import extrem_nans, nan_interp2d
from robstat.plotting import grid_heatmaps, row_heatmaps
from robstat.robstat import geometric_median, mv_normality, mv_outlier
from robstat.stdstat import mad_clip, rsc_mean
from robstat.utils import DATAPATH, flt_nan

In [None]:
%matplotlib inline

In [None]:
plot_figs = False
if plot_figs:
    mpl.rcParams['figure.figsize'] = (12, 8)
    mpl.rcParams['figure.dpi'] = 300
else:
    mpl.rcParams['figure.figsize'] = (5, 3)
    mpl.rcParams['figure.dpi'] = 125

In [None]:
# params
mad_sigma = 4.0 # sigma threshold for MAD-clipping, default is 4
no_bins_agg = 2 # averaging over n consecutive time bins in LST averaging
# (2 by default, like in HERA analysis pipeline)

### Load HERA visibility data

In [None]:
xd_vis_file = os.path.join(DATAPATH, 'xd_vis_extd_rph.npz')
sample_xd_data = np.load(xd_vis_file)

In [None]:
xd_data = sample_xd_data['data'] # dimensions (days, freqs, times, bls)
xd_flags = sample_xd_data['flags']
xd_data[xd_flags] = np.nan

xd_redg = sample_xd_data['redg']
xd_times = sample_xd_data['times']
xd_pol = sample_xd_data['pol'].item()
xd_rad_lsts = sample_xd_data['lsts']
xd_hr_lsts = xd_rad_lsts*12/np.pi # in hours
avg_hr_lsts = np.mean(xd_hr_lsts.reshape(-1, no_bins_agg), axis=1)
JDs = sample_xd_data['JDs']

freqs = sample_xd_data['freqs']
chans = sample_xd_data['chans']
if chans[-1]%100 == 99:
    plt_chans = np.append(chans, chans[-1]+1)
else:
    plt_chans = chans

no_chans = chans.size
no_days = JDs.size
no_tints = xd_times.size
new_no_tints = int(np.ceil(no_tints/no_bins_agg))

In [None]:
# rephase if averaging over consecutive time bins
if 'rph' in os.path.basename(xd_vis_file) and no_bins_agg > 1:
    print('Rephasing visibilities such that every {} rows in time have the same phase centre.'.format(no_bins_agg))
    xd_antpos = np.load(xd_vis_file, allow_pickle=True)['antpos'].item()
    xd_data = agg_tint_rephase(xd_data, xd_redg, freqs, xd_pol, xd_rad_lsts, xd_antpos, \
                               no_bins_agg=no_bins_agg)

In [None]:
bl_grp = 0 # only look at 0th baseline group

slct_bl_idxs = np.where(xd_redg[:, 0] == bl_grp)[0]
data = xd_data[..., slct_bl_idxs]
flags = xd_flags[..., slct_bl_idxs]
slct_red_bl = xd_redg[slct_bl_idxs[0], :][1:]
xd_data_bls = xd_data[..., slct_bl_idxs]
no_bls = slct_bl_idxs.size
print('Looking at baselines redundant to ({}, {}, \'{}\')'.\
      format(*slct_red_bl, xd_pol))

### LST + redundant averaging

In [None]:
res_dir = os.path.join(DATAPATH, 'loc_res_nrao')
if not os.path.exists(res_dir):
    os.mkdir(res_dir)
    
lst_red_res_fn = os.path.join(res_dir, os.path.basename(xd_vis_file).replace('.npz', '.lst_red_res.npz'))
if not os.path.exists(lst_red_res_fn):

    xd_gmed_res = np.empty((no_chans, new_no_tints), dtype=complex)
    xd_hmean_res = np.empty_like(xd_gmed_res)

    gmed_ij = None
    for freq in range(no_chans):
        for tint in range(new_no_tints):
            # use no_bins_agg time integrations for each median (2 consecutive ones are used in HERA LST-binning)
            xd_data_bft = xd_data_bls[:, freq, no_bins_agg*tint:no_bins_agg*tint+no_bins_agg, :].flatten()
            if np.isnan(xd_data_bft).all():
                gmed_ft = hmean_ft = np.nan
            else:
                gmed_ft = geometric_median(xd_data_bft, init_guess=gmed_ij, keep_res=True)
                hmean_ft = rsc_mean(xd_data_bft, sigma=mad_sigma)
            xd_gmed_res[freq, tint] = gmed_ft
            xd_hmean_res[freq, tint] = hmean_ft
            
    np.savez(lst_red_res_fn, xd_gmed_res=xd_gmed_res, xd_hmean_res=xd_hmean_res)

else:
    lst_red_res = np.load(lst_red_res_fn)
    xd_gmed_res = lst_red_res['xd_gmed_res']
    xd_hmean_res = lst_red_res['xd_hmean_res']

In [None]:
arrs = [xd_gmed_res, xd_hmean_res]

tr_arrs = lambda x, np_fn: [getattr(np, np_fn)(i) for i in x]
garrs = [tr_arrs(arrs, 'abs'), tr_arrs(arrs, 'angle'), tr_arrs(arrs, 'real'), tr_arrs(arrs, 'imag')]
garrs = [[arr[i] for arr in garrs] for i in range(len(garrs[0]))]

titles = ['Geometric Median', 'HERA Mean']
ylabels = ['Amp', 'Phase', r'$\mathfrak{Re}$', r'$\mathfrak{Im}$']
ylabels = [ylab + '\n\nFrequency channel' for ylab in ylabels]

grid_heatmaps(garrs, titles=titles, figsize=(8, 8), ybase=25, clip_pctile=1, \
              xlabels='Time bin', yticklabels=plt_chans, ylabels=ylabels)

#### Smoothness of median results

Calculate standard deviation of the distances between successive points in either frequency or time to get an idea of the smoothness of the location results.

##### Standard deviation of absolute distances

In [None]:
# in time
t_smoothness = []
for arr in arrs:
    t_stds = np.empty(arr.shape[0])
    for f in range(arr.shape[0]):
        dists = np.abs(np.ediff1d(arr[f, :]))
        t_stds[f] = np.nanstd(dists)
    t_smoothness.append(np.nanmean(t_stds))
print('Smoothness in time: \n{}\n{}\n'.format(titles, t_smoothness))

# in frequency
f_smoothness = []
for arr in arrs:
    f_stds = np.empty(arr.shape[1])
    for t in range(arr.shape[1]):
        dists = np.abs(np.ediff1d(arr[:, t]))
        f_stds[t] = np.nanstd(dists)
    f_smoothness.append(np.nanmean(f_stds))
print('Smoothness in frequency: \n{}\n{}'.format(titles, f_smoothness))

##### Standard deviation of complex differences

In [None]:
# in time
t_smoothness = []
for arr in arrs:
    t_stds = np.empty(arr.shape[0])
    for f in range(arr.shape[0]):
        dists = np.ediff1d(arr[f, :])
        t_stds[f] = np.nanstd(dists)
    t_smoothness.append(np.nanmean(t_stds))
print('Smoothness in time: \n{}\n{}\n'.format(titles, t_smoothness))

# in frequency
f_smoothness = []
for arr in arrs:
    f_stds = np.empty(arr.shape[1])
    for t in range(arr.shape[1]):
        dists = np.ediff1d(arr[:, t])
        f_stds[t] = np.nanstd(dists)
    f_smoothness.append(np.nanmean(f_stds))
print('Smoothness in frequency: \n{}\n{}'.format(titles, f_smoothness))

### Test of normality

#### Henze-Zirkler multivariate normality test

We use the HZ test as this considers the entirety of the data. Note that many alternatives tests also exist and that a single statistic does not definitely conclude if the multivariate data is normality distributed or not. 

In [None]:
# MAD-clipping about Re and Im separately, like HERA
nan_flags = np.isnan(xd_data_bls)
re_clip_f = mad_clip(xd_data_bls.real, axis=(0, 3), flags=nan_flags, verbose=True)[1]
im_clip_f = mad_clip(xd_data_bls.imag, axis=(0, 3), flags=nan_flags, verbose=True)[1]

xd_data_bls_c = xd_data_bls.copy()
xd_data_bls_c[re_clip_f + im_clip_f] *= np.nan

In [None]:
mv_nrm_res_fn = os.path.join(res_dir, os.path.basename(xd_vis_file).replace('.npz', '.mv_nrm_res.npz'))

if not os.path.exists(mv_nrm_res_fn):

    hz_r = np.empty_like(xd_gmed_res, dtype=float)
    hz_p = np.empty_like(hz_r)
    hz_n = np.empty_like(hz_r, dtype=bool)

    hz_r_c = np.empty_like(hz_r)
    hz_p_c = np.empty_like(hz_r)
    hz_n_c = np.empty_like(hz_n)

    bool_dict = {'NO': False, 'YES': True, np.nan: False}

    for freq in range(no_chans):
        for tint in range(new_no_tints):
            xd_data_bft = flt_nan(xd_data_bls[:, freq, no_bins_agg*tint:no_bins_agg*tint+no_bins_agg, \
                                              :].flatten())
            xd_data_bcft = flt_nan(xd_data_bls_c[:, freq, no_bins_agg*tint:no_bins_agg*tint+no_bins_agg, \
                                                 :].flatten())

            hz_res = mv_normality(xd_data_bft, method='hz')
            hz_r[freq, tint] = hz_res['HZ']
            hz_p[freq, tint] = hz_res['p value']
            hz_n[freq, tint] = bool_dict[hz_res['MVN']]

            hz_res_c = mv_normality(xd_data_bcft, method='hz')
            hz_r_c[freq, tint] = hz_res_c['HZ']
            hz_p_c[freq, tint] = hz_res_c['p value']
            hz_n_c[freq, tint] = bool_dict[hz_res_c['MVN']]

    np.savez(mv_nrm_res_fn, hz_r=hz_r, hz_p=hz_p, hz_n=hz_n, hz_r_c=hz_r_c, hz_p_c=hz_p_c, \
             hz_n_c=hz_n_c)

else:
    mv_nrm_res = np.load(mv_nrm_res_fn)
    hz_r = mv_nrm_res['hz_r']
    hz_p = mv_nrm_res['hz_p']
    hz_n = mv_nrm_res['hz_n']
    hz_r_c = mv_nrm_res['hz_r_c']
    hz_p_c = mv_nrm_res['hz_p_c']
    hz_n_c = mv_nrm_res['hz_n_c']

In [None]:
titles = [r'$HZ \; \mathrm{statistic}$', r'$p \; \mathrm{value}$', 'Normality']
row_heatmaps([hz_r, hz_p, hz_n], titles=titles, figsize=(8, 4), share_cbar=False, \
             cbar_loc=None, clip_pctile=1, xlabels='Time bin', ylabel='Frequency channel', \
             yticklabels=plt_chans)

In [None]:
# MAD-clipped data
titles = [r'$HZ \; \mathrm{statistic}$', r'$p \; \mathrm{value}$', 'Normality']
row_heatmaps([hz_r_c, hz_p_c, hz_n_c], titles=titles, figsize=(8, 4), share_cbar=False, \
             cbar_loc=None, clip_pctile=1, xlabels='Time bin', ylabel='Frequency channel', \
             yticklabels=plt_chans)

### Interpolation

#### Visualize data

In [None]:
fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(7, 5), sharey='row', sharex='col')

axes[0][0].plot(np.abs(xd_gmed_res), alpha=0.7)
axes[0][1].plot(np.abs(xd_hmean_res), alpha=0.7)
axes[1][0].plot(np.angle(xd_gmed_res), alpha=0.7)
axes[1][1].plot(np.angle(xd_hmean_res), alpha=0.7)

axes[0][0].set_ylabel(r'$|V|$')
axes[1][0].set_ylabel(r'$\varphi$')

axes[1][0].set_xlabel('Frequency channel')
axes[1][1].set_xlabel('Frequency channel')

axes[0][0].set_title('Geometric Median')
axes[0][1].set_title('HERA Mean')

for axr in axes:
    for axc in axr:
        axc.set_xticks(np.arange(plt_chans.size)[::25])
        axc.set_xticklabels(plt_chans[::25])

plt.tight_layout()
plt.show()

#### Fill in gaps

In [None]:
# grid interpolation to replace nan values
gmed_interp2 = nan_interp2d(xd_gmed_res)
hmean_interp2 = nan_interp2d(xd_hmean_res)

In [None]:
nan_chans = extrem_nans(np.isnan(xd_gmed_res).all(axis=1))
flt_chans = chans.copy()
if nan_chans.size != 0:
    flt_chans = np.delete(flt_chans, nan_chans, axis=0)

### Power spectrum

#### All time integrations

In [None]:
f_resolution = np.median(np.ediff1d(freqs))

gmed_delay, gmed_pspec = signal.periodogram(gmed_interp2, fs=1/f_resolution, \
    window='hann', scaling='spectrum', nfft=None, detrend=False, \
    return_onesided=False, axis=0)

delay_sort = np.argsort(gmed_delay)
gmed_delay = gmed_delay[delay_sort]
gmed_pspec = gmed_pspec[delay_sort, :]

hmean_delay, hmean_pspec = signal.periodogram(hmean_interp2, fs=1/f_resolution, \
    window='hann', scaling='spectrum', nfft=None, detrend=False, \
    return_onesided=False, axis=0)

delay_sort = np.argsort(hmean_delay)
hmean_delay = hmean_delay[delay_sort]
hmean_pspec = hmean_pspec[delay_sort, :]

In [None]:
fig, axes = plt.subplots(ncols=3, figsize=(8, 5), sharey=True)

axes[0].plot(gmed_delay, gmed_pspec, alpha=0.3)
axes[0].plot(gmed_delay, gmed_pspec.mean(axis=1), alpha=1, color='orange')
axes[0].set_ylabel('Power spectrum')

axes[1].plot(hmean_delay, hmean_pspec, alpha=0.3)
axes[1].plot(hmean_delay, hmean_pspec.mean(axis=1), alpha=1, color='purple')

axes[2].plot(gmed_delay, gmed_pspec.mean(axis=1), alpha=0.7, color='orange', label='Geometric Median')
axes[2].plot(hmean_delay, hmean_pspec.mean(axis=1), alpha=0.7, color='purple', label='HERA Mean')

for ax in axes:
    ax.set_yscale('log')
    ax.set_xlabel('Delay')
    
axes[0].set_title('Geometric Median')
axes[1].set_title('HERA Mean')
axes[2].set_title('Comparison')
axes[2].legend(loc='best')

plt.suptitle('Power spectra')

plt.show()

#### Cross-power spectrum between neighbouring time bins

In [None]:
gmed_interp2_1 = gmed_interp2[:, ::2]
gmed_interp2_2 = gmed_interp2[:, 1::2]

hmean_interp2_1 = hmean_interp2[:, ::2]
hmean_interp2_2 = hmean_interp2[:, 1::2]

gmed_delay, gmed_pspec = signal.csd(gmed_interp2_1, gmed_interp2_2, fs=1/f_resolution, \
    window='hann', scaling='spectrum', nfft=None, detrend=False, \
    return_onesided=False, axis=0)

delay_sort = np.argsort(gmed_delay)
gmed_delay = gmed_delay[delay_sort]
gmed_pspec = gmed_pspec[delay_sort, :]

hmean_delay, hmean_pspec = signal.csd(hmean_interp2_1, hmean_interp2_2, fs=1/f_resolution, \
    window='hann', scaling='spectrum', nfft=None, detrend=False, \
    return_onesided=False, axis=0)

delay_sort = np.argsort(hmean_delay)
hmean_delay = hmean_delay[delay_sort]
hmean_pspec = hmean_pspec[delay_sort, :]

In [None]:
fig, axes = plt.subplots(ncols=3, figsize=(8, 5), sharey=True)

axes[0].plot(gmed_delay, np.abs(gmed_pspec), alpha=0.3)
axes[0].plot(gmed_delay, np.abs(gmed_pspec.mean(axis=1)), alpha=1, color='orange')
axes[0].set_ylabel('Power spectrum')

axes[1].plot(hmean_delay, np.abs(hmean_pspec), alpha=0.3)
axes[1].plot(hmean_delay, np.abs(hmean_pspec.mean(axis=1)), alpha=1, color='purple')

axes[2].plot(gmed_delay, np.abs(gmed_pspec.mean(axis=1)), alpha=0.8, color='orange', label='Geometric Median')
axes[2].plot(hmean_delay, np.abs(hmean_pspec.mean(axis=1)), alpha=0.8, color='purple', label='HERA Mean')

for ax in axes:
    ax.set_yscale('log')
    ax.set_xlabel('Delay')
    
axes[0].set_title('Geometric Median')
axes[1].set_title('HERA Mean')
axes[2].set_title('Comparison')
axes[2].legend(loc='best')

plt.suptitle('Power spectra')

plt.show()

#### Only average visibilities across days

And further average across baselines post power spectrum computation by computing cross-power spectrum across all baseline permutations

In [None]:
lst_full_res_fn = os.path.join(res_dir, os.path.basename(xd_vis_file).replace('.npz', '.lst_full_res.npz'))

if not os.path.exists(lst_full_res_fn):
    xd_gmed_res_bl = np.empty((no_chans, new_no_tints, no_bls), dtype=complex)
    xd_hmean_res_bl = np.empty_like(xd_gmed_res_bl)

    gmed_ij = None
    for bl in range(no_bls):
        for freq in range(no_chans):
            for tint in range(new_no_tints):
                xd_data_bft = xd_data_bls[:, freq, no_bins_agg*tint:no_bins_agg*tint+no_bins_agg, bl].flatten()
                if np.isnan(xd_data_bft).all():
                    gmed_ft = hmean_ft = np.nan
                else:
                    gmed_ft = geometric_median(xd_data_bft, init_guess=gmed_ij, \
                                               keep_res=True)
                    hmean_ft = rsc_mean(xd_data_bft, sigma=mad_sigma)
                xd_gmed_res_bl[freq, tint, bl] = gmed_ft
                xd_hmean_res_bl[freq, tint, bl] = hmean_ft
                
    np.savez(lst_full_res_fn, xd_gmed_res_bl=xd_gmed_res_bl, xd_hmean_res_bl=xd_hmean_res_bl)

else:
    red_res = np.load(lst_full_res_fn)
    xd_gmed_res_bl = red_res['xd_gmed_res_bl']
    xd_hmean_res_bl = red_res['xd_hmean_res_bl']
    
# remove baselines with only nan entries
nan_bls = np.where(np.isnan(xd_data_bls).all(axis=(0, 1, 2)))[0]
flt_no_bls = no_bls - nan_bls.size
xd_gmed_res_bl = np.delete(xd_gmed_res_bl, nan_bls, axis=2)
xd_hmean_res_bl = np.delete(xd_hmean_res_bl, nan_bls, axis=2)

In [None]:
# plot the visibility location estimates for a selected time slice
fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(7, 5), sharey='row', sharex='col')

slct_tint = 0

axes[0][0].plot(np.abs(xd_gmed_res_bl[:, slct_tint, :]), alpha=0.7)
axes[0][1].plot(np.abs(xd_hmean_res_bl[:, slct_tint, :]), alpha=0.7)
axes[1][0].plot(np.angle(xd_gmed_res_bl[:, slct_tint, :]), alpha=0.7)
axes[1][1].plot(np.angle(xd_hmean_res_bl[:, slct_tint, :]), alpha=0.7)

axes[0][0].set_ylabel(r'$|V|$')
axes[1][0].set_ylabel(r'$\varphi$')

axes[1][0].set_xlabel('Frequency channel')
axes[1][1].set_xlabel('Frequency channel')

axes[0][0].set_title('Geometric Median')
axes[0][1].set_title('HERA Mean')

for axr in axes:
    for axc in axr:
        axc.set_xticks(np.arange(plt_chans.size)[::25])
        axc.set_xticklabels(plt_chans[::25])

plt.tight_layout()
plt.show()

In [None]:
# 2D interpolation for each baseline separately
gmed_interp_bl_list = []
hmean_interp_bl_list = []
nan_idxs_f = []
nan_idxs_t = []

for bl in range(flt_no_bls):
    gmed_i, gmed_nidxf, gmed_nidxt = nan_interp2d(xd_gmed_res_bl[..., bl], rtn_nan_idxs=True)
    hmean_i, hmean_nidxf, hmean_nidxt = nan_interp2d(xd_hmean_res_bl[..., bl], rtn_nan_idxs=True)
    gmed_interp_bl_list.append(gmed_i)
    hmean_interp_bl_list.append(hmean_i)
    nan_idxs_f.append(gmed_nidxf)
    nan_idxs_f.append(hmean_nidxf)
    nan_idxs_t.append(gmed_nidxt)
    nan_idxs_t.append(hmean_nidxt)
    
if np.unique(nan_idxs_f).size != 0:
    gmed_interp_bl_list = [np.delete(gmed_i, np.unique(nan_idxs_f), axis=0) for gmed_i in gmed_interp_bl_list]
    hmean_interp_bl_list = [np.delete(hmean_i, np.unique(nan_idxs_f), axis=0) for hmean_i in hmean_interp_bl_list]
    
if np.unique(nan_idxs_t).size != 0:
    gmed_interp_bl_list = [np.delete(gmed_i, np.unique(nan_idxs_t), axis=1) for gmed_i in gmed_interp_bl_list]
    hmean_interp_bl_list = [np.delete(hmean_i, np.unique(nan_idxs_t), axis=1) for hmean_i in hmean_interp_bl_list]
    
gmed_interp2_bl = np.moveaxis(np.array(gmed_interp_bl_list), 0, 2)
hmean_interp2_bl = np.moveaxis(np.array(hmean_interp_bl_list), 0, 2)

In [None]:
# cross-PS between all baseline pairs
bl_pairs = list(itertools.permutations(np.arange(flt_no_bls), r=2))
bls1 = [i[0] for i in bl_pairs]
bls2 = [i[1] for i in bl_pairs]

gmed_delay, gmed_pspec = signal.csd(gmed_interp2_bl[..., bls1], gmed_interp2_bl[..., bls2], \
    fs=1/f_resolution, window='hann', scaling='spectrum', nfft=None, detrend=False, \
    return_onesided=False, axis=0)

delay_sort = np.argsort(gmed_delay)
gmed_delay = gmed_delay[delay_sort]
gmed_pspec = gmed_pspec[delay_sort, :]

hmean_delay, hmean_pspec = signal.csd(hmean_interp2_bl[..., bls1], hmean_interp2_bl[..., bls2], \
    fs=1/f_resolution, window='hann', scaling='spectrum', nfft=None, detrend=False, \
    return_onesided=False, axis=0)

delay_sort = np.argsort(hmean_delay)
hmean_delay = hmean_delay[delay_sort]
hmean_pspec = hmean_pspec[delay_sort, :]

gmed_pspec = np.nanmean(gmed_pspec, axis=2)
hmean_pspec = np.nanmean(hmean_pspec, axis=2)

In [None]:
fig, axes = plt.subplots(ncols=3, figsize=(8, 5), sharey=True)

axes[0].plot(gmed_delay, np.abs(gmed_pspec), alpha=0.3)
axes[0].plot(gmed_delay, np.abs(gmed_pspec.mean(axis=1)), alpha=1, color='orange')
axes[0].set_ylabel('Power spectrum')

axes[1].plot(hmean_delay, np.abs(hmean_pspec), alpha=0.3)
axes[1].plot(hmean_delay, np.abs(hmean_pspec.mean(axis=1)), alpha=1, color='purple')

# average over times
axes[2].plot(gmed_delay, np.abs(gmed_pspec.mean(axis=1)), alpha=0.6, color='orange', label='Geometric Median')
axes[2].plot(hmean_delay, np.abs(hmean_pspec.mean(axis=1)), alpha=0.6, color='purple', label='HERA Mean')

for ax in axes:
    ax.set_yscale('log')
    ax.set_xlabel('Delay')
    
axes[0].set_title('Geometric Median')
axes[1].set_title('HERA Mean')
axes[2].set_title('Comparison')
axes[2].legend(loc='best')

plt.suptitle('Power spectra')

plt.show()