<center><strong><font size=+3>Wavelet transforms</font></center>
<br><br>
</center>
<center><strong><font size=+2>Matyas Molnar and Bojan Nikolic</font><br></strong></center>
<br><center><strong><font size=+1>Astrophysics Group, Cavendish Laboratory, University of Cambridge</font></strong></center>

The Fourier transform (and hence power spectrum) works very well in transforming a signal from its time-domain to its frequency domain, when the frequency spectrum is stationary and does not evolve in time.

The more non-stationary/dynamic a signal is, the worse the results will be, which is the case for most of the signals we see in real life. In 21 cm cosmology, we compute power spectra over considerable frequency bandwidths: the Universe can change over such scales, since frequency maps to redshift. A much better approach for analyzing dynamic signals is to use the wavelet transform instead of the Fourier transform.

Furthermore, if erroneous modes exist in the data, these will be locatable to a point in the dual frequency & delay space - these can be found using the wavelet transform.

In [None]:
import copy
import functools
import glob
import itertools
import os
import warnings

import h5py
import numpy as np
import scipy
from matplotlib import pyplot as plt
from matplotlib.colors import LogNorm
from mpl_toolkits.axes_grid1 import make_axes_locatable
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from scipy import signal
from scipy.interpolate import griddata

from hera_cal.io import HERAData
from hera_cal.redcal import get_reds

import pywt
import scaleogram as scg

In [None]:
%matplotlib inline

In [None]:
from matplotlib import rc
rc('font',**{'family':'serif','serif':['cm']})
rc('text', usetex=True)
rc('text.latex', preamble=r'\usepackage{amssymb} \usepackage{amsmath}')

In [None]:
band_1 = [175, 334]
band_2 = [515, 694]

field_1 = [1.25, 2.70]
field_2 = [4.50, 6.50]
field_3 = [8.50, 10.75]

bad_ants = [0, 2, 11, 24, 50, 53, 54, 67, 69, 98, 122, 136, 139]

## Load final OCRSLPXTK visibility product

In [None]:
analysis_dir = '/lustre/aoc/projects/hera/mmolnar/wavelets'

In [None]:
hr_full_fn = os.path.join(analysis_dir, 'h1c_idr2.OCRSLP2XTK.npz')

In [None]:
lstb_dir = '/lustre/aoc/projects/hera/H1C_IDR2/IDR2_2_pspec/v2/one_group/data'
final_files = sorted(glob.glob(os.path.join(lstb_dir, 'zen.grp1.of1.LST.*XTK.uvh5')))

In [None]:
def fltBad(bll, badl, minbl=1):
    r1 = map(functools.partial(filter, lambda x: not (x[0] in badl or x[1] \
                               in badl)), bll)
    r2 = list(map(list, r1))
    return list(filter(lambda x: len(x) >= minbl, r2))

def groupBls(bll):
    return np.array([(g, i, j) for (g, bl) in enumerate(bll) for (i, j, p) in bl])

In [None]:
hd = HERAData(final_files[0])
reds = get_reds(hd.antpos, pols=['ee'])
reds = fltBad(reds, bad_ants)
redg = groupBls(reds)
f_res = hd.channel_width

data = np.load(hr_full_fn)['arr_0']

In [None]:
# get LSTs
with warnings.catch_warnings():
    warnings.filterwarnings('ignore', message='antenna_diameters is not set. Using known values for HERA.')
    
    for i, f_file in enumerate(final_files):
        file = h5py.File(f_file, 'r')
        lsts_i = np.sort(np.unique(file['Header']['lst_array']))
        
        if i == 0:
            lsts = lsts_i
            freqs = np.squeeze(file['Header']['freq_array'])
            chans = np.arange(freqs.size)
        else:
            lsts = np.concatenate((lsts, lsts_i))
            
        file.close()
        
    lsts = lsts * 12 / np.pi  # convert to hours

In [None]:
extent = [hd.freqs[0], hd.freqs[-1], lsts[-1], lsts[0]]
xlim = np.array([hd.freqs[0], round(hd.freqs[-1], -6)])
ylim = np.array([np.ceil(lsts[-1]), np.floor(lsts[0])])

In [None]:
f1 = np.where((lsts > field_1[0]) & (lsts < field_1[1]))[0]
f2 = np.where((lsts > field_2[0]) & (lsts < field_2[1]))[0]
f3 = np.where((lsts > field_3[0]) & (lsts < field_3[1]))[0]

