In [None]:
import numpy as np
import scipy
import matplotlib.pyplot as plt

from skimage import io
from scipy import signal
from scipy import fftpack

def JPEG(img, factor):
    
    q_matrix = [[16, 11, 10, 16, 24, 40, 51, 61],
                [12, 12, 14, 19, 26, 58, 60, 55],
                [14, 13, 16, 24, 40, 57, 69, 56],
                [14, 17, 22, 29, 51, 87, 80, 62],
                [18, 22, 37, 56, 68, 109, 103, 77],
                [24, 35, 55, 64, 81, 104, 113, 92],
                [49, 64, 78, 87, 103, 121, 120, 101],
                [72, 92, 95, 98, 112, 100, 103, 99]]
    
    q_matrix = np.asarray(q_matrix, dtype=np.float)
    
    freq = np.asarray(img.copy(), dtype=np.float)
    quant = np.asarray(img.copy(), dtype=np.float)
    rev = np.asarray(img.copy(), dtype=np.float)

    for i in range(0, img.shape[0], 8):
        
        for j in range(0, img.shape[1], 8):

            i_patch = img[i:(i + 8), j:(j + 8)]

            # Forward DCT of spatial signal.
            f_patch = np.apply_along_axis(fftpack.dct, 0, i_patch, norm='ortho')
            f_patch = np.apply_along_axis(fftpack.dct, 1, f_patch, norm='ortho')
            
            freq[i:(i + 8), j:(j + 8)] = f_patch
            
            q_patch = np.round(f_patch / (q_matrix * factor)).astype(np.int)
            
            quant[i:(i + 8), j:(j + 8)] = q_patch
            
            # Inverse DCT of frequency signal.
            r_patch = np.apply_along_axis(fftpack.idct, 0, q_patch * (q_matrix * factor), norm='ortho')
            r_patch = np.apply_along_axis(fftpack.idct, 1, r_patch, norm='ortho')
            
            rev[i:(i + 8), j:(j + 8)] = r_patch

    return img, freq, quant, rev

# Reading image and initiating kernel.
img = io.imread('/home/hugo/PDI_Jupyter/images/lichtenstein.png')
factors = [0.05, 0.5, 1.0, 2.0, 5.0, 10.0, 50.0]

if len(img.shape) > 2:
    
    img = img[:, :, 1]

for factor in factors:
    
    print('factor: ' + str(factor))
    
    # JPEG transformation.
    img, freq, quant, rev = JPEG(img, factor)

    # Plotting.
    f, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))

    ax1.imshow(img, cmap=plt.cm.gray)
    ax1.set_title('Original Image')
    ax1.set_xticks([])
    ax1.set_yticks([])
    ax2.imshow(rev, cmap=plt.cm.gray)
    ax2.set_title('Reconstructed with factor ' + str(factor))
    ax2.set_xticks([])
    ax2.set_yticks([])

    plt.show()
    
    # Plotting.
    f, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8), sharex=True, sharey=True)

    ax1.hist(freq.ravel(), bins=1024, range=(-5, 5), fc='k', ec='k')
    ax1.set_title('Original DCT Histogram')
    ax2.hist(quant.ravel(), bins=1024, range=(-5, 5), fc='k', ec='k')
    ax2.set_title('Quantized Histogram with factor ' + str(factor))

    plt.show()