In [None]:
import os
import numpy as np
import random as rng
import cv2
import json
import math
import pywt
from scipy.misc import face
from scipy.signal.signaltools import wiener
import matplotlib.pyplot as plt

In [None]:
out_datadir = "/tf/studprojskrabec/images/strips_socrates"

In [None]:
info_data = []
with open(os.path.join(out_datadir, 'dataset_info.json')) as f:
    info_data = json.load(f)

In [None]:
STRIP_SIZE = 256
DENOISE_SIGMA = 5

In [None]:
def denoise_coefficient_list(coefficient_list, sigma):
    ll = coefficient_list[0]
    denoised_bands = [ll]
    for band, subband_coefficients in enumerate(coefficient_list[1 :]):
        denoised_bands.append([wiener(s.astype(np.float), sigma)
                               for s in subband_coefficients])
    return denoised_bands


In [None]:
def get_residual(grayscale_matrix):
    dyad_length = math.ceil(math.log(STRIP_SIZE, 2))
    ll_levels = 5
    wavelet_levels = dyad_length - ll_levels
    ll_size = 2 ** ll_levels
    coefficient_list = pywt.wavedec2(grayscale_matrix,
                                       'db8',
                                       level = int(wavelet_levels),
                                       mode = 'per')
    coefficient_list = denoise_coefficient_list(coefficient_list,
                                                  DENOISE_SIGMA)
    denoised_tile = pywt.waverec2(coefficient_list,
                                    'db8',
                                    mode = 'per')
    denoised_tile[denoised_tile > 255.0] = 255.0
    denoised_tile[denoised_tile < 0.0] = 0.0
    return (denoised_tile, grayscale_matrix - denoised_tile)


In [None]:
def get_final(denoised, residual):
    a = denoised * residual
    b = denoised * denoised
    return(np.divide(a, b, out=np.zeros_like(a), where=b!=0))

In [None]:
def get_final_from_path(img_path):
    img = cv2.imread(img_path)
    gimg = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    den, res = get_residual(gimg)
    return (get_final(den, res))

In [None]:
def normalize_median_dev_cut(mat, low, high, dev_cut):
    med = np.median(mat)
    dev = (mat - med)
    mdev = np.median(np.abs(dev))
    mat[mat > med+mdev*dev_cut] = med + mdev * dev_cut
    mat[mat < med-mdev*dev_cut] = med - mdev * dev_cut
    cv2.normalize(mat, mat, low, high, cv2.NORM_MINMAX)
                  

In [None]:
def get_final_from_path_channel(img_path):
    img = cv2.imread(img_path)
    if img is None:
        print("Error reading image:", img_path)
        return 
    _, _, c = img.shape
    for channel in range(c):
        gimg = img[:,:,channel]
        den, res = get_residual(gimg)
        gimg = (get_final(den, res))
        normalize_median_dev_cut(gimg, 0, 255, 5)
        img[:,:,channel] = gimg
    return img

In [None]:
def transform_dataset():
    stds = len(info_data)
    for i, std in enumerate(info_data):
        print(i, " of ", stds, std[0][0])
        for img in std:
            for strip in img[1]:
                fin = get_final_from_path_channel(strip)
                fin = fin.astype(int)
                cv2.imwrite(strip, fin)            

In [None]:
transform_dataset()