<center><strong><font size=+3>Wavelet transforms</font></center>
<br><br>
</center>
<center><strong><font size=+2>Matyas Molnar and Bojan Nikolic</font><br></strong></center>
<br><center><strong><font size=+1>Astrophysics Group, Cavendish Laboratory, University of Cambridge</font></strong></center>

The Fourier transform (and hence power spectrum) works very well in transforming a signal from its time-domain to its frequency domain, when the frequency spectrum is stationary and does not evolve in time.

The more non-stationary/dynamic a signal is, the worse the results will be, which is the case for most of the signals we see in real life. In 21 cm cosmology, we compute power spectra over considerable frequency bandwidths: the Universe can change over such scales, since frequency maps to redshift. A much better approach for analyzing dynamic signals is to use the wavelet transform instead of the Fourier transform.

Furthermore, if erroneous modes exist in the data, these will be locatable to a point in the dual frequency & delay space - these can be found using the wavelet transform.

In [None]:
import copy
import functools
import glob
import itertools
import os
import warnings

import h5py
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.colors import LogNorm
from scipy import signal
from scipy.interpolate import griddata

from hera_cal.io import HERAData
from hera_cal.redcal import get_reds
import hera_pspec as hp

import pywt
import scaleogram as scg

In [None]:
%matplotlib inline

In [None]:
from matplotlib import rc
rc('font',**{'family':'serif', 'serif':['cm']})
rc('text', usetex=True)
rc('text.latex', preamble=r'\usepackage{amssymb} \usepackage{amsmath}')

In [None]:
band_1 = [175, 334]
band_2 = [515, 694]

field_1 = [1.25, 2.70]
field_2 = [4.50, 6.50]
field_3 = [8.50, 10.75]

bad_ants = [0, 2, 11, 24, 50, 53, 54, 67, 69, 98, 122, 136, 139]

## Load final OCRSLPXTK visibility product

In [None]:
analysis_dir = '/lustre/aoc/projects/hera/mmolnar/wavelets'

In [None]:
hr_full_fn = os.path.join(analysis_dir, 'h1c_idr2.OCRSLP2XTK.npz')

In [None]:
lstb_dir = '/lustre/aoc/projects/hera/H1C_IDR2/IDR2_2_pspec/v2/one_group/data'
final_files = sorted(glob.glob(os.path.join(lstb_dir, 'zen.grp1.of1.LST.*XTK.uvh5')))

In [None]:
def fltBad(bll, badl, minbl=1):
    r1 = map(functools.partial(filter, lambda x: not (x[0] in badl or x[1] \
                               in badl)), bll)
    r2 = list(map(list, r1))
    return list(filter(lambda x: len(x) >= minbl, r2))

def groupBls(bll):
    return np.array([(g, i, j) for (g, bl) in enumerate(bll) for (i, j, p) in bl])

In [None]:
hd = HERAData(final_files[0])
reds = get_reds(hd.antpos, pols=['ee'])
reds = fltBad(reds, bad_ants)
redg = groupBls(reds)
f_res = hd.channel_width

data = np.load(hr_full_fn)['arr_0']

In [None]:
# get LSTs
with warnings.catch_warnings():
    warnings.filterwarnings('ignore', message='antenna_diameters is not set. Using known values for HERA.')
    
    for i, f_file in enumerate(final_files):
        file = h5py.File(f_file, 'r')
        lsts_i = np.sort(np.unique(file['Header']['lst_array']))
        
        if i == 0:
            lsts = lsts_i
            freqs = np.squeeze(file['Header']['freq_array'])
            chans = np.arange(freqs.size)
        else:
            lsts = np.concatenate((lsts, lsts_i))
            
        file.close()
        
    lsts = lsts * 12 / np.pi  # convert to hours

In [None]:
extent = [hd.freqs[0], hd.freqs[-1], lsts[-1], lsts[0]]
xlim = [hd.freqs[0], round(hd.freqs[-1], -6)]
ylim = [np.ceil(lsts[-1]), np.floor(lsts[0])]

