<center><strong><font size=+3>Introduction to high-pass filtering of HERA data</font></center>
<br><br>
</center>
<center><strong><font size=+2>Matyas Molnar and Bojan Nikolic</font><br></strong></center>
<br><center><strong><font size=+1>Astrophysics Group, Cavendish Laboratory, University of Cambridge</font></strong></center>

The aim of the game in this notebook is to apply a high-pass filter in delay space on HERA visibilities, and to transform back into the visibility domain for further averaging with robust statistical techniques. This will remove low-order modes (consisting mainly of foregrounds) and may better reconcile redundant and same-LST visibilities such that robust averaging may, in consequence, be more effective.

The transformation to delay space is done on the individually calibrated visibilities, where no averaging across days nor baselines has been done. For this reason, the data has a substantial number of flags and there will not be any clear frequency bands with completely unflagged channels. An unevenly sampled Fourier transform is required, or the data needs to be interpolated.

We explore all of these considerations from first principles, and show some illustrative results.

In [None]:
import itertools
import os

import matplotlib as mpl
import numpy as np
from astropy.timeseries import LombScargle
from matplotlib import pyplot as plt
from scipy import signal
from scipy.fft import fft, fftfreq, fftshift, ifft, ifftshift
from scipy.signal import butter, convolve, hann, sosfilt, sosfreqz

from robstat.ml import extrem_nans, nan_interp2d
from robstat.robstat import geometric_median
from robstat.stdstat import rsc_mean
from robstat.utils import DATAPATH, flt_nan

In [None]:
%matplotlib inline

In [None]:
mpl.rcParams['figure.dpi'] = 175
mpl.rcParams['figure.figsize'] = (5, 3)

### Load HERA dataset

In [None]:
xd_vis_file = os.path.join(DATAPATH, 'xd_vis_extd_rph.npz')
sample_xd_data = np.load(xd_vis_file)

In [None]:
xd_data = sample_xd_data['data'] # dimensions (days, freqs, times, bls)
xd_flags = sample_xd_data['flags']
xd_data[xd_flags] *= np.nan

xd_redg = sample_xd_data['redg']
xd_times = sample_xd_data['times']
xd_pol = sample_xd_data['pol'].item()
JDs = sample_xd_data['JDs']

freqs = sample_xd_data['freqs']
chans = sample_xd_data['chans']

f_resolution = np.median(np.ediff1d(freqs))
no_chans = chans.size

In [None]:
bl_grp = 0 # only look at 0th baseline group

slct_bl_idxs = np.where(xd_redg[:, 0] == bl_grp)[0]
flags = xd_flags[..., slct_bl_idxs]
slct_red_bl = xd_redg[slct_bl_idxs[0], :][1:]
xd_data_bls = xd_data[..., slct_bl_idxs]
no_bls = slct_bl_idxs.size
print('Looking at baselines redundant to ({}, {}, \'{}\')'.\
      format(*slct_red_bl, xd_pol))

### Data preprocessing

In [None]:
# remove baselines with only nan entries
nan_bls = np.where(np.isnan(xd_data_bls).all(axis=(0, 1, 2)))[0]
flt_no_bls = no_bls - nan_bls.size
flt_data = np.delete(xd_data_bls, nan_bls, axis=3)

# remove frequencies at extremities with only nan entries
nan_chans = extrem_nans(np.isnan(flt_data).all(axis=(0, 2, 3)))
if nan_chans.size != 0:
    flt_chans = np.delete(chans, nan_chans)
    flt_freqs = np.delete(freqs, nan_chans)
    flt_data = np.delete(flt_data, nan_chans, axis=1)

In [None]:
# find data slice with fewest flags
min_nan = np.argmin(np.isnan(flt_data).sum(axis=1))
ok = np.unravel_index(min_nan, np.delete(flt_data.shape, 1))
test_data = flt_data[ok[0], :, ok[1], ok[2]]
nans = np.isnan(test_data)
nan_chans = np.isnan(test_data).nonzero()[0]

In [None]:
real_lab = r'$\mathfrak{Re} \; (V)$'
imag_lab = r'$\mathfrak{Im} \; (V)$'

fig, ax = plt.subplots()
ax.scatter(flt_freqs, test_data.real, s=2, label=real_lab)
ax.scatter(flt_freqs, test_data.imag, s=2, label=imag_lab)
for i, nan_chan in enumerate(nan_chans):
    if i == 0:
        label = 'NaN chan'
    else:
        label = None
    ax.axvline(flt_freqs[nan_chan], lw=1, ls='--', color='red', alpha=0.5, label=label)
