## Imports

In [39]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from skimage import io as skio
from skimage import color
from skimage import transform

# For PSNR and SSIM
import cv2
from skimage.metrics import structural_similarity as ssim

## Implementation

In [40]:
# Returns the indexes of the brightest pixels in the dark channel
def brightest(dark_channel, percentage=0.1, show=False):
  im_rows = np.shape(dark_channel)[0]
  im_cols = np.shape(dark_channel)[1]
  n_pixels = im_rows * im_cols
  n_brightest = int(n_pixels*(percentage/100))
  dark_channel = dark_channel.reshape(-1)
  idx = np.argpartition(dark_channel, -n_brightest)[-n_brightest:]
  if show:
    # Create an empty array with the same shape as the dark_channel
    highlighted_dark_channel = np.zeros((im_rows*im_cols, 3), dtype=np.uint8)
    for i in range(len(dark_channel)):
      if i in idx:
        # Set the highlighted pixels to red
        highlighted_dark_channel[i] = [255, 0, 0]  # Red color
      else:
        # Set the rest of the pixels in grayscale using the original values
        dc_val = dark_channel[i]
        highlighted_dark_channel[i] = [dc_val, dc_val, dc_val]
    highlighted_dark_channel = highlighted_dark_channel.reshape((im_rows,im_cols,3))
    plt.figure()
    plt.title('Brightest pixels in dark channel')
    plt.imshow(highlighted_dark_channel)
    plt.show()
  return idx

# Calculates the dark channel value of a local patch
def calc_dark_channel_patch(patch):
  min_r = np.min(patch[:,:,0])
  min_g = np.min(patch[:,:,1])
  min_b = np.min(patch[:,:,2])
  dark_ch = np.min([min_r, min_g, min_b])
  return dark_ch

# Calculates the dark channel of an image
def calc_dark_channel_im(im, patch_size_x=15, patch_size_y=15, show=False):
  dark_ch = np.zeros_like(im)[:,:,0] # zeros_like copies size and datatype
  im_rows = np.shape(im)[0]
  im_cols = np.shape(im)[1]
  half_patch_x = patch_size_x // 2
  half_patch_y = patch_size_y // 2
  for i in range(im_rows):
    for j in range(im_cols):
      start_y = max(0, i-half_patch_y)
      stop_y = min(im_rows, i+half_patch_y+1)
      start_x = max(0, j-half_patch_x)
      stop_x = min(im_cols, j+half_patch_x+1)
      patch = im[start_y:stop_y, start_x:stop_x]
      dark_ch[i,j] = calc_dark_channel_patch(patch)
  # Optional:
  if show:
    plt.figure()
    plt.title('Dark channel')
    plt.imshow(dark_ch, cmap='gray')
    plt.show()
  return dark_ch

def calc_atm_light(im, patch_size_x=15, patch_size_y=15, perc=0.1, show=False):
  dark_channel = calc_dark_channel_im(im, patch_size_x, patch_size_y, show=show)
  indexes_brightest = brightest(dark_channel, perc, show=show)
  im_rows = np.shape(im)[0]
  im_cols = np.shape(im)[1]
  flat_im = np.reshape(im, (im_rows * im_cols, 3))
  candidate = [0,0,0]
  for i in indexes_brightest:
    if np.max(flat_im[i]) > np.max(candidate):
      candidate = flat_im[i]
  atm_light = candidate
  if show:
    plt.figure()
    plt.title('Estimated atmospheric light')
    plt.imshow(np.full((100, 100, 3), atm_light, dtype=np.uint8))
    plt.show()
  return atm_light

def calc_transmission(im, patch_size_x=15, patch_size_y=15, perc=0.1, omega=0.95, show=False):
  im_rows = np.shape(im)[0]
  im_cols = np.shape(im)[1]
  transmission = np.zeros((im_rows, im_cols), dtype=np.float64)
  atm_light = calc_atm_light(im, patch_size_x, patch_size_y, perc, show=show)
  half_patch_x = patch_size_x // 2
  half_patch_y = patch_size_y // 2
  for i in range(im_rows):
    for j in range(im_cols):
      start_y = max(0, i-half_patch_y)
      stop_y = min(im_rows, i+half_patch_y+1)
      start_x = max(0, j-half_patch_x)
      stop_x = min(im_cols, j+half_patch_x+1)
      patch = im[start_y:stop_y, start_x:stop_x]
      min_r = np.min(patch[:,:,0])/atm_light[0]
      min_g = np.min(patch[:,:,1])/atm_light[1]
      min_b = np.min(patch[:,:,2])/atm_light[2]
      min_c = np.min([min_r, min_g, min_b])
      transmission[i,j] = 1.0 - omega*min_c
  if show:
    plt.figure()
    plt.title('Estimated transmission')
    plt.imshow(transmission*255, cmap='gray')
    plt.show()
  return transmission, atm_light