In [None]:
fig, ax = plt.subplots(figsize=(7, 5), dpi=150)
ax.imshow(np.abs(data[..., 0]), aspect='auto', interpolation='None', extent=extent)
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.set_xlabel('Frequency')
ax.set_ylabel('LST')
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(7, 5), dpi=150)
ax.imshow(np.angle(data[..., 0]), aspect='auto', interpolation='None', extent=extent)
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.set_xlabel('Frequency')
ax.set_ylabel('LST')
plt.show()

## Running the wavelet transform

In [None]:
band = band_1

sample_data = data[50, band[0]:band[1]+1, 0]  # pick one time integration
freqs = hd.freqs[band[0]:band[1]+1]

In [None]:
fig, ax = plt.subplots(figsize=(6, 4), dpi=125)
ax.plot(freqs/1e6, sample_data.real, label=r'$\mathfrak{Re}(V)$')
ax.plot(freqs/1e6, sample_data.imag, label=r'$\mathfrak{Im}(V)$')
ax.set_xlabel('Frequency [MHz]')
ax.legend(loc='best')
plt.tight_layout()
plt.show()

In [None]:
wavelet = 'cmor1.5-1.0'

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(6, 4), dpi=125)
ax = scg.plot_wav(wavelet, axes=axes)
plt.tight_layout()
plt.show()

### Continuous wavelet transform

In [None]:
fig, ax = plt.subplots(figsize=(6, 4), dpi=125)
ax, qmesh, v = scg.cws(freqs, sample_data, wavelet=wavelet, cscale='log', coi=True, \
                    ax=ax, spectrum='power', yaxis='frequency', \
                    xlabel='Frequency', ylabel='Delay', yscale='log', cwt_fun='pywt')
plt.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(6, 4), dpi=125)
ax.imshow(np.log10(v), aspect='auto', interpolation='none', cmap='jet')
plt.tight_layout()
plt.show()

In [None]:
scales = np.arange(1, 2**(np.floor(np.log2(freqs.size))-2), dtype=int)

cfs, delays = pywt.cwt(sample_data, scales, wavelet, hd.channel_width)
power = np.abs(cfs)**2

fig, ax = plt.subplots(figsize=(6, 4), dpi=125)

im = ax.contourf(freqs, delays, power, levels=None, extend='both', norm = LogNorm())

ax.set_title('Wavelet Power Spectrum')
ax.set_xlabel('Frequency')
ax.set_ylabel('Delay')
ax.set_yscale('log')

plt.colorbar(im, format='%.0e')

plt.tight_layout()
plt.show()

### Discrete wavelet decomposition

In [None]:
# pywt.wavelist(kind='discrete')

In [None]:
disc_wavelet = 'db2'

In [None]:
# calculate approximations of scaling function (phi) and wavelet function (psi) at the given level of refinement.
phi, psi, x = pywt.Wavelet(disc_wavelet).wavefun()

In [None]:
fig, ax = plt.subplots(figsize=(6, 4), dpi=125)
ax.plot(phi, label='Scaling function')
ax.plot(psi, label='Wavelet function')
ax.legend(loc='best')
plt.tight_layout()
plt.show()

In [None]:
coeffs = pywt.wavedec(sample_data, wavelet=disc_wavelet)
cA = coeffs[0]
cDs = coeffs[1:]
levels = len(cDs)
lengths = [len(cD) for cD in cDs]
col = int(np.max(lengths))

cc = np.empty((levels, col), dtype=complex)

for level in range(levels):
    y = cDs[level]
    if lengths[level] < col:
        x = np.arange(0.5, len(y)+0.5) * col/len(y)
        xi = np.arange(col)
        yi = griddata(points=x, values=y, xi=xi, method='nearest')
    else:
        yi = y
    
    cc[level, :] = yi
    
delays = pywt.scale2frequency(disc_wavelet, np.arange(1, levels+1)) / hd.channel_width