ax.legend(prop={'size': 6})
ax.set_xlabel('Frequency')
plt.tight_layout()
plt.show()

### Unevenly sampled Fourier transform

#### Lomb-Scargle periodogram

The Lomb-Scargle periodogram can be used for unevenly spaced observations, but only works for real-valued series. Here is an example of the returned power spectrum using visibility amplitudes for some sample data.

In [None]:
dly_lim = 1/f_resolution/2
frequency, power = LombScargle(flt_freqs[~nans], np.abs(test_data[~nans])).autopower(maximum_frequency=dly_lim)

In [None]:
plt.plot(frequency, power)
plt.yscale('log')
plt.xlabel('Delay')
plt.ylabel('Lomb-Scargle power')
plt.tight_layout()
plt.show()

#### Interpolation aross freqs and time

In [None]:
sum_nans = np.isnan(flt_data).sum(axis=(1, 2))
ok = np.unravel_index(sum_nans.argmin(), sum_nans.shape)
test_data_i = flt_data[ok[0], ..., ok[1]]

In [None]:
# percentage of data flagged
pct_flagged = round(np.isnan(flt_data).sum() / flt_data.size * 100, 3)
print('{}% of the considered data is flagged.'.format(pct_flagged))

In [None]:
# cubic interpolation
interp_data, nan_c_idxs, nan_t_idxs = nan_interp2d(test_data_i, kind='cubic', \
                                                   rtn_nan_idxs=True)
interp_data = np.delete(interp_data, nan_c_idxs, axis=0)
interp_data = np.delete(interp_data, nan_t_idxs, axis=1)

flt_chans_i = np.delete(flt_chans, nan_c_idxs)
flt_freqs_i = np.delete(flt_freqs, nan_c_idxs)
flt_tints_i = np.delete(np.arange(xd_times.size), nan_t_idxs)

In [None]:
fig, ax = plt.subplots(nrows=2, figsize=(5, 5), sharex=True)
ax[0].plot(flt_freqs, np.abs(test_data_i), lw=1, alpha=0.5)
ax[1].plot(flt_freqs_i, np.abs(interp_data), lw=1, alpha=0.5)
ax[1].set_xlabel('Frequency')
plt.tight_layout()
plt.show()

In [None]:
delay, pspec = signal.periodogram(interp_data, fs=1/f_resolution, \
    window='hann', scaling='spectrum', nfft=None, detrend=False, \
    return_onesided=False, axis=0)

delay_sort = np.argsort(delay)
delay = delay[delay_sort]
pspec = pspec[delay_sort]

uf_mean_pspec = pspec.mean(axis=1)

In [None]:
plt.figure()
plt.plot(delay, pspec, lw=1, alpha=0.5)
plt.plot(delay, uf_mean_pspec, lw=1, color='blue')
plt.yscale('log')
plt.xlabel('Delay')
plt.tight_layout()
plt.show()

#### High-pass filter with a transfer window in delay space

As if we were doing this from scratch. Do on evenly sampled data such that FFTs can be used here.

##### "Ideal" high-pass filter

In [None]:
delay_cut = 1.5e-6 # to avoid 1 us hump + other foregrounds & low-order Fourier effects
dlys = fftshift(fftfreq(interp_data.shape[0], f_resolution))

# rectangular filter
inv_rect = np.ones(interp_data.shape[0])
zero_filt = np.where(np.abs(dlys) <= delay_cut)
inv_rect[zero_filt] = 0

plt.figure()
plt.plot(dlys, inv_rect)
plt.xlabel('Delay')
plt.tight_layout()
plt.show()

In [None]:
ift_inv_rect = ifft(ifftshift(inv_rect))

plt.figure()
plt.plot(flt_freqs_i, ift_inv_rect.real, label=r'$\mathfrak{Re}$')
plt.plot(flt_freqs_i, ift_inv_rect.imag, label=r'$\mathfrak{Im}$')
plt.xlabel('Frequency')
plt.title('IFFT of rectangular window function', size=8)
plt.legend(prop={'size':6})
plt.tight_layout()
plt.show()

In [None]:
ft = fftshift(fft(interp_data, axis=0), axes=0) # delay transform the data
ft_flt = ft*inv_rect[:, np.newaxis] # apply the window function to zero the low order modes
mod_data = ifft(ifftshift(ft_flt, axes=0), axis=0) # transform back to visibility