b1 = np.where((chans > band_1[0]) & (chans < band_1[1]))[0]
b2 = np.where((chans > band_2[0]) & (chans < band_2[1]))[0]

In [None]:
f_all = np.concatenate((f1, f2, f3))
b_all = np.concatenate((b1, b2))
datag = data[f_all, :, :][:, b_all, :]

In [None]:
fig, axes = plt.subplots(nrows=2, figsize=(7, 9), dpi=600, sharex=True)

rasterized = True

vmax = 400 # np.nanpercentile(np.abs(data[..., 40]), 95)
# vmax = round(np.nanmax(np.abs(datag[..., 40])), -2)
freqsm = hd.freqs/1e6
pm1 = axes[0].pcolormesh(freqsm, lsts, np.abs(data[..., 40]), vmax=vmax, rasterized=rasterized)
# axes[0].set_xlabel('Frequency [MHz]')
axes[0].set_ylabel('LST [h]')
axes[0].set_xlim(xlim/1e6)
axes[0].set_ylim(ylim)

divider = make_axes_locatable(axes[0])
cax1 = divider.append_axes('right', size='2.5%', pad=0.1)   
plt.colorbar(pm1, cax=cax1, extend='max', label=r'$|V|$')

pm2 = axes[1].pcolormesh(freqsm, lsts, np.angle(data[..., 40]), cmap='PiYG', rasterized=rasterized)
axes[1].set_xlabel('Frequency [MHz]')
axes[1].set_ylabel('LST [h]')
axes[1].set_xlim(xlim/1e6)
axes[1].set_ylim(ylim)

divider = make_axes_locatable(axes[1])
cax2 = divider.append_axes('right', size='2.5%', pad=0.1)   
plt.colorbar(pm2, cax=cax2, label=r'$\varphi$')

lstcuts = [field_1, field_2, field_3]
chancuts = [band_1, band_2]
tc = ['w', 'k']

for i, ax in enumerate(axes):
    
    for f, lc in enumerate(lstcuts):
        
        for b, cc in enumerate(chancuts):
    
            r = plt.Rectangle((freqsm[cc[0]], lc[0]), freqsm[cc[1]]-freqsm[cc[0]], lc[1]-lc[0], \
                              fc='None', lw=1.5, ec='orange')
            ax.add_patch(r)
            
            ax.text((freqsm[cc[0]]+freqsm[cc[1]])/2, (lc[1]+lc[0])/2, f'B{b+1}F{f+1}', \
                    fontsize=8, c=tc[i], ha='center', va='center')

fig.tight_layout()

# save_fig_dir = '/lustre/aoc/projects/hera/mmolnar/figs'
# plt.savefig(os.path.join(save_fig_dir, 'pI_vis.pdf'), bbox_inches='tight')

plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(7, 5), dpi=150)
ax.imshow(np.abs(data[..., 40]), aspect='auto', interpolation='None', extent=extent)
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.set_xlabel('Frequency')
ax.set_ylabel('LST')
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(7, 5), dpi=150)
ax.imshow(np.angle(data[..., 40]), aspect='auto', interpolation='None', extent=extent)
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.set_xlabel('Frequency')
ax.set_ylabel('LST')
plt.show()

## Running the wavelet transform

In [None]:
band = band_1

slct_time = 35
slct_bl = 40

sample_data = data[slct_time, band[0]:band[1]+1, slct_bl]  # pick 1 tint in F1, Band 1, 1st 14m EW bl
freqsb = hd.freqs[band[0]:band[1]+1]

print(f'LST {lsts[slct_time]:.2f} h and baseline {redg[slct_bl][1:]}')

In [None]:
fig, ax = plt.subplots(figsize=(6, 4), dpi=125)
ax.plot(freqsb/1e6, sample_data.real, label=r'$\mathfrak{Re}(V)$')
ax.plot(freqsb/1e6, sample_data.imag, label=r'$\mathfrak{Im}(V)$')
ax.set_xlabel('Frequency [MHz]')
ax.legend(loc='best')
plt.tight_layout()
plt.show()

In [None]:
# pywt.wavelist(kind='continuous')

In [None]:
# Choose wavelet
wavelet = 'cmor1.5-1.0'