def guided_filter(input_im, r=20, epsilon=0.001, guidance_im=None, show=False):
  if guidance_im is None:
    guidance_im = input_im
  A = np.zeros_like(input_im)
  B = np.zeros_like(input_im)
  output_im = np.zeros_like(input_im)
  im_rows = np.shape(input_im)[0]
  im_cols = np.shape(input_im)[1]
  for i in range(im_rows):
    for j in range(im_cols):
      # Define patch
      start_y = max(0, i-r)
      stop_y = min(im_rows, i+r+1)
      start_x = max(0, j-r)
      stop_x = min(im_cols, j+r+1)
      # Compute a and b for the patch
      patch_I = guidance_im[start_y:stop_y, start_x:stop_x]
      patch_p = input_im[start_y:stop_y, start_x:stop_x]
      patch_I2 = np.multiply(patch_I, patch_I)
      patch_Ip = np.multiply(patch_I, patch_p)
      mean_I = np.mean(patch_I)
      mean_p = np.mean(patch_p)
      corr_I = np.mean(patch_I2)
      corr_Ip = np.mean(patch_Ip)
      var_I = corr_I - np.multiply(mean_I, mean_I)
      cov_Ip = corr_Ip - np.multiply(mean_I, mean_p)
      a = np.divide(cov_Ip, var_I + epsilon)
      b = mean_p - np.multiply(a, mean_I)
      # Store a and b
      A[i,j] = a
      B[i,j] = b
  for i in range(im_rows):
    for j in range(im_cols):
      # Define patch
      start_y = max(0, i-r)
      stop_y = min(im_rows, i+r+1)
      start_x = max(0, j-r)
      stop_x = min(im_cols, j+r+1)
      # Compute mean a and b for the patch
      mean_a = np.mean(A[start_y:stop_y, start_x:stop_x])
      mean_b = np.mean(B[start_y:stop_y, start_x:stop_x])
      # Compute output
      output_im[i,j] = mean_a*guidance_im[i,j] + mean_b
  if show:
    plt.figure()
    plt.title('Soft-matted transmission')
    plt.imshow(output_im*255, cmap='gray')
    plt.show()
  return output_im

def remove_haze(path_input, path_output, max_size=600, patch_size_x=15, patch_size_y=15,
                perc=0.1, omega=0.95,
                r=20, epsilon=0.001, t0=0.1, show=False, contrast=1):
  # Open and show an input image
  im = np.float64(skio.imread(path_input))
  # Calculate new dimensions while maintaining the aspect ratio
  height, width = im.shape[:2]
  if (max(height, width) > max_size and max_size != None):
    if height > width:
      new_height = max_size
      new_width = int(width * (max_size / height))
    else:
      new_width = max_size
      new_height = int(height * (max_size / width))
    # Resize the image to the calculated dimensions
    im = transform.resize(im, (new_height, new_width))
    skio.imsave(path_input, im.astype(np.uint8))

  if show:
    plt.figure()
    plt.title('Input image')
    plt.imshow(np.divide(im,255))
    plt.show()

  # Convert im to grayscale for use in soft-matting
  bw_im = color.rgb2gray(im)

  # Compute input image transmission
  transmission, atm_light = calc_transmission(im, patch_size_x=patch_size_x,
                                   patch_size_y=patch_size_y,
                                   perc=perc, omega=omega, show=show)

  # Compute soft-matting for the transmission
  soft_matted_tr = guided_filter(transmission, r=r,
                                 epsilon=epsilon, guidance_im=bw_im, show=show)

  # Bound t(x) with t0
  t = np.clip(soft_matted_tr, a_min=t0, a_max=None)
  if show:
    plt.figure()
    plt.title('Transmission after clipping with t0')
    plt.imshow(t*255, cmap='gray')
    plt.show()

  # Recover scene radiance
  im_rows = np.shape(im)[0]
  im_cols = np.shape(im)[1]
  output = np.zeros_like(im)
  for i in range(im_rows):
    for j in range(im_cols):
      num = np.subtract(im[i,j], atm_light)
      frac = np.divide(num, t[i,j])
      output[i,j] = np.add(frac, atm_light)

  clipped_output = ((output+np.abs(output.min()))/(output.max()+np.abs(output.min())))
  skio.imsave(path_output, (clipped_output*255).astype(np.uint8))
  if show:
    plt.figure()
    plt.title('Output image')
    plt.imshow(clipped_output)
    plt.show()
  return output