In [None]:
plt.figure()
plt.plot(dlys, np.abs(ft), lw=1, alpha=0.5)
plt.yscale('log')
plt.xlabel('Delay')
plt.title('FFT of data', size=8)
plt.tight_layout()
plt.show()

In [None]:
plt.plot(dlys, np.abs(ft_flt), lw=1, alpha=0.5)
plt.yscale('log')
plt.xlabel('Delay')
plt.title('FFT of data with window applied', size=8)
plt.tight_layout()
plt.show()

In [None]:
plt.figure()
plt.plot(flt_freqs_i, np.abs(mod_data), lw=1, alpha=0.5)
plt.plot(flt_freqs_i, np.abs(mod_data.mean(axis=1)), lw=1.5, color='blue')
plt.xlabel('Frequency channel')
plt.title('High-pass filtered visibilities', size=8)
plt.tight_layout()
plt.show()

The rectangular window function used here has sharp edges, which imparts structure back in the visibility domain as seen in the IFFT plot. A smoother function is required.

##### Butterworth transfer window

Also known as the maximally flat magnitude filter - it has an as flat as possible frequency response in the passband.

The “nth” Order Butterworth filter has frequency response:

$$ H(j w) = \frac{H_0}{\sqrt{1 + \left(\frac{j w}{j w_c}\right)^{2n}}} $$

where $n$ is the order of the filter, $w_c$ is the cutoff frequency, $H_0$ is the gain at zero frequency. As $n \rightarrow \infty$, the gain becomes a rectangle function with frequencies below $w_c$ completely suppressed.

In [None]:
sos = signal.butter(5, 1.5e-6, 'high', analog=False, output='sos', fs=1/f_resolution)

w, h = signal.sosfreqz(sos, worN=interp_data.shape[0], fs=1/f_resolution)

In [None]:
fig, ax = plt.subplots(ncols=2)
ax[0].plot(h.real, label=r'$\mathfrak{Re}$')
ax[0].plot(h.imag, label=r'$\mathfrak{Im}$')
ax[1].plot(np.abs(h), label='Amp')
ax[1].plot(np.angle(h), label='Phase')
ax[0].legend(prop={'size':6})
ax[1].legend(prop={'size':6})
plt.tight_layout()
plt.show()

We create a transfer window in both the positive and negative delay regions.

In [None]:
cct_w = np.concatenate((-1*np.flip(w)[:-1], w))
cct_h = np.concatenate((np.flip(h)[:-1], h))

fig, ax = plt.subplots(ncols=2)
ax[0].plot(cct_w, cct_h.real, label=r'$\mathfrak{Re}$')
ax[0].plot(cct_w, cct_h.imag, label=r'$\mathfrak{Im}$')
ax[1].plot(cct_w, np.abs(cct_h), label='Amp')
ax[1].plot(cct_w, np.angle(cct_h), label='Phase')
ax[0].set_xlabel('Delay')
ax[1].set_xlabel('Delay')
ax[0].legend(prop={'size':6})
ax[1].legend(prop={'size':6})
plt.tight_layout()
plt.show()

In [None]:
# need to resample to fit with shape of data array
cct_w_mod = np.concatenate((-np.flip(w[::2][:-1]), np.array([0]), w[::2][:-1]))
cct_h_mod = np.concatenate((np.flip(h[::2][:-1]), np.array([0]), h[::2][:-1]))

In [None]:
plt.figure()
plt.plot(flt_freqs_i, ifft(fftshift(cct_h_mod)).real, label=r'$\mathfrak{Re}$')
plt.plot(flt_freqs_i, ifft(fftshift(cct_h_mod)).imag, label=r'$\mathfrak{Im}$')
plt.xlabel('Frequency')
plt.legend(prop={'size':6})
plt.tight_layout()
plt.show()

In [None]:
ft = fftshift(fft(interp_data, axis=0), axes=0)
ft_flt = ft * cct_h_mod[:, np.newaxis]
mod_data = ifft(ifftshift(ft_flt, axes=0), axis=0)

In [None]:
plt.figure()
plt.plot(flt_freqs_i, np.abs(mod_data), lw=1, alpha=0.5)
plt.plot(flt_freqs_i, np.abs(mod_data.mean(axis=1)), lw=1.5, color='blue')
plt.xlabel('Frequency channel')
plt.title('High-pass filtered visibilities', size=8)
plt.tight_layout()
plt.show()