In [None]:
fig, axes = plt.subplots(ncols=2, figsize=(8, 4), dpi=125)

axes = scg.plot_wav('cmor1.5-1.0', axes=axes)

axes[0].set_title('Wavelet Function')
axes[1].set_title('Frequency Support')

plt.suptitle(r'Complex Morlet Wavelet with $B=1.5$ and $C=1$')
fig.tight_layout()

# save_fig_dir = '/lustre/aoc/projects/hera/mmolnar/figs/CHAP-5/FIGS'
# plt.savefig(os.path.join(save_fig_dir, 'morlet_wavelet.pdf'), bbox_inches='tight')

plt.show()

### Continuous wavelet transform

In [None]:
fig, ax = plt.subplots(figsize=(6, 4), dpi=125)
ax, qmesh, v = scg.cws(freqsb, sample_data, wavelet=wavelet, cscale='log', coi=True, \
                    ax=ax, spectrum='power', yaxis='frequency', \
                    xlabel='Frequency', ylabel='Delay', yscale='log', cwt_fun='pywt')
plt.tight_layout()
plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(6, 4), dpi=125)
ax.imshow(np.log10(v), aspect='auto', interpolation='none', cmap='jet')
plt.tight_layout()
plt.show()

In [None]:
scales = np.arange(1, 2**(np.floor(np.log2(freqsb.size))-2), dtype=int)

cfs, delays = pywt.cwt(sample_data, scales, wavelet, hd.channel_width)
power = np.abs(cfs)**2

fig, ax = plt.subplots(figsize=(6, 4), dpi=125)

im = ax.contourf(freqsb, delays, power, levels=None, extend='both', norm = LogNorm())

ax.set_title('Wavelet Power Spectrum')
ax.set_xlabel('Frequency')
ax.set_ylabel('Delay')
ax.set_yscale('log')

plt.colorbar(im, format='%.0e')

plt.tight_layout()
plt.show()

### Discrete wavelet decomposition

In [None]:
# pywt.wavelist(kind='discrete')

In [None]:
disc_wavelet = 'db2'

In [None]:
# calculate approximations of scaling function (phi) and wavelet function (psi) at the given level of refinement.
phi, psi, x = pywt.Wavelet(disc_wavelet).wavefun()

In [None]:
fig, ax = plt.subplots(figsize=(6, 4), dpi=125)
ax.plot(phi, label='Scaling function')
ax.plot(psi, label='Wavelet function')
ax.legend(loc='best')
plt.tight_layout()
plt.show()

In [None]:
coeffs = pywt.wavedec(sample_data, wavelet=disc_wavelet)
cA = coeffs[0]
cDs = coeffs[1:]
levels = len(cDs)
lengths = [len(cD) for cD in cDs]
col = int(np.max(lengths))

cc = np.empty((levels, col), dtype=complex)

for level in range(levels):
    y = cDs[level]
    if lengths[level] < col:
        x = np.arange(0.5, len(y)+0.5) * col/len(y)
        xi = np.arange(col)
        yi = griddata(points=x, values=y, xi=xi, method='nearest')
    else:
        yi = y
    
    cc[level, :] = yi
    
delays = pywt.scale2frequency(disc_wavelet, np.arange(1, levels+1)) / hd.channel_width

In [None]:
fig, ax = plt.subplots(figsize=(6, 4), dpi=125)

im = ax.imshow(np.abs(cc)**2, aspect='auto', extent=[freqsb[0]/1e6, freqsb[-1]/1e6, levels+0.5, 0.5], \
               norm=LogNorm(), interpolation='None')
cbar = plt.colorbar(im)

ax.set_xlabel('Frequency')
ax.set_ylabel('Level')

ax.invert_yaxis()

plt.tight_layout()
plt.show()

### Look at different CWT implementations

In [None]:
# scales = np.arange(1, 2**(np.floor(np.log2(freqsb.size))-2), dtype=int)
scales = np.arange(1, min(len(freqsb)/10, 100), dtype=int)

cfs, delays = pywt.cwt(sample_data, scales, 'cmor1.5-1.0', sampling_period=hd.channel_width)
power = np.abs(cfs)**2

In [None]:
fig, ax = plt.subplots(figsize=(6, 4), dpi=125)
ax.imshow(np.log10(power), aspect='auto', interpolation='none', cmap='jet')
plt.tight_layout()
plt.show()

