<center><strong><font size=+3>Wavelet Power Spectrum Analysis in Napari</font></center>
<br><br>
</center>
<center><strong><font size=+2>Matyas Molnar and Bojan Nikolic</font><br></strong></center>
<br><center><strong><font size=+1>Astrophysics Group, Cavendish Laboratory, University of Cambridge</font></strong></center>

### View CWT products in napari

In [None]:
import os

import numpy as np
from astropy.stats import mad_std, sigma_clip
from matplotlib import pyplot as plt
from matplotlib.colors import ListedColormap, LogNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable

from scipy import signal

import pywt
import scaleogram as scg

from hera_cal.io import HERAData

In [None]:
from matplotlib import rc
rc('font',**{'family':'serif','serif':['cm']})
rc('text', usetex=True)
rc('text.latex', preamble=r'\usepackage{amssymb} \usepackage{amsmath}')

In [None]:
%matplotlib inline

In [None]:
npz_f1 = np.load('cwt_power_b1.npz')
npz_f2 = np.load('cwt_power_b2.npz')
data1 = npz_f1['power']
data2 = npz_f2['power']

In [None]:
lsts = npz_f1['lsts']
redg = npz_f1['redg']
chans1 = npz_f1['freqs']
chans2 = npz_f2['freqs']
wavelet = npz_f1['wavelet'].item()
scales = npz_f1['scales']

In [None]:
lstb_dir = '/lustre/aoc/projects/hera/H1C_IDR2/IDR2_2/LSTBIN/one_group/grp1'
if not os.path.exists(lstb_dir):
    lstb_dir = '/Users/matyasmolnar/Downloads/HERA_Data/sample_data/'

zen_lstb = os.path.join(lstb_dir, 'zen.grp1.of1.LST.1.31552.HH.OCRSL.uvh5')

hd = HERAData(zen_lstb)

In [None]:
# freqs = np.linspace(1e8, 2e8, 1024+1)[:-1]
# freq_resolution = np.median(np.ediff1d(freqs))

freqs = hd.freqs
freq_resolution = hd.channel_width

In [None]:
# # ant filt to make more manageable
# data1 = data1[..., 100:200]
# data2 = data2[..., 100:200]

In [None]:
band_1 = [175, 334]
band_2 = [515, 694]

field_1 = [1.25, 2.70]
field_2 = [4.50, 6.50]
field_3 = [8.50, 10.75]

In [None]:
f1 = np.where((lsts > field_1[0]) & (lsts < field_1[1]))[0]
f2 = np.where((lsts > field_2[0]) & (lsts < field_2[1]))[0]
f3 = np.where((lsts > field_3[0]) & (lsts < field_3[1]))[0]

In [None]:
print(f1, f2, f3, sep='\n\n')

In [None]:
antsep = {}
for row in redg:
    antsep[tuple(row[1:])] = np.abs(hd.antpos[row[2]] - hd.antpos[row[1]])

proj_ew = 14
# baselines with projected EW length < 14 m
nan_bls = [i for i, (k, v) in enumerate(antsep.items()) if v[0] < proj_ew]
ok_bls = [i for i, (k, v) in enumerate(antsep.items()) if v[0] > proj_ew]

In [None]:
data1[..., nan_bls] *= np.nan
data2[..., nan_bls] *= np.nan

In [None]:
if False:
    
    import napari

    # get two bands side by side
    # careful because B1 and B2 will have different scales..
    data_m = np.concatenate((data1, data2), axis=1)
    # data_m = data2

    # only look at times from Field 2
    # in napari - set auto-constrast to "once", and adjust contrast limits & gamma
    # found that FPS 8 works well, with play mode "back and forth" (for time axis)

    viewer = napari.view_image(np.log(data_m[:, :, :, :]), colormap='turbo', ndisplay=2, order=(2, 3, 0, 1), \
                               gamma=1, interpolation='nearest', scale=(8, 1, 1, 1))

### Automatic detection

In [None]:
# TODO
# same vlims for saved figs

