# Dependencies

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import scipy as sp

# local dependencies
from utils.dct import (
    cosine_basis_1d,
    dct,
    idct,
    cosine_basis_2d,
    dct2,
    idct2
)

from utils.filters import (
    triangle,
    rectangle,
    block
)

# Load an image

In [4]:
# load cameraman.tif
cm = plt.imread("../assets/images/dip_3rd/CH02_Fig0222(b)(cameraman).tif")

# Discrete Cosine Transform (DCT)

In [None]:
signal_length = 8
dct_basis_vectors = cosine_basis_1d(signal_length)

# plot
fig, axs = plt.subplots(nrows=1, ncols=signal_length, figsize=(16, 2), layout='compressed')
fig.suptitle("1D DCT basis vectors")

for i in range(signal_length):
    axs[i].imshow(dct_basis_vectors[i].reshape(1, -1), cmap='gray', vmin=dct_basis_vectors.min(), vmax=dct_basis_vectors.max())
    axs[i].set_title(f"basis {i}")
    axs[i].set_xticks(range(signal_length))
    axs[i].set_yticks([])

plt.show()

In [None]:
image_length = (8, 8)
dct2_basis_images = cosine_basis_2d(image_length[0])

# plot
fig, axs = plt.subplots(nrows=image_length[0], ncols=image_length[1], figsize=(16, 18), layout='compressed')
fig.suptitle("2D DCT basis Images")

for i in range(image_length[0]):
    for j in range(image_length[1]):
        axs[i, j].imshow(dct2_basis_images[i, j], cmap='gray', vmin=dct2_basis_images.min(), vmax=dct2_basis_images.max())
        axs[i, j].set_title(f"basis: {i},{j}")
        axs[i, j].set_yticks([])
        axs[i, j].set_xticks([])

plt.show()

## Reconstructing a 1d signal

In [None]:
# 1d array of length 5 (N=5 which is considered as the period of the signal)
arr1 = np.array([3, 0, 2, 1, 5])

# DCT : transform signal from spatial domain to frequency domain
dct_arr1 = dct(arr1)

# IDCT: transform signal from frequency domain to spatial domain
idct_arr1 = idct(dct_arr1)

# clip the signal values in range(0, 5) [in this case we know that for example our original signal must have values in this range]
idct_arr1 = idct_arr1.clip(0, 5)

# change dtype from float to int (original signal had integer values)
idct_arr1 = np.round(idct_arr1).astype(np.int32)

# plot
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(12, 2), layout='compressed')
fig.suptitle("Reconstructing 1D signal")

axs[0].imshow(arr1.reshape(1, -1), cmap='gray')
axs[0].set_title("Original signal")
axs[1].imshow(dct_arr1.reshape(1, -1), cmap='gray')
axs[1].set_title("Magnitude in frequency domain")
axs[2].imshow(idct_arr1.reshape(1, -1), cmap='gray')
axs[2].set_title("Reconstructed signal")

for ax in fig.axes:
    ax.set_yticks([])

plt.show()

## Reconstructing a 2d signal

In [None]:
# 2d array of length 3x3
arr2 = np.array([[1, 1, 3], [2, 1, 2], [2, 3, 3]])

# DCT : transform signal from spatial domain to frequency domain
dct_arr2 = dct2(arr2)

# IDCT: transform signal from frequency domain to spatial domain
idct_arr2 = idct2(dct_arr2)

# clip the signal values in range(1, 3) [in this case we know that for example our original signal must have values in this range]
idct_arr2 = idct_arr2.clip(1, 3)

# change dtype from float to int (original signal had integer values)
idct_arr2 = np.round(idct_arr2).astype(np.int32)

# plot
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(12, 5), layout='compressed')
fig.suptitle("Reconstructing 2D signal")

axs[0].imshow(arr2, cmap='gray')
axs[0].set_title("Original signal")
axs[1].imshow(np.abs(dct_arr2), cmap='gray')
axs[1].set_title("Magnitude in frequency domain")
axs[2].imshow(idct_arr2, cmap='gray')
axs[2].set_title("Reconstructed signal")

for ax in fig.axes:
    ax.set_xticks(range(3))
    ax.set_yticks(range(3))

plt.show()

# Compression effect [DFT vs DCT]

