## Compression image and the Fast Fourier Transform

This lesson is based on the Image Compression youtube lessons from Steve Barton:
- [Image compreesion and FFT](https://www.youtube.com/watch?v=gGEBUdM0PVc&ab_channel=SteveBrunton)
- [Image compression and FFT (Example in Python)](https://www.youtube.com/watch?v=uB3v6n8t2dQ&t=301s&ab_channel=SteveBrunton)


The Fourier transform as we saw in the past lessons, is a transform that decompose a function in terms of sines and cosines, and we can compute the coefficents efficiently using the "fast fourier transform" algorithm. Luckly, this algorithm is already a built in function in `numpy` so we can go ahead and use it. 

One of the most common applications of the Fast Fourier Transform (FFT) is compression, compression of audio, video, images, etc. In this notebook we will see an example of image compression, for which we will use a 2D fast fourier transform `fft2`. The idea behind a two-dimensional FFT is  that we applied FFT to all the rows and then to all the columns of our 2D image. Fortunately, `numpy` already has a 2D fft implementation `numpy.fft.fft2` that we will apply to our pixels and get the fourier coefficients of our image. Most of this coefficients will be very small and we will be able to discard them and keep the higher ones. We will see that by just keeping a small percentage of coefficients we can recover the original image by applying the inverse of the FFT. This process is what is behind all types of compressions compression. For real applications wavelets transforms are implemented since they give better results, but the process behind is the same. Let's take a look in an actual example.

When we are saving a JPG image in our phones or computers we are saving only the small percentage of pixels that recovers the image, which can be display very fast thanks to the inverse FFT.


In [None]:
from matplotlib import pyplot
from matplotlib.image import imread
import numpy

In [None]:
pyplot.rcParams['figure.figsize'] = [9.8, 6.53]

We can "read" and image as an 3D `numpy` array using `imread` (one matrix per RBG channel) and display it using `pyplot.imshow`.

In [None]:
A= imread('../images/blm.jpg')

In [None]:
pyplot.imshow(A)
pyplot.axis('off');

If we check the shape of our array, we have three dimensions. We have a three 653x980 matrix, each of these matrices correspond to a RBG (Red-Blue-Green) channel. 

In [None]:
numpy.shape(A)

For simplicity, we will convert the RBG to a grey scale and only work with one matrix. We do this by averaging along the third axis (in `python` index 2). For each pixel we take the average between its value in the red, blue and green channels. 

In [None]:
Ags = numpy.mean(A, axis=2)

In [None]:
numpy.shape(Ags)

In [None]:
pyplot.imshow(Ags, cmap='gray')
pyplot.axis('off');

We compute the fourier transform in 2D and plot we can plot the magnitud of each coefficient by taking the absolute value of the complex array. We plot it in log-scale to be more visible the difference. 

In [None]:
from matplotlib.colors import LogNorm

In [None]:
FF = numpy.fft.fft2(Ags)
fig = pyplot.figure(figsize=(6,6))
pyplot.imshow(abs(FF), cmap ='gray', norm=LogNorm(vmin=numpy.min(abs(FF)), vmax=numpy.max(abs(FF))));

#pyplot.imshow(numpy.log(abs(FF)), cmap ='gray');


In [None]:
numpy.min(abs(FF))

Let's display the zeroeth coefficient in the center, to do that we use the `fftshift` function, and let's plot a colorbar along with the plot. 

In [None]:
# It is more convenient to display the (0,0) frequency in the center
FF_s = numpy.fft.fftshift(FF)
fig = pyplot.figure(figsize=(6,6))
ax = pyplot.axes()
im = pyplot.imshow(abs(FF_s), cmap ='gray', norm=LogNorm(vmin=numpy.min(abs(FF_s)), vmax=numpy.max(abs(FF_s))))

#this is to get color bar matching side of plot. 
cax = fig.add_axes([ax.get_position().x1+0.01,ax.get_position().y0,0.02,ax.get_position().height])

pyplot.colorbar(im, cax=cax);
#pyplot.imshow(numpy.log(abs(FF_s)), cmap ='gray');

Now let's see what happen if we keep only a small percentage of the coefficients. To do this, we will reshape the 2D array into a 1D long array which we will sort and then zero out everything but the top percentage. 


In [None]:
#reshape(FF, -1) the -1 means that the length of the 1D array 
# is inferred from the length of the array and remaining dimensions.

#We sort the array and then inverted the order to get the higher values first
Bt = numpy.sort(numpy.abs(numpy.reshape(FF,-1)))[::-1]


keep = 0.03 #let's say we want to keep 3% 
#we multiply the len(Bt) by keep which will give use the top 3% index cut off after we take the int() and use this index
#to get the threshold value
thresh = Bt[int((keep)*len(Bt))]

#We create a mask array by checking where the coefficients amplitud is bigger than the threshold 
ind = numpy.abs(FF) > thresh

#we keep the top values determined by the throshold by multiplying the coefficients array by the mask
# This will zeroed all the values below the threshold
Atlow = FF * ind

#Now we plot the amplitude, shifted to the center in a log scale 
# Put FFT on log scale we added 1e-12 to not get error when taking log of 0

Flow = numpy.log(numpy.abs(numpy.fft.fftshift(Atlow)+1e-12)) 

fig = pyplot.figure(figsize=(6,6))
pyplot.imshow(Flow,cmap='gray')
#pyplot.axis('off')
#pyplot.show()

In [None]:
numpy.shape(Atlow)

In [None]:
thresh

Now if we apply the inverse of the Fast fourier transform we should recover a version of the image that is coarser than the original. The question is, do we have enough information to recover a similar image? 

Keep in mind that `ifft` returns a complex array, but in this case the complex part should be nearly zero. To plot the image, we take the real part of the array after the `ifft` was applied. 

In [None]:
Alow = numpy.fft.ifft2(Atlow)
pyplot.imshow(Alow.real, cmap='gray')
pyplot.axis('off');

We can check that the imaginary values are close to zero by checking the imaginary part of the `Alow` array. 

In [None]:
Alow.imag

In [None]:
numpy.max(Alow.imag)

### ADD DIFFERENT KEEP VALUES, CREATE FOR LOOP TO DO THAT, OR LEAVE AS EXERCISE. 


Image as a surface plot of pixel intensities. The brightest points are the highest and lighter in color. 

In [None]:
from mpl_toolkits.mplot3d import Axes3D


In [None]:
# Rotate the plot down and play around

In [None]:
%matplotlib notebook
pyplot.rcParams['figure.figsize'] = [9.8, 6.53]

fig = pyplot.figure()
ax = fig.add_subplot(111, projection='3d')

X,Y = numpy.meshgrid(numpy.arange(0, numpy.shape(Alow)[1]), numpy.arange(0, numpy.shape(Alow)[0]))
ax.plot_surface(X[0::5, 0::5], Y[0::5, 0::5], Ags[0::5, 0::5], cmap='viridis') #we are jumping each 5 to not overload the plot
#ax.view_init(-88,-90)
#ax.view_init(0,0)