In [None]:
SELECT_BAND = 'Band 2' # choose band

In [None]:
if SELECT_BAND == 'Band 1':
    sb = 0
    b_freqs = freqs[chans1]
    cwt_data = data1
    mad_clip_thresh = 2e-1
    
    
if SELECT_BAND == 'Band 2':
    sb = 1
    b_freqs = freqs[chans2]
    cwt_data = data2
    mad_clip_thresh = 7e-2

In [None]:
axis = (2, 3)  # times and baselines
# compute median and MAD across times and baselines
med = np.nanmedian(cwt_data, axis=axis)  # shape (scales, freqs)
mad = mad_std(cwt_data, axis=axis, ignore_nan=True)  # shape (scales, freqs)

# find dimensions to tile over
ex_dims = np.ones(cwt_data.ndim, dtype=int)
for ax in axis:
    ex_dims[ax] = cwt_data.shape[ax]
    
# tile over dimensions that were averaged over
tile_loc = np.tile(np.expand_dims(med, axis=axis), ex_dims)
tile_scale = np.tile(np.expand_dims(mad, axis=axis), ex_dims)

# calculate modified Z score
modz = (cwt_data - tile_loc) / tile_scale

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(7.5, 4), dpi=125, sharey=True)

extent = [b_freqs[0]/1e6, b_freqs[-1]/1e6, med.shape[0]+0.5, 0.5]

norm = LogNorm()
im1 = axes[0].imshow(med, aspect='auto', interpolation='none', cmap='jet', norm=norm, \
                     extent=extent)
im2 = axes[1].imshow(mad, aspect='auto', interpolation='none', cmap='jet', norm=norm, \
                     extent=extent)

divider = make_axes_locatable(axes[0])
cax1 = divider.append_axes('right', size='5%', pad=0.05)   
plt.colorbar(im1, cax=cax1)

divider = make_axes_locatable(axes[1])
cax2 = divider.append_axes('right', size='5%', pad=0.05)   
plt.colorbar(im2, cax=cax2)

axes[0].set_ylabel('Wavelet scale')
axes[0].set_xlabel('Frequency [MHz]')
axes[1].set_xlabel('Frequency [MHz]')

axes[0].set_title('Median')
axes[1].set_title('MAD')

plt.tight_layout()
plt.show()

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(8, 4), dpi=125)

im1 = axes[0].pcolormesh(b_freqs/1e6, scales+0.5, med, norm=LogNorm(), cmap='jet')
axes[0].invert_yaxis()
axes[0].set_yscale('log')