In [None]:
fig, ax = plt.subplots(figsize=(6, 4), dpi=125)

im = ax.imshow(np.abs(cc)**2, aspect='auto', extent=[freqs[0]/1e6, freqs[-1]/1e6, levels+0.5, 0.5], \
               norm=LogNorm(), interpolation='None')
cbar = plt.colorbar(im)

ax.set_xlabel('Frequency')
ax.set_ylabel('Level')

ax.invert_yaxis()

plt.tight_layout()
plt.show()

### Look at different CWT implementations

In [None]:
# scales = np.arange(1, 2**(np.floor(np.log2(freqs.size))-2), dtype=int)
scales = np.arange(1, min(len(freqs)/10, 100), dtype=int)

cfs, delays = pywt.cwt(sample_data, scales, 'cmor1.5-1.0', sampling_period=hd.channel_width)
power = np.abs(cfs)**2

In [None]:
fig, ax = plt.subplots(figsize=(6, 4), dpi=125)
ax.imshow(np.log10(power), aspect='auto', interpolation='none', cmap='jet')
plt.tight_layout()
plt.show()

In [None]:
cfs, delays = scg.wfun.fastcwt(sample_data, scales, 'cmor1.5-1.0', sampling_period=hd.channel_width)
power = np.abs(cfs)**2

In [None]:
fig, ax = plt.subplots(figsize=(6, 4), dpi=125)
ax.imshow(np.log10(power), aspect='auto', interpolation='none', cmap='jet')
plt.tight_layout()
plt.show()

In [None]:
scipy_cwt = signal.cwt(sample_data, signal.morlet2, widths=scales)
power = np.abs(scipy_cwt)**2

In [None]:
fig, ax = plt.subplots(figsize=(6, 4), dpi=125)
ax.imshow(np.log10(power), aspect='auto', interpolation='none', cmap='jet')
plt.tight_layout()
plt.show()

### Create CWT hypercube

In [None]:
no_red = len(set(redg[:, 0]))
red_data = np.empty_like(data)[..., :no_red]

In [None]:
for red in range(no_red):
    red_idxs = np.where(redg[:, 0] == red)[0]
    red_data[..., red] = np.nanmean(data[..., red_idxs], axis=-1)

In [None]:
fig, ax = plt.subplots(figsize=(7, 5), dpi=150)
slct_red = 2
vmax = np.nanpercentile(np.abs(red_data[..., slct_red]), 95)
ax.imshow(np.abs(red_data[..., slct_red]), aspect='auto', interpolation='None', extent=extent, vmax=vmax)
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.set_xlabel('Frequency')
ax.set_ylabel('LST')
plt.show()

In [None]:
# # scales = np.arange(1, min(len(freqs)/10, 100), dtype=int)
scales = np.arange(1, 18, dtype=int)
# data_ = data

# # hypercube with dims scales, freqs, times, bls
# power_arr = np.zeros((scales.size, band[1]+1 - band[0], data_.shape[0], data_.shape[2]))

# for tint in range(data_.shape[0]):
#     if tint % 20 == 0:
#         print(tint)
#     for bl in range(data_.shape[2]):
#         cfs, delays = pywt.cwt(data_[tint, band[0]:band[1]+1, bl], scales, 'cmor1.5-1.0', \
#                                sampling_period=hd.channel_width)
#         power = np.abs(cfs)**2
        
#         power_arr[..., tint, bl] = power
        
# np.savez('cwt_power_b1.npz', power=power_arr, scales=scales, wavelet='cmor1.5-1.0', \
#          delays=delays, freqs=hd.freqs[np.arange(band[0], band[1]+1)], lsts=lsts, redg=redg)

In [None]:
p_npz = np.load('cwt_power_b2.npz')

In [None]:
power_arr = p_npz['power']

In [None]:
fig, ax = plt.subplots(figsize=(6, 4), dpi=125)
ax.imshow(np.log10(power_arr[..., 180, 500]), aspect='auto', interpolation='none', cmap='jet')
plt.tight_layout()
plt.show()