# Image Compression!

Welcome to the image-compression notebook

Let's get started by running some libraries (these are all standard libraries that should be installed already...) and functions

- `getRGB()` will get the pixel data of an image as a list-of-list-of-lists
- `set_size()` is a helper function for the next function
- `show_image()` uses the output from `getRGB()` to show an image
- `saveRGB()` also uses the output of `getRGB()` to save a new image
- `greyscaleHelp()` is a helper function for the next function
- `greyscale()` converts an image to greyscale and returns a list-of-lists of pixel data

In [None]:
# libraries!
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
import scipy

from matplotlib.pyplot import imread
from numpy import pi
from numpy import sin
from numpy import zeros
from numpy import r_
from scipy import signal
import matplotlib.pylab as pylab

In [None]:
def getRGB(filename):
    """ reads a png or jpg file like 'pitzer_grounds.jpg' (a string)
        returns the pixels as a list-of-lists-of-lists
        this is accessible, but not fast: Use small images!
    """
    original = Image.open(filename)
    print(f"Reading image from '{filename}':")
    print(f"  Format: {original.format}\n  Original Size: {original.size}\n  Mode: {original.mode}")
    WIDTH, HEIGHT = original.size
    px = original.load()
    PIXEL_LIST = []
    for r in range(HEIGHT):
        row = []
        for c in range(WIDTH):
            row.append( px[c,r][:3] )
        PIXEL_LIST.append( row )
    return PIXEL_LIST

In [None]:
def set_size(width, height, ax=None):
    """Sets the size of an image when printing in the notebook
       w, h: width, height in inches """
    w = width; h = height
    if not ax: 
        ax=plt.gca()  # gets current axes
    l = ax.figure.subplotpars.left
    r = ax.figure.subplotpars.right
    t = ax.figure.subplotpars.top
    b = ax.figure.subplotpars.bottom
    figw = float(w)/(r-l)
    figh = float(h)/(t-b)
    ax.figure.set_size_inches(figw, figh)

In [None]:
# wrapper for matplotlib's imshow function
def show_image( rgbdata, hgt_in=5.42, wid_in=5.42 ):
    """ shows an image whose pixels are in rgbdata 
        note:  rgbdata is a list-of-rows-of-pixels-of-rgb values, _not_ a filename!
            use getRGB() to get this data!
        hgt_in is the desired height (in inches)
        wid_in is the desired width (in inches)
            use set_size() with these parameters
        _but_ the library will not change the aspect ratio (takes the smaller)
        by default, the hgt_in and wid_in are 5.42 in.
        (and feel free to change these!)
    """
    fig, ax = plt.subplots()               # obtains the figure and axes objects
    
    if type(rgbdata[0][0]) == list:
        im = ax.imshow(rgbdata)            # this is matplotlib's call to show an image 
    if type(rgbdata[0][0] == int):
        im = ax.imshow(rgbdata, cmap="gray")
    
    set_size(width=wid_in, height=hgt_in)  # matplotlib will maintain the image's aspect ratio
    ax.axis('off')                         # turns off the axes (in units of pixels)
    plt.show()                             # show the image

In [None]:
def saveRGB( PX, filename ):
    """ saves a list-of-lists-of-lists of rgb pixels (PX) as filename where
        len(PX) == the # of rows
        len(PX[0]) == the # of columns
        len(PX[0][0]) should be 3 (rgb)
    """
    boxed_pixels = PX
    print( 'Starting to save', filename, '...' )
    H = len(PX)
    W = len(PX[0])
    im = Image.new("RGB", (W, H), "black")
    px = im.load()
    for r in range(H):
        for c in range(W):
            bp = boxed_pixels[r][c]
            t = tuple(bp)
            px[c,r] = t
    im.save( filename )
    time.sleep(0.42)   # give the filesystem some time...
    print( filename, "saved." )    

In [None]:
def greyscaleHelp( rgbpixel ):
    """ Helper Function
        rgbpixel should be in the form [r,g,b]
        returns [newred, newgreen, new blue],
        based on their old versions!
    """
    [r,g,b] = rgbpixel
    lum = (21 * r)//100 + (72 * g)//100 + (7 * b)//100   # a generic formula to convert rgb to greyscale
    return lum   # returns a single number

In [None]:
def greyscale( image ):
    """Makes an image grayscale"""
    
    new_image = [[ greyscaleHelp(pix) for pix in row] for row in image]   # sick list comprehension
    return new_image

## Singular Value Decomposition