xtk = np.linspace(round(b_freqs[0], -6), round(b_freqs[-1], -6), 10)
intticks = (xtk[1:-1]//1e6).astype(int)

axes[0].set_xticks(intticks)
axes[0].set_xticklabels(intticks)

divider = make_axes_locatable(axes[0])
cax1 = divider.append_axes('right', size='5%', pad=0.1)
plt.colorbar(im1, cax=cax1)

im2 = axes[1].pcolormesh(b_freqs/1e6, scales+0.5, mad, norm=LogNorm(), cmap='jet')
axes[1].invert_yaxis()
axes[1].set_yscale('log')
axes[1].set_xticks(intticks)
axes[1].set_xticklabels(intticks)

divider = make_axes_locatable(axes[1])
cax2 = divider.append_axes('right', size='5%', pad=0.1)
plt.colorbar(im2, cax=cax2, label=r'')

axes[0].set_title('Median')
axes[0].set_xlabel('Frequency [MHz]')
axes[1].set_title('MAD')
axes[1].set_xlabel('Frequency [MHz]')
axes[0].set_ylabel('Scale')
axes[1].set_ylabel('Scale')
axes[1].yaxis.label.set_color('white')

# axes[0].plot(scales*hd.channel_width*np.sqrt(2)/1e6+b_freqs[0]/1e6, scales, c='black')

fig.tight_layout()

# save_fig_dir = '/lustre/aoc/projects/hera/mmolnar/figs'
# plt.savefig(os.path.join(save_fig_dir, 'med_mad_b2_2.pdf'), bbox_inches='tight')

plt.show()

In [None]:
hp = mad > mad_clip_thresh
cmap_bool = ListedColormap(['green','red'])

fig, ax = plt.subplots(figsize=(5, 4), dpi=125, sharey=True)

im = ax.imshow(hp, aspect='auto', interpolation='None', cmap=cmap_bool, vmin=0, vmax=1, \
               extent=extent)

divider = make_axes_locatable(ax)
cax = divider.append_axes('right', size='5%', pad=0.1)   
cbar = plt.colorbar(im, cax=cax)

cbar.set_ticks([0.25, 0.75])
cbar.set_ticklabels(['False', 'True'])

ax.set_ylabel('Wavelet scale')
ax.set_xlabel('Frequency [MHz]')

plt.tight_layout()
plt.show()

In [None]:
flged_bls = np.isnan(cwt_data).all(axis=(0, 1, 2))
flged_tints = np.isnan(cwt_data).all(axis=(0, 1, 3))

In [None]:
# Apply flags
modz[hp, ...] = np.nan  # flag regions of the scaleogram in the CoI (cone of influence; edge-effect artifacts)
modz[..., flged_bls] = np.nan  # flagged baselines (from calibration & from proj EW < 14 m)
modz[..., flged_tints, :] = np.nan  # remove fully flagged times

# Only look at relevant data
# select bands and fields ONLY
deslct_tints = np.delete(np.arange(modz.shape[2]), np.concatenate((f1, f2, f3)))
modz[..., deslct_tints, :] = np.nan

# Look at absolute value of modified Z-score
abs_modz = np.abs(modz)

In [None]:
sample_abs_modz = abs_modz[..., f2[10], 150]

if sample_abs_modz.nonzero()[0].size > 0:
    fig, ax = plt.subplots(figsize=(4, 4), dpi=125)
    
    norm = None  # LogNorm()
    im = ax.imshow(sample_abs_modz, aspect='auto', interpolation='None', norm=norm, extent=extent)

    divider = make_axes_locatable(ax)
    cax = divider.append_axes('right', size='5%', pad=0.05)
    
    ax.set_title(r'$|\mathrm{mod} Z|$')
    ax.set_ylabel('Wavelet scale')
    ax.set_xlabel('Frequency [MHz]')

    plt.colorbar(im, cax=cax)
    plt.tight_layout()
    plt.show()
    
else:
    print('Flagged baseline or time.')

In [None]:
fig, ax = plt.subplots(figsize=(6, 4), dpi=125)

ax.hist(abs_modz.ravel(), bins=100, density=False, log=True)

ax.set_xlabel(r'Modified $Z$-score')
ax.set_ylabel('Number of pixels')

plt.tight_layout()
plt.show()

In [None]:
# look at worst offender
worst_idx_flat = np.nanargmax(abs_modz)
worst_idx = np.unravel_index(worst_idx_flat, abs_modz.shape)
print(f'Worst slice is for LST {lsts[worst_idx[2]]:.2f} h and baseline {redg[worst_idx[3]][1:]}')
print(f'Modified Z-score of worst slice is {modz[worst_idx]:.2f}')

In [None]:
sort_abs_modz = abs_modz.copy()
sort_abs_modz[np.isnan(sort_abs_modz)] = 0
sorted_modz_idx = np.argsort(sort_abs_modz.ravel())[::-1]
del sort_abs_modz

In [None]:
offender = 0

bad_idx = np.unravel_index(sorted_modz_idx[offender], abs_modz.shape)
sample_abs_modz = abs_modz[..., bad_idx[-2], bad_idx[-1]]

abs_modz_vmin = 1e-1

print(f'Bad slice is for LST {lsts[bad_idx[2]]:.2f} h and baseline {redg[bad_idx[3]][1:]}')
print(f'Modified Z-score of bad slice is {modz[bad_idx]:.2f}')

if sample_abs_modz.nonzero()[0].size > 0:
    
    fig, axes = plt.subplots(ncols=2, figsize=(7.5, 4), dpi=125)
    
    im1 = axes[0].imshow(sample_abs_modz, aspect='auto', interpolation='None', \
                         norm=LogNorm(vmin=abs_modz_vmin), cmap='jet', extent=extent)
    axes[0].scatter(b_freqs[bad_idx[1]]/1e6, bad_idx[0]+1, s=1000, fc='None', edgecolors='cyan', \
                    lw=2, ls='--')
    divider = make_axes_locatable(axes[0])
    cax1 = divider.append_axes('right', size='5%', pad=0.05)
    plt.colorbar(im1, cax=cax1)
    axes[0].set_title(r'Modified $Z$-score')
    
    im2 = axes[1].imshow(cwt_data[..., bad_idx[-2], bad_idx[-1]], aspect='auto', interpolation='None', \
                         norm=LogNorm(), cmap='jet', extent=extent)
    axes[1].scatter(b_freqs[bad_idx[1]]/1e6, bad_idx[0]+1, s=1000, fc='None', edgecolors='cyan', \
                    lw=2, ls='--')
    divider = make_axes_locatable(axes[1])
    cax2 = divider.append_axes('right', size='5%', pad=0.05)
    plt.colorbar(im2, cax=cax2)
    axes[1].set_title('Scaleogram')
    
    axes[0].set_ylabel('Wavelet scale')
    axes[0].set_xlabel('Frequency [MHz]')
    axes[1].set_xlabel('Frequency [MHz]')
    
    plt.tight_layout()
    plt.show()
    
else:
    print('Flagged basline or time.')

In [None]:
abs_modz_thresh = 40

bad_modz_rav_idxs = sorted_modz_idx[:(abs_modz > abs_modz_thresh).sum()]
bad_modz_idxs = [np.unravel_index(i, abs_modz.shape) for i in bad_modz_rav_idxs]
bad_modz_t_bl_idxs = [i[2:] for i in bad_modz_idxs]

lookup = set()  # a temporary lookup set
bad_modz_t_bl = [x for x in bad_modz_t_bl_idxs if x not in lookup and lookup.add(x) is None]

In [None]:
len(bad_modz_t_bl)

In [None]:
bad_slice = 0

sample_abs_modz = abs_modz[..., bad_modz_t_bl[bad_slice][0], bad_modz_t_bl[bad_slice][1]]

blst = lsts[bad_modz_t_bl[bad_slice][0]]
bbl = redg[bad_modz_t_bl[bad_slice][1]][1:]
print(f'Bad slice is for LST {blst:.2f} h and baseline '\
      f'{bbl}')

if sample_abs_modz.nonzero()[0].size > 0:
    
    fig, axes = plt.subplots(ncols=2, figsize=(7.5, 4), dpi=125)
    
    for i in bad_modz_idxs:
        if i[2:] == bad_modz_t_bl[bad_slice]:
            sf_idx = i
            break
    
    im1 = axes[0].imshow(sample_abs_modz, aspect='auto', interpolation='None', \
                         norm=LogNorm(vmin=abs_modz_vmin), cmap='jet', extent=extent)
    # circle worst pixel in scaleogram
    axes[0].scatter(b_freqs[sf_idx[1]]/1e6, sf_idx[0]+1, s=1000, fc='None', edgecolors='cyan', \
                    lw=2, ls='--')
    divider = make_axes_locatable(axes[0])
    cax1 = divider.append_axes('right', size='5%', pad=0.05)
    plt.colorbar(im1, cax=cax1)
    axes[0].set_title(r'Modified $Z$-score')
    
    
    im2 = axes[1].imshow(cwt_data[..., bad_modz_t_bl[bad_slice][0], bad_modz_t_bl[bad_slice][1]], \
                         aspect='auto', interpolation='None', norm=LogNorm(), cmap='jet', extent=extent)
    axes[1].scatter(b_freqs[sf_idx[1]]/1e6, sf_idx[0]+1, s=1000, fc='None', edgecolors='cyan', \
                    lw=2, ls='--')
    divider = make_axes_locatable(axes[1])
    cax2 = divider.append_axes('right', size='5%', pad=0.05)
    plt.colorbar(im2, cax=cax2)
    axes[1].set_title('Scaleogram')
    
    axes[0].annotate(SELECT_BAND + f'\n LST: {blst:.2f} h \n bl:{bbl} \n mod-Z:{modz[sf_idx]:.1f}', \
                     xycoords='axes fraction', xy=(0.5, 0.03), ha='center', va='bottom', fontsize=8, \
                     bbox=dict(facecolor='white', edgecolor='black', boxstyle='round, pad=0.3', alpha=0.5))
    
    axes[0].set_ylabel('Wavelet scale')
    axes[0].set_xlabel('Frequency [MHz]')
    axes[1].set_xlabel('Frequency [MHz]')
    
    plt.tight_layout()
    plt.show()
    
else:
    print('Flagged basline or time.')

In [None]:
if False:
    
    import multiprocess as multiprocessing
    
    save_dir = os.path.join('/lustre/aoc/projects/hera/mmolnar/wavelets/figures/scg_modz', \
                            SELECT_BAND.replace(' ', '_').lower())
    
    def bl_iter(bad_slice):

#         if bad_slice % 20 == 0:
#             print(bad_slice)

        sample_abs_modz = abs_modz[..., bad_modz_t_bl[bad_slice][0], bad_modz_t_bl[bad_slice][1]]

        blst = lsts[bad_modz_t_bl[bad_slice][0]]
        bbl = redg[bad_modz_t_bl[bad_slice][1]][1:]

        fig, axes = plt.subplots(ncols=2, figsize=(7.5, 4), dpi=300)

        for i in bad_modz_idxs:
            if i[2:] == bad_modz_t_bl[bad_slice]:
                sf_idx = i
                break

        im1 = axes[0].imshow(sample_abs_modz, aspect='auto', interpolation='None', norm=LogNorm(), \
                             cmap='jet', extent=extent)
        # circle worst pixel in scaleogram
        axes[0].scatter(b_freqs[sf_idx[1]]/1e6, sf_idx[0]+1, s=1000, fc='None', edgecolors='cyan', \
                        lw=2, ls='--')
        divider = make_axes_locatable(axes[0])
        cax1 = divider.append_axes('right', size='5%', pad=0.05)
        plt.colorbar(im1, cax=cax1)
        axes[0].set_title(r'Modified $Z$-score')


        im2 = axes[1].imshow(cwt_data[..., bad_modz_t_bl[bad_slice][0], bad_modz_t_bl[bad_slice][1]], \
                             aspect='auto', interpolation='None', norm=LogNorm(), cmap='jet', extent=extent)
        axes[1].scatter(b_freqs[sf_idx[1]]/1e6, sf_idx[0]+1, s=1000, fc='None', edgecolors='cyan', \
                        lw=2, ls='--')
        divider = make_axes_locatable(axes[1])
        cax2 = divider.append_axes('right', size='5%', pad=0.05)
        plt.colorbar(im2, cax=cax2)
        axes[1].set_title('Scaleogram')

        axes[0].annotate(SELECT_BAND + f'\n LST: {blst:.2f} h \n bl:{bbl} \n mod-Z:{modz[sf_idx]:.1f}', \
                         xycoords='axes fraction', xy=(0.5, 0.03), ha='center', va='bottom', fontsize=8, \
                         bbox=dict(facecolor='white', edgecolor='black', boxstyle='round, pad=0.3', alpha=0.5))

        axes[0].set_ylabel('Wavelet scale')
        axes[0].set_xlabel('Frequency [MHz]')
        axes[1].set_xlabel('Frequency [MHz]')

        fig.tight_layout()

        save_fn = 'scg_modz_' + str(bad_modz_t_bl[bad_slice]).replace(', ', '_') + '.png'
        plt.savefig(os.path.join(save_dir, save_fn), bbox_inches='tight')

        plt.close()
        
    m_pool = multiprocessing.Pool(multiprocessing.cpu_count())
    _ = m_pool.map(bl_iter, range(len(bad_modz_t_bl)))
    m_pool.close()
    m_pool.join()

In [None]:
# convolve abs_mod_z images as don't just want a single bad point, want a bad region

### Notes v1

**These notes were taken when analysing all baselines that do not contained flagged antennas. It was then noted that baselines with projected EW < 14 metres are discarded in the power spectrum computation, so the analysis was repeated, with comments written in Notes v2.**

In the below by "delay" we mean delay of the wavelet, with wavelets of small scale -> compressed wavelet -> rapidly changing details -> higher delay (since the signal is in frequency space to start off with)

Redundant analysis:
 - B1 baseline group 1, 3 power at high-ish delays
 - B2F2 baseline group 7 power at low delays
 - B1 & B2 baseline group 12 and 13 has more power at mid delays and high
 - B1 & B2 baseline group 20, 30, 31, 45 higher power at mid delays
 - B1 baseline group 67, 68, 72, 74, 82, 83, 88, 101 localized power at mid delays
 - B1 & B2 baseline group 77, 78, 81, 99, 105, 106 localized power at mid delays
 
All baselines analysis:

 - B1 baseline 12, 16, 69, 74 (bad), 212, 230, 324, 540, 655, 657, 660, 731 power at mid delays
 - B2 baseline 23, 27, 71, 196, 198 power at mid delays
 - B1 & B2 baseline 28, 29, 30, 33 (v bad), 37 (bad), 75, 77, 200, 201 (bad), 202, 203, 206, 208 (bad), 209, 210, 218 (bad), 220, 221, 227, 228, 315 (bad), 319, 321 (bad), 322, 323, 423, 426, 430, 439, 543, 544, 546, 659  power at mid delays
 - B1 baseline 1, 6, 8, 9, 20, 21, 25, 31, 34, 73, 79, 216, 223, 224, 225, 424, 440, 662, 696 localized power at mid delays
 - B2 baseline 3, 19, 35, 195, 212, 313, 432 localized power at mid delays
 - B1 & B2 baseline 14, 17, 204, 217, 219, 222, 316, 317, 433, 434, 436 localized power at mid delays
 
 
Other notes:
 - For B1 get localized power at mid delays at higher end of frequency band - recurring spot for a few baselines
 - Bls 33, 37, 74 bad, with lots of power at mid delays, especially in Band 1
 - Features do not appear transient in time - high power seems to be present across times for specific baselines
 - B1 seems worse than B2 for Field 2? looking at H1C limits, expect B1 to be worse, so effects will be more noticeable, especially when comparing on the same scale

### Notes v2

**Repeating the notes and looking at Band 2 alone (so that scales are not distorted) and not looking at baselines that have projected EW < 14 m**

### Look at some example slices where CWT looks bad

In [None]:
hr_full_fn = 'h1c_idr2.OCRSLP2XTK.npz'
vis_data = np.load(hr_full_fn)['arr_0']

In [None]:
bad_slice = 0

sample_tint = bad_modz_t_bl[bad_slice][0]
sample_bl = bad_modz_t_bl[bad_slice][1]
print(f'Examining baseline {redg[sample_bl, :][1:]} at LST {lsts[sample_tint]:.3f}')

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(7.5, 4), dpi=125, sharey=True)
axes[0].imshow(np.log10(data1[..., sample_tint, sample_bl]), aspect='auto', interpolation='none', cmap='jet')
axes[1].imshow(np.log10(data2[..., sample_tint, sample_bl]), aspect='auto', interpolation='none', cmap='jet')
axes[0].set_title('Band 1')
axes[1].set_title('Band 2')
plt.tight_layout()
plt.show()

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(7.5, 4), dpi=125, sharey=True)