In [None]:
# blocking the image
block_size = 8
cm_total_chunks = (cm.shape[0] // block_size, cm.shape[1] // block_size)

# dft & dct [plus compression]
rec_mask = rectangle(block_size, 2)

dft_coef_per_block = np.zeros(shape=(*cm_total_chunks, block_size, block_size), dtype=np.complex128)
dct_coef_per_block = np.zeros(shape=(*cm_total_chunks, block_size, block_size))

for i in range(cm_total_chunks[0]):
    for j in range(cm_total_chunks[1]):
        cm_block = cm[i * block_size: (i + 1) * block_size, j * block_size: (j + 1) * block_size]
        dft_coef_per_block[i, j] = np.fft.fft2(cm_block, norm='ortho') * rec_mask
        dct_coef_per_block[i, j] = sp.fftpack.dctn(cm_block, norm='ortho') * rec_mask

# idft & idct
idft_cm = np.zeros(shape=cm.shape)
idct_cm = np.zeros(shape=cm.shape)

for i in range(cm_total_chunks[0]):
    for j in range(cm_total_chunks[1]):
        idft_cm[i * block_size: (i + 1) * block_size, j * block_size: (j + 1) * block_size] = np.fft.ifft2(dft_coef_per_block[i, j], norm='ortho').real.clip(0, 255).astype(np.uint8)
        idct_cm[i * block_size: (i + 1) * block_size, j * block_size: (j + 1) * block_size] = sp.fftpack.idctn(dct_coef_per_block[i, j], norm='ortho').clip(0, 255).astype(np.uint8)

# plot
fig, axs = plt.subplots(nrows=1, ncols=3, figsize=(16, 5), layout='compressed')

axs[0].imshow(cm[20:120, 70:180], cmap='gray')
axs[0].set_title("Original image [zoomed]")
axs[1].imshow(idft_cm[20:120, 70:180], cmap='gray')
axs[1].set_title("Reconstructed dft [zoomed]")
axs[2].imshow(idct_cm[20:120, 70:180], cmap='gray')
axs[2].set_title("Reconstructed dct [zoomed]")

for ax in fig.axes:
    ax.axis('off')

plt.show()

# frequency modification [Zonal Masking]

In [None]:
# dct2
dct_cm = sp.fftpack.dctn(cm, norm='ortho')

# masks
tri_mask = triangle(dct_cm.shape[0], 1.5)
rec_mask = rectangle(dct_cm.shape[0], 50)
blk_mask = block(dct_cm.shape[0], 25)

# signal * mask
dct_cm_1 = dct_cm * tri_mask
dct_cm_2 = dct_cm * rec_mask
dct_cm_3 = dct_cm * blk_mask

# magnitude
abs_dct_cm = np.abs(dct_cm)
abs_dct_cm_1 = np.abs(dct_cm_1)
abs_dct_cm_2 = np.abs(dct_cm_2)
abs_dct_cm_3 = np.abs(dct_cm_3)

# reconstruction
idct_dct_cm = sp.fftpack.idctn(dct_cm, norm='ortho').clip(0, 255)
idct_dct_cm_1 = sp.fftpack.idctn(dct_cm_1, norm='ortho').clip(0, 255)
idct_dct_cm_2 = sp.fftpack.idctn(dct_cm_2, norm='ortho').clip(0, 255)
idct_dct_cm_3 = sp.fftpack.idctn(dct_cm_3, norm='ortho').clip(0, 255)

# difference
diff_idct_dct_cm = cm - idct_dct_cm
diff_idct_dct_cm_1 = cm - idct_dct_cm_1
diff_idct_dct_cm_2 = cm - idct_dct_cm_2
diff_idct_dct_cm_3 = cm - idct_dct_cm_3

# plot
fig, axs = plt.subplots(nrows=3, ncols=4, figsize=(16, 12), layout='compressed')

axs[0, 0].imshow(np.log2(abs_dct_cm + 1), cmap='gray')
axs[0, 0].set(title='Magnitude [abs_dct_cm]')
axs[0, 1].imshow(np.log2(abs_dct_cm_1 + 1), cmap='gray')
axs[0, 1].set(title='Magnitude [abs_dct_cm_1]')
axs[0, 2].imshow(np.log2(abs_dct_cm_2 + 1), cmap='gray')
axs[0, 2].set(title='Magnitude [abs_dct_cm_2]')
axs[0, 3].imshow(np.log2(abs_dct_cm_3 + 1), cmap='gray')
axs[0, 3].set(title='Magnitude [abs_dct_cm_3]')
axs[1, 0].imshow(idct_dct_cm, cmap='gray')
axs[1, 0].set(title='Reconstruct')
axs[1, 1].imshow(idct_dct_cm_1, cmap='gray')
axs[1, 1].set(title='Reconstruct')
axs[1, 2].imshow(idct_dct_cm_2, cmap='gray')
axs[1, 2].set(title='Reconstruct')
axs[1, 3].imshow(idct_dct_cm_3, cmap='gray')
axs[1, 3].set(title='Reconstruct')
axs[2, 0].imshow(diff_idct_dct_cm, cmap='gray')
axs[2, 0].set(title='Difference from original')
axs[2, 1].imshow(diff_idct_dct_cm_1, cmap='gray')
axs[2, 1].set(title='Difference from original')
axs[2, 2].imshow(diff_idct_dct_cm_2, cmap='gray')
axs[2, 2].set(title='Difference from original')
axs[2, 3].imshow(diff_idct_dct_cm_3, cmap='gray')
axs[2, 3].set(title='Difference from original')

for ax in fig.axes:
    ax.axis('off')

plt.show()

In [None]:
# dct2
dct_cm = sp.fftpack.dctn(cm, norm='ortho')

# masks
mask_1 = np.zeros(shape=dct_cm.shape)
mask_1[:dct_cm.shape[0] // 2, :] = 1
mask_2 = np.zeros(shape=dct_cm.shape)
mask_2[:, :dct_cm.shape[1] // 2] = 1

# signal * mask
dct_cm_mask_1 = dct_cm * mask_1
dct_cm_mask_2 = dct_cm * mask_2

# magnitude
abs_dct_cm = np.abs(dct_cm)
abs_dct_cm_mask_1 = np.abs(dct_cm_mask_1)
abs_dct_cm_mask_2 = np.abs(dct_cm_mask_2)

# reconstruction
idct_dct_cm = sp.fftpack.idctn(dct_cm, norm='ortho').clip(0, 255)
idct_dct_cm_mask_1 = sp.fftpack.idctn(dct_cm_mask_1, norm='ortho').clip(0, 255)
idct_dct_cm_mask_2 = sp.fftpack.idctn(dct_cm_mask_2, norm='ortho').clip(0, 255)

# difference
diff_idct_dct_cm = cm - idct_dct_cm
diff_idct_dct_cm_mask_1 = cm - idct_dct_cm_mask_1
diff_idct_dct_cm_mask_2 = cm - idct_dct_cm_mask_2

# plot
fig, axs = plt.subplots(nrows=3, ncols=3, figsize=(10, 10), layout='compressed')

axs[0, 0].imshow(np.log2(abs_dct_cm + 1), cmap='gray')
axs[0, 0].set(title='Magnitude [abs_dct_cm]')
axs[1, 0].imshow(np.log2(abs_dct_cm_mask_1 + 1), cmap='gray')
axs[1, 0].set(title='Magnitude [abs_dct_cm_mask_1]')
axs[2, 0].imshow(np.log2(abs_dct_cm_mask_2 + 1), cmap='gray')
axs[2, 0].set(title='Magnitude [abs_dct_cm_mask_2]')

axs[0, 1].imshow(idct_dct_cm, cmap='gray')
axs[0, 1].set(title='Reconstruct [idct_dct_cm]')
axs[1, 1].imshow(idct_dct_cm_mask_1, cmap='gray')
axs[1, 1].set(title='Reconstruct [idct_dct_cm_mask_1]')
axs[2, 1].imshow(idct_dct_cm_mask_2, cmap='gray')
axs[2, 1].set(title='Reconstruct [idct_dct_cm_mask_2]')

axs[0, 2].imshow(diff_idct_dct_cm, cmap='gray')
axs[0, 2].set(title='Difference from original')
axs[1, 2].imshow(diff_idct_dct_cm_mask_1, cmap='gray')
axs[1, 2].set(title='Difference from original')
axs[2, 2].imshow(diff_idct_dct_cm_mask_2, cmap='gray')
axs[2, 2].set(title='Difference from original')


for ax in fig.axes:
    ax.axis('off')

plt.show()