In [None]:
# Remove sinusoidal noise from image

# Steps
# Step 1: Read the image
# Step 2: Compute the Fourier Transform of the image
# Step 3: Get the centered Fourier Transform spectrum and display
# Step 4: Spot the periodic noise and pattern in the FT image
# Step 5: Block the periodic noise
# Step 6: Convert the image back to the spatial domain from the frequency domain
# Step 7: Compute the inverse transform of the image
# Step 8: Display

# This code ran on Google collab

import cv2 as cv
import numpy as np
import matplotlib.pyplot as plt
import math
from scipy import ndimage as ndi
from skimage.feature import peak_local_max
from skimage import data, img_as_float

import matplotlib.image as pltImg

from google.colab.patches import cv2_imshow as imshow

from google.colab import drive
drive.mount('/content/drive')


sinusoidalImagePath = '/content/drive/MyDrive/sinusoidalImage.png'
imageFiltered = '/content/drive/MyDrive/filteredImage.png'

def readImage(imageUrl):
  img = cv.imread(imageUrl, 0)
  return img

def showImage(img):
  imshow(img)

def computeFourierTransform(img):
  return np.fft.fft2(img)

def computerCenteredFFT(fftImg):
  return np.fft.fftshift(fftImg)

def plotMagnitueSpectrum(magnitude_spectrum):
  plt.subplots(figsize=(30,30))
  plt.imshow(magnitude_spectrum, cmap = 'gray')
  plt.title('Magnitude Spectrum')
  plt.xticks([])
  plt.yticks([])
  plt.show()  


sinusoidalImage = readImage(sinusoidalImagePath)
showImage(sinusoidalImage)

fftImage = computeFourierTransform(sinusoidalImage)
showImage(fftImage)

centeredFftImage = computerCenteredFFT(fftImage)
showImage(centeredFftImage)


spectrum = 20*np.log(np.abs(centeredFftImage))

# plotting
plotMagnitueSpectrum(spectrum)


rows, cols = sinusoidalImage.shape
crow,ccol = math.ceil(rows/2) , math.ceil(cols/2)

## TEST

# Smooth the image first
spectrumBlurred = cv.GaussianBlur(spectrum, (19,19), 3)
imAsFloat = img_as_float(spectrumBlurred)
coordinates = peak_local_max(imAsFloat, min_distance=4)

plt.plot(coordinates[:, 1], coordinates[:, 0], 'r.')
plt.show()

for pair in coordinates:
  print(pair)
  x = pair[0]
  y = pair[1]
  if [x, y] == [crow, ccol] or \
      [x, y] == [crow+1, ccol] or \
      [x, y] == [crow-1, ccol] or \
      [x, y] == [crow, ccol-1] or \
      [x, y] == [crow, ccol+1] or \
      [x, y] == [crow-1, ccol-1] or \
      [x, y] == [crow-1, ccol+1] or \
      [x, y] == [crow+1, ccol-1] or \
      [x, y] == [crow+1, ccol+1]:
    continue
  centeredFftImage[x, y] = 0
  centeredFftImage[x, y-1] = 0
  centeredFftImage[x, y+1] = 0
  centeredFftImage[x-1, y] = 0
  centeredFftImage[x+1, y] = 0

# centeredFftImage[crow-3:crow+3, ccol-11:ccol-8] = 0
# centeredFftImage[crow-3:crow+3, ccol+6:ccol+9] = 0


tunnedSpectrum = 20*np.log(np.abs(centeredFftImage))
plotMagnitueSpectrum(tunnedSpectrum)

showImage(centeredFftImage)

f_ishift = np.fft.ifftshift(centeredFftImage)

# inverse fft to get the image back 
img_back = np.fft.ifft2(f_ishift)

# img_back = np.real(img_back)
img_back = np.abs(img_back)

# increase contrast: X' = aX+b, a for contrast and 0 for brightness ; alpha = contrast/127+1, beta=brightness - contrast)
adjusted = cv.convertScaleAbs(img_back, alpha=1.5, beta=-63)
showImage(adjusted)