# # set same values scale
# sdata = np.concatenate((data1[..., sample_tint, sample_bl], data2[..., sample_tint, sample_bl]), axis=1)
# vmin = sdata.min()
# vmax = sdata.max()
# vlims = (vmin, vmax)
vlims = None
coikw = {'alpha':0.1, 'hatch':'/'}

sample_data1 = vis_data[sample_tint, band_1[0]:band_1[1]+1, sample_bl]
sample_data2 = vis_data[sample_tint, band_2[0]:band_2[1]+1, sample_bl]

r = scg.cws(freqs[chans1], sample_data1, scales=scales, wavelet=wavelet, cscale='log', coi=True, \
            ax=axes[0], spectrum='power', yaxis='scale', title='WPS B1', \
            xlabel='Frequency', ylabel='Delay', yscale='log', cwt_fun='pywt', vlims=vlims, coikw=coikw)

_ = scg.cws(freqs[chans2], sample_data2, scales=scales, wavelet=wavelet, cscale='log', coi=True, \
            ax=axes[1], spectrum='power', yaxis='scale', title='WPS B2', \
            xlabel='Frequency', ylabel='Delay', yscale='log', cwt_fun='pywt', vlims=vlims, coikw=coikw)

plt.tight_layout()
plt.show()

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(8, 4), dpi=125)