In [None]:
cfs, delays = scg.wfun.fastcwt(sample_data, scales, 'cmor1.5-1.0', sampling_period=hd.channel_width)
power = np.abs(cfs)**2

In [None]:
fig, ax = plt.subplots(figsize=(6, 4), dpi=125)
ax.imshow(np.log10(power), aspect='auto', interpolation='none', cmap='jet')
plt.tight_layout()
plt.show()

In [None]:
scipy_cwt = signal.cwt(sample_data, signal.morlet2, widths=scales)
power = np.abs(scipy_cwt)**2

In [None]:
fig, ax = plt.subplots(figsize=(6, 4), dpi=125)
ax.imshow(np.log10(power), aspect='auto', interpolation='none', cmap='jet')
plt.tight_layout()
plt.show()

### Create CWT hypercube

In [None]:
no_red = len(set(redg[:, 0]))
red_data = np.empty_like(data)[..., :no_red]

In [None]:
for red in range(no_red):
    red_idxs = np.where(redg[:, 0] == red)[0]
    red_data[..., red] = np.nanmean(data[..., red_idxs], axis=-1)

In [None]:
fig, ax = plt.subplots(figsize=(7, 5), dpi=150)
slct_red = 2
vmax = np.nanpercentile(np.abs(red_data[..., slct_red]), 95)
ax.imshow(np.abs(red_data[..., slct_red]), aspect='auto', interpolation='None', extent=extent, vmax=vmax)
ax.set_xlim(xlim)
ax.set_ylim(ylim)
ax.set_xlabel('Frequency')
ax.set_ylabel('LST')
plt.show()

In [None]:
# # scales = np.arange(1, min(len(freqsb)/10, 100), dtype=int)
scales = np.arange(1, 18, dtype=int)
# data_ = data

# # hypercube with dims scales, freqsb, times, bls
# power_arr = np.zeros((scales.size, band[1]+1 - band[0], data_.shape[0], data_.shape[2]))

# for tint in range(data_.shape[0]):
#     if tint % 20 == 0:
#         print(tint)
#     for bl in range(data_.shape[2]):
#         cfs, delays = pywt.cwt(data_[tint, band[0]:band[1]+1, bl], scales, 'cmor1.5-1.0', \
#                                sampling_period=hd.channel_width)
#         power = np.abs(cfs)**2
        
#         power_arr[..., tint, bl] = power
        
# np.savez('cwt_power_b1.npz', power=power_arr, scales=scales, wavelet='cmor1.5-1.0', \
#          delays=delays, chans=np.arange(band[0], band[1]+1), freqs=freqsb, lsts=lsts, redg=redg)

In [None]:
p_npz = np.load('cwt_power_b2.npz')

In [None]:
power_arr = p_npz['power']

In [None]:
fig, ax = plt.subplots(figsize=(6, 4), dpi=125)
ax.imshow(np.log10(power_arr[..., 180, 500]), aspect='auto', interpolation='none', cmap='jet')
plt.tight_layout()
plt.show()

### Plot CWT, signal and FT

In [None]:
fig = plt.figure(figsize=(8, 8), constrained_layout=True, dpi=125)

gs = fig.add_gridspec(3, 2, height_ratios=[1, 2.5, 0.1], width_ratios=[2.5, 1])


# 1) Scaleogram
ax1 = plt.subplot(gs[1, 0])
ax1_title = ''#'CWT PS'
coikw = {'alpha':0.5, 'hatch':'/'}
ax1, qmesh, values = scg.cws(freqsb, sample_data, wavelet=wavelet, scales=np.arange(2, 60), \
    cscale='log', coi=True, title=ax1_title, ax=ax1, spectrum='power', yaxis='frequency', \
    cbar=False, xlabel='Frequency [MHz]', ylabel='Delay [s]', yscale='log', cwt_fun='pywt', \
    cbarkw={'aspect':40, 'pad':0.12, 'fraction':0.05}, coikw=coikw)