In [None]:
plt.figure()
plt.plot(dlys, np.abs(ft_flt), lw=1, alpha=0.5)
plt.yscale('log')
plt.xlabel('Delay')
plt.title('FFT of data with window applied', size=8)
plt.tight_layout()
plt.show()

### Robust statistics on delay filtered data

#### Compute location estimates in visibility domain of high-pass filtered data, then take periodogram

In [None]:
flt_no_chans = interp_data.shape[0]

mad_sigma = 4.0 # sigma threshold for MAD-clipping, default is 4

gmed_res = np.empty(flt_no_chans, dtype=complex)
hmean_res = np.empty(flt_no_chans, dtype=complex)

gmed_t = None
for chan in range(flt_no_chans):
    data_t = mod_data[chan, :]
    gmed_t = geometric_median(data_t, init_guess=gmed_t, keep_res=True)
    hmean_t = rsc_mean(data_t, sigma=mad_sigma)
    
    gmed_res[chan] = gmed_t
    hmean_res[chan] = hmean_t

In [None]:
gmed_delay, gmed_pspec = signal.periodogram(gmed_res, fs=1/f_resolution, \
    window='hann', scaling='spectrum', nfft=None, detrend=False, \
    return_onesided=False)

delay_sort = np.argsort(gmed_delay)
gmed_delay = gmed_delay[delay_sort]
gmed_pspec = gmed_pspec[delay_sort]

hmean_delay, hmean_pspec = signal.periodogram(hmean_res, fs=1./f_resolution, \
    window='hann', scaling='spectrum', nfft=None, detrend=False, \
    return_onesided=False)

delay_sort = np.argsort(hmean_delay)
hmean_delay = hmean_delay[delay_sort]
hmean_pspec = hmean_pspec[delay_sort]

In [None]:
plt.figure()
plt.plot(gmed_delay, gmed_pspec, label='Geometric Median', alpha=0.8)
plt.plot(gmed_delay, hmean_pspec, label='HERA Mean', alpha=0.8)
plt.axvline(-delay_cut, ls='--', color='red', alpha=0.5, label='Cut off')
plt.axvline(delay_cut, ls='--', color='red', alpha=0.5)
plt.xlabel('Delay')
plt.yscale('log')
plt.legend(prop={'size': 6})
plt.tight_layout()
plt.show()

In [None]:
# Comparing the PS results from geometric median and HERA mean location estimates
resid = (gmed_pspec - hmean_pspec)*np.abs(cct_h_mod)
print('Mean normalized adjusted residual between geometric mean and HERA mean: {}.'.\
      format(round(resid.mean(), 7)))

plt.figure()
plt.plot(gmed_delay, resid)
plt.axvline(-delay_cut, ls='--', color='red', alpha=0.5, label='Cut off')
plt.axvline(delay_cut, ls='--', color='red', alpha=0.5)
plt.xlabel('Delay')
plt.tight_layout()
plt.show()

In [None]:
hpf_region = np.logical_not(np.abs(dlys) <= delay_cut)

cps_resid = np.mean(gmed_pspec[hpf_region] - uf_mean_pspec[hpf_region])
print('Residual between HPF and UF PS (robust location averaging in visibility domain): '\
      '{:.5e}'.format(cps_resid))

##### Compute CPS across all times for comparison

In [None]:
tint_pairs = list(itertools.permutations(np.arange(interp_data.shape[1]), r=2))
tint1 = [i[0] for i in tint_pairs]
tint2 = [i[1] for i in tint_pairs]

###### HPF data

In [None]:
delay, pspec = signal.csd(mod_data[:, tint1], mod_data[:, tint2], fs=1/f_resolution, \
    window='hann', scaling='spectrum', nfft=None, detrend=False, \
    return_onesided=False, axis=0)

delay_sort = np.argsort(delay)
delay = delay[delay_sort]
pspec = pspec[delay_sort, :]

hpf_mean_cpspec = np.abs(pspec.mean(axis=1))

In [None]:
min_log = np.log10(np.min(hpf_mean_cpspec[hpf_region]))
ymin = 10**(np.floor(min_log))

max_log = np.log10(np.max(hpf_mean_cpspec[hpf_region]))
ymax = 10**(np.ceil(max_log))

plt.figure()
plt.plot(delay, hpf_mean_cpspec, alpha=0.8)
plt.ylim(ymin, ymax)
plt.yscale('log')
plt.xlabel('Delay')
plt.tight_layout()
plt.show()