# WPS Scaleogram
ax, qmesh, values = scg.cws(b_freqs, sample_data2, scales=scales, wavelet=wavelet, cscale='log', \
    coi=True, ax=axes[1], spectrum='power', yaxis='frequency', \
    title='WPS', xlabel='Frequency [MHz]', ylabel=r'Delay [$\mu$s]', \
    yscale='log', cwt_fun='pywt', vlims=vlims, cbar=False, coikw=coikw)

axes[1].scatter(b_freqs[sf_idx[1]], pywt.scale2frequency(wavelet, sf_idx[0]+1)/hd.channel_width, \
                s=1000, fc='None', edgecolors='cyan', lw=2, ls='--')

divider = make_axes_locatable(axes[1])
cax1 = divider.append_axes('right', size='5%', pad=0.1)
plt.colorbar(qmesh, cax=cax1)

axes[1].set_xticks(xtk[1:-1])
axes[1].set_xticklabels(intticks)

axes[1].set_yticks([1e-6, 1e-5])
axes[1].set_yticklabels([r'$10^0$', r'$10^1$'])

# modZ plot
# im1 = axes[0].imshow(sample_abs_modz, aspect='auto', interpolation='None', \
#                      norm=LogNorm(vmin=abs_modz_vmin), cmap='jet', \
#                      extent=[extent[0], extent[1], scales[-1]+1, scales[0]])
im1 = axes[0].pcolormesh(b_freqs/1e6, scales+0.5, sample_abs_modz, \
                         norm=LogNorm(vmin=abs_modz_vmin), cmap='jet')