## Experiments

Dehaze 1 image:

In [None]:
output = remove_haze('./bad-2.jpg', './output.jpg', max_size=800, omega=0.75, perc=0.1, contrast=1.1, r=50, t0=0.2, show=True)

Dehaze a set of images:

In [None]:
for i in range(10):
  print(i)
  output = remove_haze(f'./{i}.jpg', f'./{i}_output.jpg', max_size=1200, omega=0.75, r=50, contrast=1.2, show=False)

Plot histograms of the dark channel of an image:

In [None]:
def plot_dc_histogram(image_path):
    # Read the grayscale image
    im = skio.imread(image_path)

    # Calculate the dc
    dc = calc_dark_channel_im(im, patch_size_x=15, patch_size_y=15, show=False)

    # Calculate the histogram
    histogram, bin_edges = np.histogram(dc, bins=256)

    # Calculate the cumulative histogram
    cumulative_histogram = np.cumsum(histogram)

    # Set up subplots
    sns.set(style="ticks", context="paper")
    fig, axs = plt.subplots(nrows=2, figsize=(8, 8))

    # Plot the original image and the dark channel
    axs[0].imshow(im, cmap='gray')
    axs[0].set_title("Original Image")
    axs[0].axis('off')

    axs[1].imshow(dc, cmap='gray')
    axs[1].set_title("Dark Channel")
    axs[1].axis('off')

    # Set up subplots for histograms
    fig, axs_hist = plt.subplots(nrows=2, figsize=(4, 6))

    # Plot the histogram
    sns.lineplot(x=bin_edges[0:-1], y=histogram, ax=axs_hist[0])
    axs_hist[0].set(xlabel="Intensity value", ylabel="Frequency")
    axs_hist[0].set_title("Histogram")
    sns.despine()

    # Plot the cumulative histogram
    sns.lineplot(x=bin_edges[0:-1], y=cumulative_histogram, ax=axs_hist[1], color='orange')
    axs_hist[1].set(xlabel="Intensity value", ylabel="Cumulative Frequency")
    axs_hist[1].set_title("Cumulative Histogram")
    sns.despine()

    plt.tight_layout()
    plt.show()

# Example usage
image_path = './test.jpg'
plot_dc_histogram(image_path)

Calculate PSRN and SSIM with varying parameters:

In [None]:
omegas = [0.55, 0.65, 0.7, 0.75, 0.8, 0.9]
#rs = [20, 30, 40, 50, 60, 70]

psnr_array = []
ssim_array = []
for o in omegas:
  # Dehaze
  print(f'omega = {o}...')
  #print(f'r = {r}...')

  output = remove_haze('./0-haze.jpg', f'./o-{o}-output.jpg', max_size=600, omega=o, perc=0.1, contrast=1.1, r=50, show=False)
  #output = remove_haze('./0-haze.jpg', f'./r-{r}-output.jpg', max_size=600, omega=0.75, perc=0.1, contrast=1.1, r=r, show=False)

  # We compare our dehazing result with the original hazeless image
  input = cv2.imread('./0.jpg')
  output = cv2.imread(f'./o-{o}-output.jpg')
  #output = cv2.imread(f'./r-{r}-output.jpg')

  # Calc PSNR
  p = cv2.PSNR(input, output)
  psnr_array.append(p)

  # Calc SSIM
  s = ssim(input, output, channel_axis=2)
  ssim_array.append(s)

Plot PSNR and SSIM:

In [None]:
sns.set(style="ticks", context="paper")
fig, ax = plt.subplots()

sns.lineplot(x=omegas, y=psnr_array, label="PSNR", ax=ax)
#sns.lineplot(x=omegas, y=ssim_array, label="SSIM", ax=ax)

ax.set(xlabel="omega", ylabel="Score")
ax.legend()
sns.despine()

plt.show()