<center><strong><font size=+3>Wavelet illustrations</font></center>
<br><br>
</center>
<center><strong><font size=+2>Matyas Molnar and Bojan Nikolic</font><br></strong></center>
<br><center><strong><font size=+1>Astrophysics Group, Cavendish Laboratory, University of Cambridge</font></strong></center>

In [None]:
import os

import numpy as np
import scipy
from matplotlib import pyplot as plt

import pywt

In [None]:
%matplotlib inline

In [None]:
from matplotlib import rc
rc('font',**{'family':'serif','serif':['cm']})
rc('text', usetex=True)
rc('text.latex', preamble=r'\usepackage{amssymb} \usepackage{amsmath}')

In [None]:
save_fig_dir = '/lustre/aoc/projects/hera/mmolnar/figs'
if not os.path.exists(save_fig_dir):
    save_fig_dir = '/Users/matyasmolnar/Desktop/Thesis/CHAP-5/FIGS'

### Heisenberg boxes

In [None]:
# choose wavelet
wavelet = 'cmor1.5-1.0'

In [None]:
fig, ax = plt.subplots(figsize=(5, 5), dpi=125, sharey=True)

c1 = 'darkorchid'
c2 = 'darkorange'

# Mother wavelet
wav = pywt.ContinuousWavelet(wavelet)  # 'cmor1.5-1.0'
m_fun_wav, m_time = wav.wavefun(length=int(1e5))
fun_wav_c = m_fun_wav.copy()
m_fun_wav /= m_fun_wav.max() * 2
m_fun_wav[np.abs(m_fun_wav) < 1e-3] *= np.nan


# 1st wavelet
shift1 = 10
time1 = m_time + shift1
ax.plot(time1, m_fun_wav.real, label=r'', color=c1)

wt = scipy.fft.fftshift(scipy.fft.fft(fun_wav_c, n=int(1e6)))  # increase sampling on FFT end
nrm = wt.max()
wt_nrm = wt/nrm
df = np.median(np.ediff1d(m_time))
wt_frqs = scipy.fft.fftshift(scipy.fft.fftfreq(fun_wav_c.size, df))
wt_frqs = np.interp(np.arange(wt.size)/wt.size, np.arange(wt_frqs.size)/wt_frqs.size, wt_frqs)
wt_nrm[np.abs(wt_nrm) < 1e-2] *= np.nan
ax.plot(np.abs(wt_nrm), wt_frqs, color=c1)

# idxs1 = np.isnan(m_fun_wav)
# ifw = m_fun_wav[~idxs1]
# it = m_time[~idxs1]
# idxs2 = np.isnan(wt_nrm)
# iwt = wt_nrm[~idxs2]
# ifr = wt_frqs[~idxs2]

# stdt1 = np.sqrt(scipy.integrate.simpson(np.abs(ifw*np.conj(ifw)) * it**2, it))
# stdf1 = np.sqrt(scipy.integrate.simpson(np.abs(iwt*np.conj(iwt)) * ifr**2, ifr))

# rect = patches.Rectangle((10-stdt1/2, 1-stdf1/2), stdt1, stdf1, linewidth=1, edgecolor='r', facecolor='none')
# ax.add_patch(rect)


# 2nd wavelet
shift2 = 4
scale = 1/4
time2 = (scale*m_time + shift2)
ax.plot(time2, m_fun_wav.real, label=r'', color=c2)

wt = scipy.fft.fftshift(scipy.fft.fft(fun_wav_c, n=int(1e6)))  # increase sampling on FFT end
nrm = wt.max()
wt_nrm = wt/nrm
df = np.median(np.ediff1d(scale*m_time))
wt_frqs = scipy.fft.fftshift(scipy.fft.fftfreq(m_fun_wav.size, df))
wt_frqs = np.interp(np.arange(wt.size)/wt.size, np.arange(wt_frqs.size)/wt_frqs.size, wt_frqs)
wt_nrm[np.abs(wt_nrm) < 1e-2] *= np.nan
ax.plot(np.abs(wt_nrm), wt_frqs, color=c2)


ax.set_xlabel('Time')
ax.set_ylabel('Frequency')

# ax.spines['left'].set_position('zero')
ax.spines['bottom'].set_position('zero')

ax.hlines(y=0, xmin=0, xmax=14, color='k', zorder=-1, lw=1)
ax.vlines(x=0, ymin=0, ymax=6.5, color='k', zorder=-1, lw=1)

ax.hlines(y=1, xmin=0, xmax=shift1, color='grey', zorder=-1, lw=1, ls='--')
ax.vlines(x=shift1, ymin=0, ymax=1, color='grey', zorder=-1, lw=1, ls='--')

ax.hlines(y=1/scale, xmin=0, xmax=shift2, color='grey', zorder=-1, lw=1, ls='--')
ax.vlines(x=shift2, ymin=0, ymax=1/scale, color='grey', zorder=-1, lw=1, ls='--')

ax.spines['left'].set_color('none')
ax.spines['bottom'].set_color('none')
ax.spines['top'].set_color('none')
ax.spines['right'].set_color('none')

ax.set_ylim(-0.5, 6.5)
ax.set_xlim(-0.5, 14)

ax.xaxis.set_ticks([])
ax.yaxis.set_ticks([])

fig.tight_layout()
# plt.savefig(os.path.join(save_fig_dir, 'heisinberg_boxes.pdf'), bbox_inches='tight')
plt.show()

### FT, STFT, CWT illustration

In [None]:
fig, axes = plt.subplots(ncols=3, figsize=(8, 4), dpi=150, sharey=True)

for ax in axes:
    ax.set_xlabel('Time')
    ax.xaxis.set_ticks([])
    ax.yaxis.set_ticks([])
    ax.set_aspect('equal')


# DFT
for f in np.linspace(0, 1, 17)[1:-1]:
    axes[0].axhline(f)


# STFT
for f in np.linspace(0, 1, 9)[1:-1]:
    axes[1].axhline(f)
    axes[1].axvline(f)


# CWT
r = np.linspace(0, 1, 17)[1:-1]
for f in r:
    axes[2].axvline(f, ymin=0.5, ymax=1)
axes[2].axhline(0.5)
for f in r[1::2]:
    axes[2].axvline(f, ymin=0.5/2, ymax=1/2)
axes[2].axhline(0.25)
for f in r[3::4]:
    axes[2].axvline(f, ymin=0.5/4, ymax=1/4)
axes[2].axhline(0.125)
for f in r[7::8]:
    axes[2].axvline(f, ymin=0.5/8, ymax=1/8)
axes[2].axhline(0.125/2)


axes[0].set_ylabel('Frequency')
axes[0].set_title('FT')
axes[1].set_title('STFT')
axes[2].set_title('WT')

fig.tight_layout()
# plt.savefig(os.path.join(save_fig_dir, 'wt_illustration.pdf'), bbox_inches='tight')
plt.show()