# circle worst pixel in scaleogram
axes[0].scatter(b_freqs[sf_idx[1]]/1e6, sf_idx[0]+1, s=1000, fc='None', edgecolors='cyan', \
                lw=2, ls='--')
axes[0].set_ylim((1, 18))
axes[0].invert_yaxis()
axes[0].set_yscale('log')
axes[0].set_xticks(intticks)
axes[0].set_xticklabels(intticks)

divider = make_axes_locatable(axes[0])
cax1 = divider.append_axes('right', size='5%', pad=0.1)
plt.colorbar(im1, cax=cax1)

axes[0].set_title(r'Modified $Z$-score')
axes[0].set_xlabel('Frequency [MHz]')
axes[0].set_ylabel('Scale')

fig.tight_layout()

# save_fig_dir = '/lustre/aoc/projects/hera/mmolnar/figs'
# plt.savefig(os.path.join(save_fig_dir, 'modz_cwtps.pdf'), bbox_inches='tight')

plt.show()

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(7.5, 4), dpi=125)

c = [chans1, chans2]
s = [sample_data1, sample_data2]
t = ['Band 1 Field 2', 'Band 2 Field 2']

for i, ax in enumerate(axes):
    ax.plot(freqs[c[i]]/1e6, s[i].real, label=r'$\mathfrak{Re}(V)$')
    ax.plot(freqs[c[i]]/1e6, s[i].imag, label=r'$\mathfrak{Im}(V)$')
    ax.set_xlabel('Frequency [MHz]')
    ax.set_title(t[i])
    
