In [None]:
import numpy as np
import pywt
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

from skimage import io

%matplotlib inline

def read_img(path):

    # Reading image file.
    img = io.imread(path).astype(np.float)

    # If RGB image, take only the green channel.
    if len(img.shape) > 2:

        img = img[:,:,1]

    return img

def print_iters(LL, n_iters, base='db1'):

    size = 16
    curr_size = size

    plt.figure(0, figsize=(size, size))

    # Wavelet iterations.
    for i in range(n_iters):

        curr_size = curr_size / 2

        LL, (LH, HL, HH) = pywt.dwt2(LL, wavelet=base, mode='symmetric')

        axis_hh = plt.subplot2grid((size, size), (curr_size, curr_size), colspan=curr_size, rowspan=curr_size)
        axis_hh.imshow(HH, cmap=plt.cm.gray)
        axis_hh.set_xticks([])
        axis_hh.set_yticks([])
        axis_hl = plt.subplot2grid((size, size), (0, curr_size), colspan=curr_size, rowspan=curr_size)
        axis_hl.imshow(HL, cmap=plt.cm.gray)
        axis_hl.set_xticks([])
        axis_hl.set_yticks([])
        axis_lh = plt.subplot2grid((size, size), (curr_size, 0), colspan=curr_size, rowspan=curr_size)
        axis_lh.imshow(LH, cmap=plt.cm.gray)
        axis_lh.set_xticks([])
        axis_lh.set_yticks([])

    axis_ll = plt.subplot2grid((size, size), (0, 0), colspan=curr_size, rowspan=curr_size)
    axis_ll.imshow(LL, cmap=plt.cm.gray)
    axis_ll.set_xticks([])
    axis_ll.set_yticks([])

    plt.show()

In [None]:
img = read_img('images/ball_100.png')

print_iters(img, 1)

In [None]:
img = read_img('images/ball_100.png')

print_iters(img, 2)

In [None]:
img = read_img('images/ball_100.png')

print_iters(img, 3)

In [None]:
img = read_img('images/ball_100.png')

print_iters(img, 4)

In [None]:
img = read_img('images/wavelet_demonstration.png')

print(pywt.families(short=False))
print(pywt.wavelist('bior'))

print_iters(img, 2, base = 'haar')

In [None]:
img = read_img('images/zebra.jpg')

print_iters(img, 4)

In [None]:
img = read_img('images/lichtenstein.png')

print_iters(img, 3)