In [None]:
from dtcwt.coeffs import biort, qshift
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

from mpl_styles import list_sizes
import add_labels
import helpers

# DTCWT biort filters

In [None]:
wavelets = ('antonini', 'legall', 'near_sym_a', 'near_sym_b')
descriptions = (
    "Antonini 9,7 tap filters",
    "LeGall 5,3 tap filters",
    "Near-Symmetric 5,7 tap filters",
    "Near-Symmetric 13,19 tap filters",
    "Near-Symmetric 13,19 tap filters + BP filter",
)
filter_names = ('h0o', 'g0o', 'h1o', 'g1o')
filt_notations = ('h0R', '', 'h1R', '')
for wavelet, desc in zip(wavelets, descriptions):
    filts = [(f, n, m) for f, n, m in zip(biort(wavelet), filter_names, filt_notations) if n.startswith('h')]
    fig, axs = plt.subplots(2, len(filts))
    fig.suptitle(f'{wavelet} - {desc}')
    axs = axs.T
    for ax, ax2, (filt, filt_name, mnot) in zip(axs[0], axs[1], filts):
        filt = filt[:, 0]
        
        ax.set_title(mnot)
        [line] = ax.plot(np.arange(len(filt)), filt)
        add_labels.sample(ax)

        ax2.set_title(f'{mnot} - АЧХ')
        [line] = ax2.plot(*helpers.fourier_image(filt, 1/len(filt)))
        add_labels.fourier(ax2)
        
    fig.tight_layout()
    fig.savefig(f'out/dtcwt_filters/{wavelet}.png')

# DTCWT qshift filters

In [None]:
filter_names = ('h0Re', 'h0Im', 'g0a', 'g0b', 'h1Re', 'h1Im', 'g1a', 'g1b')
wavelet_names = ('qshift_06', 'qshift_a', 'qshift_b', 'qshift_b', 'qshift_c', 'qshift_d')
wavelet_descriptions = (
    "Quarter Sample Shift Orthogonal (Q-Shift) 10,10 tap filters, (only 6,6 non-zero taps).",
    "Q-shift 10,10 tap filters, (with 10,10 non-zero taps, unlike qshift_06).",
    "Q-Shift 14,14 tap filters.",
    "Q-Shift 16,16 tap filters.",
    "Q-Shift 18,18 tap filters.",
)

for wavelet, desc in zip(wavelet_names, wavelet_descriptions):
    filts = [(f, n) for f, n in zip(qshift(wavelet), filter_names) if n.startswith('h')]
    fig, axs = plt.subplots(len(filts), 2, figsize=list_sizes['A4'])
    fig.suptitle(f'{wavelet} - {desc}')
    axs = axs.T
    print(wavelet)
    for ax, ax2, (filt, filt_name) in zip(axs[0], axs[1], filts):
        filt = filt[:, 0]
        ax.set_title(filt_name)
        [line] = ax.plot(np.arange(len(filt)), filt)
        add_labels.sample(ax)

        ax2.set_title(f'{filt_name} - АЧХ')
        [line] = ax2.plot(*helpers.fourier_image(filt, 1/len(filt)))
        add_labels.fourier(ax2)

    fig.tight_layout()
    fig.savefig(f'out/dtcwt_filters/{wavelet}.png')

# Sigmoid

In [None]:
x = np.linspace(-6, 6, 1000)
plt.figure(figsize=list_sizes['A4/2']*np.array([1/2, 1/2]))
plt.plot(x, tf.math.sigmoid(x))
plt.xlabel('x')
plt.ylabel('sigmoid(x)')
plt.tight_layout()
# plt.savefig(OUT_PATH / 'sigmoid.png')

# ReLU

In [None]:
plt.figure(figsize=list_sizes['A4/2']*np.array([1/2, 1/2]))
plt.plot(x, tf.nn.relu(x))
plt.xlabel('x')
plt.ylabel('relu(x)')
plt.tight_layout()
# plt.savefig(OUT_PATH / 'relu.png')