axes[0].set_ylabel('Visibility')
axes[0].legend(loc='best')

plt.tight_layout()
plt.show()

In [None]:
# Look at mean of baselines in redundant group to see if visibilities are similar
bl_grp = redg[sample_bl][0]
red_grp = np.where(redg[:, 0] == bl_grp)[0]

In [None]:
fig, axes = plt.subplots(ncols=2, nrows=2, figsize=(7.5, 7.5), dpi=125, sharex='col')

# sample_data3 = np.nanmean(vis_data[sample_tint, band_1[0]:band_1[1]+1, red_grp], axis=0)
# sample_data4 = np.nanmean(vis_data[sample_tint, band_2[0]:band_2[1]+1, red_grp], axis=0)

red_grp_min = [i for i in red_grp if i != sample_bl]

sample_data3 = vis_data[sample_tint, band_1[0]:band_1[1]+1, red_grp_min].T
sample_data4 = vis_data[sample_tint, band_2[0]:band_2[1]+1, red_grp_min].T

c = [chans1, chans2]
s = [sample_data3, sample_data4]
t = ['Band 1 Field 2', 'Band 2 Field 2']

lws = [2 if i == sample_bl else 1 for i in red_grp]
cs = ['grey' if i == sample_bl else 'red' for i in red_grp]