xtk = np.linspace(round(freqsb[10], -6), round(freqsb[-10], -6), 8)
intticks = (xtk//1e6).astype(int)
ax1.set_xticks(xtk)
ax1.set_xticklabels(intticks)

cax1 = plt.subplot(gs[2, 0])
plt.colorbar(qmesh, cax=cax1, orientation='horizontal', label='$|\mathrm{CWT}|^2$')


# 2) Visibility Signal
ax0 = plt.subplot(gs[0, 0])
ax0.plot(freqsb, sample_data.real, label=r'$\mathfrak{Re}$')
ax0.plot(freqsb, sample_data.imag, label=r'$\mathfrak{Im}$')
# ax0.set_title('Visibility Signal')
ax0.set_ylabel(r'$V$ [Jy]')
ax0.legend(loc='lower left')
ax0.set_xlim(*ax1.get_xlim())
ax0.tick_params(labelbottom=False)


# 3) PS or FT
# # 3a) PS
# delay, pspec = signal.periodogram(sample_data, fs=1/ hd.channel_width, \
#     window='blackmanharris', scaling='spectrum', nfft=sample_data.size, detrend=False, \
#     return_onesided=False)
# delay_sort = np.argsort(delay)
# delay = delay[delay_sort]
# pspec = pspec[delay_sort]
# pspec[np.abs(delay) < ax1.get_ylim()[0] - np.ediff1d(delay).mean()] *= np.nan

# ax2 = plt.subplot(gs[1, 1])
# z_idx = np.where(delay == 0)[0][0]
# ax2.plot(pspec[z_idx:], delay[z_idx:], label=r'$+$', c='deeppink', alpha=0.8)
# ax2.plot(pspec[:z_idx+1], -delay[:z_idx+1], label=r'$-$', c='purple', alpha=0.8)
# ax2.set_ylim(*ax1.get_ylim())
# ax2.set_xscale('log')
# ax2.set_yscale('log')

# ax2.legend(loc='best')
# # ax2.set_title('Power Spectrum')
# ax2.set_xlabel('Power Spectrum')
# ax2.tick_params(labelleft=False)

# 3b) FT
vft = scipy.fft.fft(sample_data*signal.blackmanharris(sample_data.size))
dly = scipy.fft.fftfreq(sample_data.size, hd.channel_width)

dly_sort = np.argsort(dly)
dly = dly[dly_sort]
vft = vft[dly_sort]
vft[np.abs(dly) < ax1.get_ylim()[0] - np.ediff1d(dly).mean()] *= np.nan

ax2 = plt.subplot(gs[1, 1])
z_idx = np.where(dly == 0)[0][0]
ax2.plot(np.abs(vft[z_idx:]), dly[z_idx:], label=r'$+$', c='deeppink', alpha=0.8)
ax2.plot(np.abs(vft[:z_idx+1]), -dly[:z_idx+1], label=r'$-$', c='purple', alpha=0.8)
ax2.set_ylim(*ax1.get_ylim())
ax2.set_xscale('log')
ax2.set_yscale('log')
# ax2.set_xlim((0.15, 150))

ax2.legend(loc='best')
# ax2.set_title('Power Spectrum')
ax2.set_xlabel(r'$|\widetilde{V}|$ [Jy Hz]')
ax2.tick_params(labelleft=False)


# save_fig_dir = '/lustre/aoc/projects/hera/mmolnar/figs'
# plt.savefig(os.path.join(save_fig_dir, 'plot_wav_ft.pdf'), bbox_inches='tight')

plt.show()

### Blackman-Harris window

In [None]:
blg = 2  # choose 2nd baseline group
ew_short_bls = np.where(redg[:, 0] == blg)[0]
d = data[f1[f1.size//2], :, ew_short_bls]

d_plt = np.abs(d.T)
mask = np.ones(d_plt.shape, dtype=bool)
mask[np.concatenate((b1, b2)), :] = False
d_plt_flg = d_plt.copy()
d_plt_flg[~mask] *= np.nan
d_plt[mask] *= np.nan

print(f'Looking at baselines redundant to {redg[ew_short_bls[0], 1:]}')

In [None]:
# fig, axes = plt.subplots(nrows=2, figsize=(6, 8), dpi=125)
fig = plt.figure(figsize=(8, 7), constrained_layout=True, dpi=125)

gs = fig.add_gridspec(2, 2, height_ratios=[1, 1], width_ratios=[1, 1])

# Final pI visibility amplitudes of 14 m EW baselines with BH windows
ax1 = plt.subplot(gs[0, :])
ax1.plot(freqsm, d_plt, alpha=0.5, zorder=2)
ax1.plot(freqsm, d_plt_flg, alpha=0.2, c='grey', zorder=2)

# ax.set_yscale('log')
ax1.set_ylim(0, 60)
ax1.set_xlabel('Frequency [MHz]')
ax1.set_ylabel(r'$|V|$ [Jy]')

ax1_b = ax1.twinx()
w1 = signal.blackmanharris(b1.size)
# ax1_b.plot(freqsm[b1], w1, c='forestgreen', zorder=10)
ax1_b.fill_between(freqsm[b1], w1, color='orange', alpha=0.5)
w2 = signal.blackmanharris(b2.size)
# ax1_b.plot(freqsm[b2], w2, c='orange', zorder=10)
ax1_b.fill_between(freqsm[b2], w2, color='green', alpha=0.5)
ax1_b.set_ylim(0, 1.15)

ax1_b.text(freqsm[b1].mean(), 1.05, 'Band 1', fontsize=12, ha='center', color='orange')
ax1_b.text(freqsm[b2].mean(), 1.05, 'Band 2', fontsize=12, ha='center', color='green')
ax1_b.set_ylabel('Window amplitude')

ax1.set_zorder(2)
ax1_b.set_zorder(1)
ax1.patch.set_visible(False)


# Band 1 FFT
ax2 = plt.subplot(gs[1, 0])

vft1 = scipy.fft.fft(d[:, b1]*w1)
dly1 = scipy.fft.fftfreq(b1.size, hd.channel_width)

dly_sort1 = np.argsort(dly1)
dly1 = dly1[dly_sort1]
vft1 = vft1[:, dly_sort1]

ax2.plot(dly1*1e6, np.abs(vft1.T), alpha=0.5)
ax2.set_yscale('log')
ax2.set_xlabel(r'Delay [$\mu$s]')
ax2.set_ylabel(r'$|\widetilde{V}|$ [Jy Hz]')

ax2.text(0.075, 0.9, 'Band 1', fontsize=12, ha='left', color='orange', transform=ax2.transAxes)


# Band 2 FFT
ax3 = plt.subplot(gs[1, 1], sharey=ax2)

vft2 = scipy.fft.fft(d[:, b2]*w2)
dly2 = scipy.fft.fftfreq(b2.size, hd.channel_width)

dly_sort2 = np.argsort(dly2)
dly2 = dly2[dly_sort2]
vft2 = vft2[:, dly_sort2]

ax3.plot(dly2*1e6, np.abs(vft2.T), alpha=0.5)
ax3.get_yaxis().set_visible(False)
ax3.set_xlabel(r'Delay [$\mu$s]')

ax3.text(0.075, 0.9, 'Band 2', fontsize=12, ha='left', color='green', transform=ax3.transAxes)


# FFT of BH window
iax = inset_axes(ax3, width='30%', height='35%', loc=1)
iax.tick_params(axis='both', labelsize=6)

w = signal.windows.blackmanharris(b1.size)
wft = scipy.fft.fft(w, 2**12)  # increase sampling on FFT end
response = np.abs(scipy.fft.fftshift(wft / np.abs(wft).max()))  # normalize
response = 20 * np.log10(response)  # as amplitude^2 propto power and dB measures power
freq = np.linspace(-b1.size/2, b1.size/2, len(wft))  # convert to frequency bins

iax.plot(freq, response, lw=1)

iax.text(0.1, 0.8, r'$|\widetilde{W}|$', fontsize=8, ha='left', transform=iax.transAxes)
iax.set_ylabel('dB', fontsize=8)
iax.set_xlabel('Delay Bin', fontsize=8)
iax.set_xlim(-15, 15)
iax.set_ylim(-130, 5)


# save_fig_dir = '/lustre/aoc/projects/hera/mmolnar/figs'
# plt.savefig(os.path.join(save_fig_dir, 'BH_window.pdf'), bbox_inches='tight')


plt.show()

In [None]:
fig, ax = plt.subplots(figsize=(6, 4), dpi=125)

w = signal.windows.blackmanharris(b1.size)
wft = scipy.fft.fft(w, n=2**13)  # increase sampling on FFT end
response = np.abs(scipy.fft.fftshift(wft / np.abs(wft).max()))  # normalize
response = 20 * np.log10(response)  # as amplitude^2 propto power and dB measures power
freq = np.linspace(-b1.size/2, b1.size/2, len(wft))  # convert to frequency bins

ax.plot(freq, response)

ax.set_ylabel('dB')
ax.set_xlabel('Delay Bin')
ax.set_xlim(-40, 40)
ax.set_ylim(-130, 5)

plt.tight_layout()
plt.show()

### Morlet wavefunction

In [None]:
wav = pywt.ContinuousWavelet(wavelet)  #'cmor1.5-1.0'
fun_wav, time = wav.wavefun(length=int(1e5))

fig, axes = plt.subplots(ncols=2, figsize=(8, 4), dpi=125)

axes[0].set_title('Wavelet Function')
axes[1].set_title('Frequency Support')
plt.suptitle(r'Complex Morlet Wavelet with $B=1.5$ and $C=1$')


# Wavelet function
axes[0].plot(time, fun_wav.real, label=r"$\mathfrak{Re}$")
axes[0].plot(time, fun_wav.imag, 'r-', label=r"$\mathfrak{Im}$")
axes[0].set_xlabel('Time [s]')
axes[0].legend(loc='best')
axes[0].set_xlim(-4, 4)


# Frequency support
axes[1].set_xlabel('Frequency [Hz]')

wt = scipy.fft.fftshift(scipy.fft.fft(fun_wav, n=int(1e7)))  # increase sampling on FFT end
nrm = wt.max()
wt /= nrm
df = np.median(np.ediff1d(time))
wt_frqs = scipy.fft.fftshift(scipy.fft.fftfreq(fun_wav.size, df))
wt_frqs = np.interp(np.arange(wt.size)/wt.size, np.arange(wt_frqs.size)/wt_frqs.size, wt_frqs)
# wt_frqs = np.linspace(wt_frqs[0], wt_frqs[-1]+df, wt.size)  # convert to frequency bins
axes[1].plot(wt_frqs, np.abs(wt))
axes[1].set_xlim(0, 2)

axes[1].axvline(wav.center_frequency, color='orange')
# # Here tried to calculate FWHM but issues with scaling?
# xstd = np.sqrt(1/wav.bandwidth_frequency)# * fun_wav.size / wt.size
# fwhm = 2* (2 * np.log(2))**0.5 * xstd
# axes[1].arrow(wav.center_frequency-xstd/2, 0.5, xstd, 0)

fig.tight_layout()

# save_fig_dir = '/Users/matyasmolnar/Desktop/Thesis/CHAP-5/FIGS'
# plt.savefig(os.path.join(save_fig_dir, 'morlet_wavelet2.pdf'), bbox_inches='tight')

plt.show()

### Heisinberg boxes

In [None]:
fig, ax = plt.subplots(figsize=(5, 5), dpi=125, sharey=True)

c1 = 'darkorchid'
c2 = 'darkorange'

# Mother wavelet
wav = pywt.ContinuousWavelet(wavelet)  # 'cmor1.5-1.0'
m_fun_wav, m_time = wav.wavefun(length=int(1e5))
m_fun_wav /= m_fun_wav.max() * 2
m_fun_wav[np.abs(m_fun_wav) < 1e-3] *= np.nan


# 1st wavelet
shift1 = 10
fun_wav1 = m_fun_wav.copy()
time1 = m_time + shift1
ax.plot(time1, fun_wav1.real, label=r'', color=c1)

wt = scipy.fft.fftshift(scipy.fft.fft(fun_wav, n=int(1e6)))  # increase sampling on FFT end
nrm = wt.max()
wt_nrm = wt/nrm
df = np.median(np.ediff1d(m_time))
wt_frqs = scipy.fft.fftshift(scipy.fft.fftfreq(fun_wav1.size, df))
wt_frqs = np.interp(np.arange(wt.size)/wt.size, np.arange(wt_frqs.size)/wt_frqs.size, wt_frqs)
wt_nrm[np.abs(wt_nrm) < 1e-2] *= np.nan
ax.plot(np.abs(wt_nrm), wt_frqs, color=c1)

# idxs1 = np.isnan(m_fun_wav)
# ifw = m_fun_wav[~idxs1]
# it = m_time[~idxs1]
# idxs2 = np.isnan(wt_nrm)
# iwt = wt_nrm[~idxs2]
# ifr = wt_frqs[~idxs2]

# stdt1 = np.sqrt(scipy.integrate.simpson(np.abs(ifw*np.conj(ifw)) * it**2, it))
# stdf1 = np.sqrt(scipy.integrate.simpson(np.abs(iwt*np.conj(iwt)) * ifr**2, ifr))

# rect = patches.Rectangle((10-stdt1/2, 1-stdf1/2), stdt1, stdf1, linewidth=1, edgecolor='r', facecolor='none')
# ax.add_patch(rect)


# 2nd wavelet
shift2 = 4
scale = 1/4
fun_wav2 = m_fun_wav.copy()
time2 = (scale*m_time + shift2)
ax.plot(time2, fun_wav2.real, label=r'', color=c2)

wt = scipy.fft.fftshift(scipy.fft.fft(fun_wav, n=int(1e6)))  # increase sampling on FFT end
nrm = wt.max()
wt_nrm = wt/nrm
df = np.median(np.ediff1d(scale*m_time))
wt_frqs = scipy.fft.fftshift(scipy.fft.fftfreq(fun_wav.size, df))
wt_frqs = np.interp(np.arange(wt.size)/wt.size, np.arange(wt_frqs.size)/wt_frqs.size, wt_frqs)
wt_nrm[np.abs(wt_nrm) < 1e-2] *= np.nan
ax.plot(np.abs(wt_nrm), wt_frqs, color=c2)


ax.set_xlabel('Time')
ax.set_ylabel('Frequency')

# ax.spines['left'].set_position('zero')
ax.spines['bottom'].set_position('zero')

ax.hlines(y=0, xmin=0, xmax=14, color='k', zorder=-1, lw=1)
ax.vlines(x=0, ymin=0, ymax=6.5, color='k', zorder=-1, lw=1)

ax.hlines(y=1, xmin=0, xmax=shift1, color='grey', zorder=-1, lw=1, ls='--')
ax.vlines(x=shift1, ymin=0, ymax=1, color='grey', zorder=-1, lw=1, ls='--')

ax.hlines(y=1/scale, xmin=0, xmax=shift2, color='grey', zorder=-1, lw=1, ls='--')
ax.vlines(x=shift2, ymin=0, ymax=1/scale, color='grey', zorder=-1, lw=1, ls='--')

ax.spines['left'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.spines['top'].set_color('none')
ax.spines['right'].set_color('none')

ax.set_ylim(-0.5, 6.5)
ax.set_xlim(-0.5, 14)

ax.xaxis.set_ticks([])
ax.yaxis.set_ticks([])

fig.tight_layout()

# save_fig_dir = '/Users/matyasmolnar/Desktop/Thesis/CHAP-5/FIGS'
# plt.savefig(os.path.join(save_fig_dir, 'heisinberg_boxes.pdf'), bbox_inches='tight')

plt.show()

### FT, STFT, CWT illustration

In [None]:
fig, axes = plt.subplots(ncols=3, figsize=(8, 4), dpi=150, sharey=True)

for ax in axes:
    ax.set_xlabel('Time')
    ax.xaxis.set_ticks([])
    ax.yaxis.set_ticks([])
    ax.set_aspect('equal')
    

# DFT
for f in np.linspace(0, 1, 17)[1:-1]:
    axes[0].axhline(f)
    

# STFT
for f in np.linspace(0, 1, 9)[1:-1]:
    axes[1].axhline(f)
    axes[1].axvline(f)
    
    
# CWT
r = np.linspace(0, 1, 17)[1:-1]
for f in r:
    axes[2].axvline(f, ymin=0.5, ymax=1)
axes[2].axhline(0.5)
for f in r[1::2]:
    axes[2].axvline(f, ymin=0.5/2, ymax=1/2)
axes[2].axhline(0.25)
for f in r[3::4]:
    axes[2].axvline(f, ymin=0.5/4, ymax=1/4)
axes[2].axhline(0.125)
for f in r[7::8]:
    axes[2].axvline(f, ymin=0.5/8, ymax=1/8)
axes[2].axhline(0.125/2)


axes[0].set_ylabel('Frequency')
axes[0].set_title('FT')
axes[1].set_title('STFT')
axes[2].set_title('WT')

fig.tight_layout()

# save_fig_dir = '/Users/matyasmolnar/Desktop/Thesis/CHAP-5/FIGS'
# plt.savefig(os.path.join(save_fig_dir, 'wt_illustration.pdf'), bbox_inches='tight')

plt.show()