<center><strong><font size=+3>High-pass filtering of HERA data with hera_cal</font></center>
<br><br>
</center>
<center><strong><font size=+2>Matyas Molnar and Bojan Nikolic</font><br></strong></center>
<br><center><strong><font size=+1>Astrophysics Group, Cavendish Laboratory, University of Cambridge</font></strong></center>

High-pass filtering using the functions in [hera_cal](https://github.com/HERA-Team/hera_cal) and [uvtools](https://github.com/HERA-Team/uvtools). We use the DAYENU or CLEAN filter, which removes smooth foregrounds for intensity mapping power spectra ([Ewall-Wice et al. 2020](https://ui.adsabs.harvard.edu/abs/2021MNRAS.500.5195E/abstract)).

In [None]:
import multiprocess as multiprocessing
import os

import matplotlib as mpl
import numpy as np
from matplotlib import pyplot as plt
from scipy import fft, signal

try:
    import hera_cal
    import uvtools
except:
    raise ImportError('Notebook requires the hera_cal and uvtools packages.')

from robstat.ml import extrem_nans
from robstat.utils import DATAPATH

In [None]:
%matplotlib inline

In [None]:
mpl.rcParams['figure.dpi'] = 175
mpl.rcParams['figure.figsize'] = (5, 3)

mpl.rc('font',**{'family':'serif','serif':['cm']})
mpl.rc('text', usetex=True)
mpl.rc('text.latex', preamble=r'\usepackage{amssymb} \usepackage{amsmath}')

### Load visibility data

In [None]:
# xd_vis_file = os.path.join(DATAPATH, 'xd_vis_extd_rph.npz')
xd_vis_file = os.path.join(DATAPATH, 'lstb_no_avg/idr2_lstb_14m_ee_1.40949.npz')
sample_xd_data = np.load(xd_vis_file)

In [None]:
xd_data = sample_xd_data['data'] # dimensions (days, freqs, times, bls)

xd_redg = sample_xd_data['redg']
xd_pol = sample_xd_data['pol'].item()
JDs = sample_xd_data['JDs']

if 'lstb_no_avg' in xd_vis_file:
    xd_flags = np.isnan(xd_data)
    chans = np.arange(xd_data.shape[1])
    freqs = np.linspace(1e8, 2e8, 1025)[:-1]
else:
    xd_flags = sample_xd_data['flags']
    freqs = sample_xd_data['freqs']
    chans = sample_xd_data['chans']

f_resolution = np.median(np.ediff1d(freqs))
no_chans = chans.size
no_days = xd_data.shape[0]
no_tints = xd_data.shape[2]

In [None]:
bl_grp = 0 # only look at 0th baseline group

slct_bl_idxs = np.where(xd_redg[:, 0] == bl_grp)[0]
flags = xd_flags[..., slct_bl_idxs]
slct_red_bl = xd_redg[slct_bl_idxs[0], :][1:]
xd_data_bls = xd_data[..., slct_bl_idxs]
no_bls = slct_bl_idxs.size
print('Looking at baselines redundant to ({}, {}, \'{}\')'.\
      format(*slct_red_bl, xd_pol))

### Example on test data

#### Format and select test data

In [None]:
test_data = xd_data_bls[0, ..., 0].copy()
test_flags = flags[0, ..., 0]

min_nan_idx = np.argmin(test_flags.astype(float).sum(axis=0))
test_data = test_data[:, min_nan_idx]
test_flags = test_flags[:, min_nan_idx]

v = np.logical_not(test_flags) # valid entries

In [None]:
real_lab = r'$\mathfrak{Re} \; (V)$'
imag_lab = r'$\mathfrak{Im} \; (V)$'

fig, ax = plt.subplots()
ax.scatter(freqs[v], test_data.real[v], s=0.5, alpha=0.5, label=real_lab)
ax.scatter(freqs[v], test_data.imag[v], s=0.5, alpha=0.5, label=imag_lab)
for i, nan_chan in enumerate(chans[~v]):
    if i == 0:
        label = 'NaN chan'
    else:
        label = None
    ax.axvline(freqs[np.where(chans==nan_chan)], lw=0.1, ls='--', color='red', alpha=0.5, label=label)
ax.legend(loc='upper right', prop={'size': 6})
ax.set_xlabel('Frequency')
plt.tight_layout()
plt.show()

#### Apply high pass fourier filter

In [None]:
# parameters
filter_centers = [0.] # center of rectangular fourier regions to filter
filter_half_widths = [1e-6] # half-width of rectangular fourier regions to filter
mode = 'clean'

In [None]:
test_data[np.isnan(test_data)] = 0. # data should not have any nans
wgts = np.logical_not(test_flags).astype(float) # real weights where flagged data has 0 weight

filter_kwargs = dict()
if mode != 'clean':
    filter_kwargs['max_contiguous_edge_flags'] = no_chans

d_mdl, d_res, info = uvtools.dspec.fourier_filter(freqs, test_data, wgts, \
    filter_centers, filter_half_widths, mode, filter_dims=1, skip_wgt=0., zero_residual_flags=True, \
    **filter_kwargs)

In [None]:
d_mdl_n = d_mdl.copy()
d_res_n = d_res.copy()
d_mdl_n[~v] *= np.nan
d_res_n[~v] *= np.nan

fig, ax = plt.subplots()
ax.scatter(freqs[v], test_data[v].real, s=0.5, label=real_lab, alpha=0.5)
ax.scatter(freqs[v], test_data[v].imag, s=0.5, label=imag_lab, alpha=0.5)
ax.plot(freqs, d_mdl_n.real, lw=1)
ax.plot(freqs, d_mdl_n.imag, lw=1)
for i, nan_chan in enumerate(chans[~v]):
    if i == 0:
        label = 'NaN chan'
    else:
        label = None
    ax.axvline(freqs[np.where(chans==nan_chan)], lw=0.1, ls='--', color='red', alpha=0.5, label=label)
ax.legend(loc='upper right', prop={'size': 6})
ax.set_xlabel('Frequency')
ax.set_title('Model visibilities', size=8)
plt.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots()
ax.plot(freqs, d_res_n.real, label=real_lab, alpha=0.7, lw=1)
ax.plot(freqs, d_res_n.imag, label=imag_lab, alpha=0.7, lw=1)
ax.legend(loc='upper right', prop={'size': 6})
ax.set_xlabel('Frequency')
ax.set_title('HPF visibilities', size=8)
plt.tight_layout()
plt.show()

#### Choosing a gap free band

In [None]:
gc = np.split(np.where(v)[0], np.where(np.diff(np.where(v)[0]) != 1)[0]+1)
lgap = sorted(gc, key=len, reverse=True)[0][1:]
print('Looking at channels {}-{}'.format(chans[lgap][0], chans[lgap][-1]))

In [None]:
fig, ax = plt.subplots()
ax.plot(freqs[lgap], d_mdl[lgap].real)
ax.plot(freqs[lgap], d_mdl[lgap].imag)
ax.scatter(freqs[lgap], test_data[lgap].real, s=1, alpha=0.7, label=real_lab)
ax.scatter(freqs[lgap], test_data[lgap].imag, s=1, alpha=0.7, label=imag_lab)
ax.set_xlabel('Frequency')
ax.set_title('Model visibilities - selected range', size=8)
ax.legend(loc='best', prop={'size': 6})
plt.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots()
dlys = fft.fftshift(fft.fftfreq(lgap.size, f_resolution))
ax.plot(dlys, np.abs(fft.fftshift(fft.fft(test_data[lgap]))), alpha=0.8, label='Data')
ax.plot(dlys, np.abs(fft.fftshift(fft.fft(d_res[lgap]))), alpha=0.8, label='HPF')
ax.set_xlabel('Delay')
ax.set_ylabel('FFT')
ax.set_yscale('log')
ax.legend(loc='best', prop={'size': 6})
plt.tight_layout()
plt.show()

We note that at low delays, FFTs have a high variance - power spectra are better estimators of power.

In [None]:
dlys, pspec = signal.periodogram(test_data[lgap], fs=1/f_resolution, \
    window='hann', scaling='spectrum', nfft=None, detrend=False, \
    return_onesided=False)

delay_sort = np.argsort(dlys)
dlys = dlys[delay_sort]
td_pspec = pspec[delay_sort]

dlys, pspec = signal.periodogram(d_res[lgap], fs=1/f_resolution, \
    window='hann', scaling='spectrum', nfft=None, detrend=False, \
    return_onesided=False)

delay_sort = np.argsort(dlys)
dlys = dlys[delay_sort]
dr_pspec = pspec[delay_sort]

In [None]:
fig, ax = plt.subplots()
ax.plot(dlys, td_pspec, alpha=0.8, label='Data')
ax.plot(dlys, dr_pspec, alpha=0.8, label='HPF')
ax.set_ylabel('Power spectrum')
ax.set_yscale('log')
ax.set_xlabel('Delay')
ax.legend(loc='best', prop={'size': 6})
plt.tight_layout()
plt.show()

### HPF 2D array

In [None]:
# with trimming of flagged edges

wgts = np.logical_not(flags).astype(float) # real weights where flagged data has 0 weight

sidxs = (0, 0) # sample indices for example case

data_2d = xd_data[sidxs[0], ..., sidxs[1]]
ex_nans = extrem_nans(np.isnan(data_2d).all(axis=1))
s_idxs, e_idxs = np.split(ex_nans, np.where(np.ediff1d(ex_nans) > 1)[0]+1)
s = s_idxs.max() + 1
e = e_idxs.min()

data_2d_tr = data_2d[s:e, :].copy()
flags_2d = xd_flags[sidxs[0], s:e, :, sidxs[1]]
data_2d_tr[flags_2d] = 0.
wgts = np.logical_not(flags_2d).astype(float)
freqs_tr = freqs[s:e]

d_mdl_tr, d_res_tr, info = uvtools.dspec.fourier_filter(freqs_tr, data_2d_tr, wgts, \
    filter_centers, filter_half_widths, mode, filter_dims=0, skip_wgt=0., \
    zero_residual_flags=True, **filter_kwargs)

d_mdl_tr[flags_2d] *= np.nan
d_res_tr[flags_2d] *= np.nan

d_mdl = np.empty_like(data_2d)*np.nan
d_res = d_mdl.copy()
d_mdl[s:e, :] = d_mdl_tr
d_res[s:e, :] = d_res_tr

In [None]:
fig, ax = plt.subplots(ncols=3)
ax[0].imshow(np.abs(data_2d), aspect='auto', interpolation='none')
ax[1].imshow(np.abs(d_mdl), aspect='auto', interpolation='none')
ax[2].imshow(np.abs(d_res), aspect='auto', interpolation='none')
ax[1].yaxis.set_ticklabels([])
ax[2].yaxis.set_ticklabels([])
plt.tight_layout()
plt.show()

### HPF filter entire dataset & save

In [None]:
mp = True # turn on multiprocessing

hpf_vis_file = os.path.join(DATAPATH, xd_vis_file.replace('.npz', '_hpf.npz'))

if not os.path.exists(hpf_vis_file):

    def bl_iter(bl):
        hpf_data_d = np.empty((no_days, no_chans, no_tints), dtype=complex)
        for day in range(no_days):
            data = xd_data_bls[day, ..., bl]
            flgs = flags[day, ..., bl]

            if flgs.all():
                d_res_d = np.empty_like(data) * np.nan
            else:
                ex_nans = extrem_nans(np.isnan(data).all(axis=1))
                s_idxs, e_idxs = np.split(ex_nans, np.where(np.ediff1d(ex_nans) > 1)[0]+1)
                s = s_idxs.max() + 1
                e = e_idxs.min()
                
                data_tr = data[s:e, :].copy()
                flgs_tr = flgs[s:e, :]
                data_tr[flgs_tr] = 0.
                wgts = np.logical_not(flgs_tr).astype(float)
                freqs_tr = freqs[s:e]

                _, d_res_tr, info = uvtools.dspec.fourier_filter(freqs_tr, data_tr, wgts, filter_centers, \
                    filter_half_widths, mode, filter_dims=0, skip_wgt=0., zero_residual_flags=True, \
                    **filter_kwargs)

                d_res_tr[flgs_tr] *= np.nan

                d_res_d = np.empty_like(data)*np.nan
                d_res_d[s:e, :] = d_res_tr

            hpf_data_d[day, ...] = d_res_d

        return  hpf_data_d[..., np.newaxis]
    
    if mp:
        m_pool = multiprocessing.Pool(min(multiprocessing.cpu_count(), no_bls))
        pool_res = m_pool.map(bl_iter, range(no_bls))
        m_pool.close()
        m_pool.join()
    else:
        pool_res = list(map(bl_iter, range(no_bls)))
        
    hpf_data = np.concatenate(pool_res, axis=3)

    hpf_data[flags] *= np.nan
    
    keys = list(sample_xd_data.keys())
    keys.remove('data')
    antpos_in = 'antpos' in keys
    if antpos_in:
        keys.remove('antpos')
    metadata = {k: sample_xd_data[k] for k in keys}
    if antpos_in:
        metadata['antpos'] = np.load(xd_vis_file, allow_pickle=True)['antpos'].item()

    np.savez(hpf_vis_file, data=hpf_data, **metadata)
    
else:
    hpf_data = np.load(hpf_vis_file)['data']

In [None]:
# look at the PS of some sample HPF data
d = hpf_data[0, lgap, :, 0]
d[np.isnan(d)] = 0

dlys, pspec = signal.periodogram(d, fs=1/f_resolution, \
    window='hann', scaling='spectrum', nfft=None, detrend=False, \
    return_onesided=False, axis=0)

delay_sort = np.argsort(dlys)
dlys = dlys[delay_sort]
dr_pspec = pspec[delay_sort]

In [None]:
fig, ax = plt.subplots()
ax.plot(dlys, np.abs(dr_pspec), alpha=0.8, label='HPF')
ax.set_ylabel('Power spectrum')
ax.set_yscale('log')
ax.set_xlabel('Delay')
plt.tight_layout()
plt.show()