axes[0][0].plot(freqs[c[0]]/1e6, s[0].real, alpha=0.5)
axes[1][0].plot(freqs[c[0]]/1e6, s[0].imag, alpha=0.5)
axes[0][1].plot(freqs[c[1]]/1e6, s[1].real, alpha=0.5)
axes[1][1].plot(freqs[c[1]]/1e6, s[1].imag, alpha=0.5)

axes[0][0].plot(freqs[c[0]]/1e6, vis_data[sample_tint, band_1[0]:band_1[1]+1, sample_bl].real, \
                lw=1.5, c='red')
axes[1][0].plot(freqs[c[0]]/1e6, vis_data[sample_tint, band_1[0]:band_1[1]+1, sample_bl].imag, \
                lw=1.5, c='red')
axes[0][1].plot(freqs[c[1]]/1e6, vis_data[sample_tint, band_2[0]:band_2[1]+1, sample_bl].real, \
                lw=1.5, c='red')
axes[1][1].plot(freqs[c[1]]/1e6, vis_data[sample_tint, band_2[0]:band_2[1]+1, sample_bl].imag, \
                lw=1.5, c='red')

axes[1][0].set_xlabel('Frequency [MHz]')
axes[1][1].set_xlabel('Frequency [MHz]')
    
axes[0][0].set_ylabel(r'$\mathfrak{Re}(V)$')
axes[1][0].set_ylabel(r'$\mathfrak{Im}(V)$')
# axes[1].legend(loc='best')
axes[0][0].set_title(t[0])
axes[0][1].set_title(t[0])

plt.tight_layout()
plt.show()

In [None]:
# check auto-PS for the baseline
delay, pspec = signal.periodogram(sample_data1, fs=1/freq_resolution, \
    window='blackmanharris', scaling='spectrum', nfft=None, detrend=False, \
    return_onesided=False)

delay_sort = np.argsort(delay)
delay = delay[delay_sort]
pspec = pspec[delay_sort]

In [None]:
fig, ax = plt.subplots(figsize=(6, 6), dpi=125)

ax.plot(delay, pspec, alpha=1)

ax.set_yscale('log')
ax.set_ylabel('Power spectrum')
ax.set_xlabel('Delay')

plt.tight_layout()
plt.show()