###### Unfiltered data

In [None]:
delay, pspec = signal.csd(interp_data[:, tint1], interp_data[:, tint2], fs=1/f_resolution, \
    window='hann', scaling='spectrum', nfft=None, detrend=False, \
    return_onesided=False, axis=0)

delay_sort = np.argsort(delay)
delay = delay[delay_sort]
pspec = pspec[delay_sort, :]

uf_mean_cpspec = np.abs(pspec.mean(axis=1))

In [None]:
plt.figure()
plt.plot(delay, uf_mean_cpspec)
plt.yscale('log')
plt.xlabel('Delay')
plt.tight_layout()
plt.show()

In [None]:
cps_resid = np.mean(hpf_mean_cpspec[hpf_region] - uf_mean_cpspec[hpf_region])
print('Residual between HPF and UF mean CPS: {:.5e}'.format(cps_resid))

In [None]:
cps_resid = np.mean(gmed_pspec[hpf_region] - hpf_mean_cpspec[hpf_region])
print('Residual between HPF PS (geometric median estimates) and HPF mean CPS: {:.5e}'.format(cps_resid))

From these residuals, we conclude:
 - HPF decreases the PS for high delays
 - CPS approach is better than that which uses location estimates in visibility space

### FFT with unevenly sampled data

#### DFT

In [None]:
u_test_data = np.delete(test_data_i, extrem_nans(np.isnan(test_data_i).all(axis=1)), axis=0)
u_test_data = np.delete(u_test_data, extrem_nans(np.isnan(test_data_i).all(axis=0)), axis=1)

In [None]:
u_test_data_t = u_test_data[:, 0] # sample tint

plt.figure()
plt.plot(flt_freqs_i, u_test_data_t.real, label=real_lab)
plt.plot(flt_freqs_i, u_test_data_t.imag, label=imag_lab)
plt.xlabel('Frequency')
plt.legend(prop={'size': 6})
plt.tight_layout()
plt.show()

In [None]:
# accelerate with JAX?
def ndft(x, f, N=None, f_res=None):
    '''Non-equispaced discrete Fourier transform'''
    if x.min() <0 or x.max() > 1:
        x = np.interp(x, (x.min(), x.max()), (-0.5, 0.5))
    if N is None:
        N = x.size
    k = np.arange(N) - (N//2)
    if f_res is None:
        f_res = 1
    frqs = fftshift(fftfreq(N, f_res))
    return frqs, np.dot(f, np.exp((-2j * np.pi * k * x[:, np.newaxis])))


def nidft(x, f, N=None, f_res=None, f_start=None):
    '''Non-equispaced inverse discrete Fourier transform'''
    if x.min() <0 or x.max() > 1:
        x = np.interp(x, (x.min(), x.max()), (-0.5, 0.5))
    if N is None:
        N = x.size
    k = np.arange(N) - (N//2)
    if f_res is None:
        f_res = 1
    frqs = fftshift(fftfreq(N, f_res))
    if f_start is not None:
        frqs = np.interp(frqs, (frqs.min(), frqs.max()), (f_start, f_start+frqs.max()*2))
    return frqs, np.dot(f, np.exp((2j * np.pi * k * x[:, np.newaxis]))) / N

In [None]:
ndft_dlys, ndft_t = ndft(flt_freqs_i[~np.isnan(u_test_data_t)], flt_nan(u_test_data_t), \
                         N=u_test_data_t.size, f_res=f_resolution)

In [None]:
plt.figure()
plt.plot(ndft_dlys, np.abs(ndft_t))
plt.yscale('log')
plt.xlabel('Delay')
plt.tight_layout()
plt.show()

In [None]:
nidft_freqs, nidft_t = nidft(ndft_dlys, ndft_t, N=None, \
                             f_res=np.median(np.ediff1d(ndft_dlys)), f_start=flt_freqs_i[0])

In [None]:
plt.figure()
plt.plot(nidft_freqs, nidft_t.real)
plt.plot(nidft_freqs, nidft_t.imag)
for nan_chan in nan_chans:
    plt.axvline(flt_freqs[nan_chan], lw=1.5, ls='--', color='red', alpha=0.5, label='NaN chan')
plt.xlabel('Delay')
plt.tight_layout()
plt.show()

In [None]:
# check that NDFT and NIDFT work correctly
# Butterworth transfer window after DFT