In [3]:
# Using the PyWavelets module, available at 
# https://pywavelets.readthedocs.io/en/latest/install.html

from matplotlib.image import imread
import numpy as np
import matplotlib.pyplot as plt
import os
import pywt
plt.rcParams['figure.figsize'] = [8, 8]
plt.rcParams.update({'font.size': 18})

A = imread(os.path.join('..','DATA','dog.jpg'))
B = np.mean(A, -1); # Convert RGB to grayscale

In [2]:
import torch
from pytorch_wavelets import DWTForward, DWTInverse # (or import DWT, IDWT)
xfm = DWTForward(J=3, mode='zero', wave='db3')  # Accepts all wave types available to PyWavelets
ifm = DWTInverse(mode='zero', wave='db3')
X = torch.randn(10,5,64,64)
Yl, Yh = xfm(X)
print(Yl.shape)
#>>> torch.Size([10, 5, 12, 12])
print(Yh[0].shape)
#>>> torch.Size([10, 5, 3, 34, 34])
print(Yh[1].shape)
#>>> torch.Size([10, 5, 3, 19, 19])
print(Yh[2].shape)
#>>> torch.Size([10, 5, 3, 12, 12])
Y = ifm((Yl, Yh))
import numpy as np
np.testing.assert_array_almost_equal(Y.cpu().numpy(), X.cpu().numpy())

torch.Size([10, 5, 12, 12])
torch.Size([10, 5, 3, 34, 34])
torch.Size([10, 5, 3, 19, 19])
torch.Size([10, 5, 3, 12, 12])


In [14]:
def to_image(t):
    return t[0].permute(1, 2, 0).byte().numpy()

def to_tensor(img):
    return torch.tensor(A).float().permute(2, 0, 1)[None]

image_t =  to_tensor(B)
image_t.shape, to_image(image_t).shape

(torch.Size([1, 3, 2000, 1500]), (2000, 1500, 3))

In [15]:
level, wave = 4, 'db3'

xfm = DWTForward(J=level, mode='zero', wave=wave)
ifm = DWTInverse(mode='zero', wave=wave)
xfm, ifm

(DWTForward(), DWTInverse())

In [17]:
Yl, Yh = xfm(image_t)
Yl.shape#, Yh

torch.Size([1, 3, 129, 98])

In [20]:
[l.shape for l in  Yh]

[torch.Size([1, 3, 3, 1002, 752]),
 torch.Size([1, 3, 3, 503, 378]),
 torch.Size([1, 3, 3, 254, 191]),
 torch.Size([1, 3, 3, 129, 98])]

In [1]:
## Wavelet Compression
n = 4
w = 'db20'
coeffs = pywt.wavedec2(B,wavelet=w,level=n)

coeff_arr, coeff_slices = pywt.coeffs_to_array(coeffs)
#print([f.shape for f in coeff_slices])
Csort = np.sort(np.abs(coeff_arr.reshape(-1)))

for keep in (1, 0.1, 0.05, 0.01, 0.005):
    thresh = Csort[int(np.floor((1-keep)*len(Csort)))]    
    ind = np.abs(coeff_arr) > thresh
    Cfilt = coeff_arr * ind # Threshold small indices
    
    coeffs_filt = pywt.array_to_coeffs(Cfilt,coeff_slices,
        output_format='wavedec2')
    
    # Plot reconstruction
    Arecon = pywt.waverec2(coeffs_filt,wavelet=w)
    plt.figure()
    plt.imshow(Arecon.astype('uint8'),cmap='gray')
    plt.axis('off')
    plt.title('keep = ' + str(keep))

NameError: name 'pywt' is not defined

In [22]:
coeff_slices


[(slice(None, 161, None), slice(None, 130, None)),
 {'ad': (slice(None, 161, None), slice(130, 260, None)),
  'da': (slice(161, 322, None), slice(None, 130, None)),
  'dd': (slice(161, 322, None), slice(130, 260, None))},
 {'ad': (slice(None, 284, None), slice(260, 481, None)),
  'da': (slice(322, 606, None), slice(None, 221, None)),
  'dd': (slice(322, 606, None), slice(260, 481, None))},
 {'ad': (slice(None, 529, None), slice(481, 885, None)),
  'da': (slice(606, 1135, None), slice(None, 404, None)),
  'dd': (slice(606, 1135, None), slice(481, 885, None))},
 {'ad': (slice(None, 1019, None), slice(885, 1654, None)),
  'da': (slice(1135, 2154, None), slice(None, 769, None)),
  'dd': (slice(1135, 2154, None), slice(885, 1654, None))}]