Grayscale image data is essentially a large matrix (it's a list-of-lists!). As you may learn in linear algebra, SVD is a special way to factor matrices using so-called "singular values." It turns out that omitting data associated with the less-important singular values has very little impact on the quality of an image. 

Below you can choose an image (file) to compress. You can also choose different amounts of singular values to keep. 

For example, tha matrix with pixel data for "dodds.jpg" can be factored into three matrices (called U, $\Sigma$, and V). The $\Sigma$ matrix holds the singular values (in this case there are 682 of them). We can compress the image by splicing the matrices to only use the first n singular values with very little quality loss. Try is out with six different values (held in the list `comps`)

In [None]:
# choose an image to SVD-ify
file = "dodds.jpg"

# plot images with different number of components
comps = [600, 400, 200, 100, 20, 5]

# get rgb data
rgb = getRGB(file)
gray = np.asarray(greyscale(rgb))

# obtain svd
U, S, V = np.linalg.svd(gray)   # isn't numpy amazing?

# inspect shapes of the matrices
print(f"\nThe U matrix has shape {U.shape}")
print(f"The Sigma matrix has shape {S.shape}")
print(f"The V matrix has shape {V.shape}")
print("Make sure the largest value in the component list is < the smallest value you see above!")

plt.figure(figsize = (12, 8))
for i in range(6):
    low_rank = U[:, :comps[i]] @ np.diag(S[:comps[i]]) @ V[:comps[i], :]   # @ is used in numpy for matrix multiplication
    plt.subplot(2, 3, i+1), plt.imshow(low_rank, cmap = 'gray'), plt.axis('off'), plt.title("components = " + str(comps[i]))

# DCT

The Discrete Cosine Transformation is a bit complicated. You can see the details on a wikipedia page!

https://en.wikipedia.org/wiki/Discrete_cosine_transform

In short, DCT compresses an image by transforming a matrix into useful parts and not-so-useful parts (somewhat similar to SVD). Here, we will apply DCT to 8x8 blocks of pixels throughout the whole image

First let's read in an image

In [None]:
filename = "dodds.jpg"   # image to work with

rgb = getRGB(filename)
im = np.asarray(greyscale(rgb))
f = plt.figure(figsize=(8,8))
plt.imshow(im, cmap="gray")
plt.title("Our favorite Prof")
plt.axis('off')
plt.show()

### Helper Functions

Although these functions are very short, they are the ones that do all of the work! We use the scipy library to apply the transformation to the data. 

Run the cells below to apply DCT to an image and view what it does to one 8x8 block of pixels

In [None]:
def dct2(a):
    return scipy.fft.dct( scipy.fft.dct( a, axis=0, norm='ortho' ), axis=1, norm='ortho' ).astype(int)

def idct2(a):
    return scipy.fft.idct( scipy.fft.idct( a, axis=0 , norm='ortho'), axis=1 , norm='ortho').astype(int)

In [None]:
imsize = im.shape
dct = np.zeros(imsize)

# Do 8x8 DCT on image (in-place)
for i in r_[:imsize[0]:8]:
    for j in r_[:imsize[1]:8]:
        dct[i:(i+8),j:(j+8)] = dct2( im[i:(i+8),j:(j+8)] ) 

In [None]:
pos = 128

# Extract a block from image
fig, ax = plt.subplots(1,2, figsize=(10,10))

ax[0].imshow(im[pos:pos+8,pos:pos+8],cmap='gray')
ax[0].set_title( "An 8x8 Image block")
ax[0].axis('off')

# Display the dct of that block
ax[1].imshow(dct[pos:pos+8,pos:pos+8], cmap='gray', vmax = np.max(dct)*0.01,vmin = 0, extent=[0,pi,pi,0])
ax[1].set_title( "An 8x8 DCT block")
ax[1].axis('off')

plt.show()

In [None]:
plt.figure(figsize=(10,10))
plt.imshow(dct,cmap='gray',vmax = np.max(dct)*0.01,vmin = 0)
plt.title( "8x8 DCTs of the image")
plt.axis('off')
plt.show()

### Compression Time

The cell below will keep the cells with values above a certain threshold and delete everything else. As you might see, a significant amount of the data is deleted...

In [None]:
# Threshold
thresh = 0.012

dct_thresh = dct * (abs(dct) > (thresh*np.max(dct)))

plt.figure(figsize=(10,10))
plt.imshow(dct_thresh,cmap='gray',vmax = np.max(dct)*0.01,vmin = 0)
plt.title( "Thresholded 8x8 DCTs of the image")

percent_nonzeros = np.sum( dct_thresh != 0.0 ) / (imsize[0]*imsize[1]*1.0)

print(f"Keeping only {percent_nonzeros*100.0}% of the DCT coefficients")
plt.show()

### Inverse

We can now perform the inverse of DCT to get a normal-looking image back!

How does it look compared to the original?

In [None]:
im_dct = np.zeros(imsize)

for i in r_[:imsize[0]:8]:
    for j in r_[:imsize[1]:8]:
        im_dct[i:(i+8),j:(j+8)] = idct2( dct_thresh[i:(i+8),j:(j+8)] ).astype(int)
        
fig, ax = plt.subplots(1,2, figsize=(10,10))

# original
ax[0].imshow(im,cmap='gray')
ax[0].set_title( "Original")
ax[0].axis('off')

# compressed
ax[1].imshow(im_dct.astype(int), cmap='gray')
ax[1].set_title( "DCT compressed")
ax[1].axis('off')

print("Here are the first 20 pixels in the original image")
print(im[0][:20])
print("\nHere are the first 20 pixels in the conpressed image")
print(im_dct.astype(int)[0][:20])
print("\nNotice how values are repeated more frequently in DCT!")

plt.show()