In [56]:
""" ADMM color test. """
import numpy as np
import cv2
from scipy.fftpack import fft2, ifft2
import time
from utils.utils import D
from utils.utils import DT

In [58]:
def DT(grad):
    if len(grad.shape)==3:
        # grayscale image
        n1, n2, _ = grad.shape
        shift = np.roll(grad[:, :, 0], (1, 0), axis=(0, 1))
        div1 = grad[:, :, 0] - shift
        div1[0, :] = grad[0, :, 0]
        div1[n1-1, :] = -shift[n1-1, :]

        shift = np.roll(grad[:, :, 1], (0, 1), axis=(0, 1))
        div2 = grad[:, :, 1] - shift
        div2[:, 0] = grad[:, 0, 1]
        div2[:, n2-1] = -shift[:, n2-1]

        div = div1 + div2 
        return div
    else:
        # color image
        assert len(grad.shape)==4, "gradient dimension error: dimension not accepted."
        n1, n2, _, _ = grad.shape
        div = np.zeros((n1, n2, 3))

        for k in range(3):
            for i in range(2):
                shift = np.roll(grad[:, :, k, i].squeeze(), shift=(1-i, -i), axis=(0, 1))
                div1 = grad[:, :, k, i].squeeze() - shift
                div1[0, :] = grad[0, :, k, i].squeeze()
                div1[n1 - 1, :] = -shift[n1 - 1, :]

                shift = np.roll(grad[:, :, k, 1-i].squeeze(), shift=(-i, 1-i), axis=(0, 1))
                div2 = grad[:, :, k, 1-i].squeeze() - shift
                div2[:, 0] = grad[:, 0, k, 1-i].squeeze()
                div2[:, n2 - 1] = -shift[:, n2 - 1]

                div[:, :, k] += div1 + div2
        return div


In [59]:
print(DT(np.ones((10,10,3,2))))

[[[ 4.  4.  4.]
  [ 2.  2.  2.]
  [ 2.  2.  2.]
  [ 2.  2.  2.]
  [ 2.  2.  2.]
  [ 2.  2.  2.]
  [ 2.  2.  2.]
  [ 2.  2.  2.]
  [ 2.  2.  2.]
  [ 0.  0.  0.]]

 [[ 2.  2.  2.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [-2. -2. -2.]]

 [[ 2.  2.  2.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [-2. -2. -2.]]

 [[ 2.  2.  2.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [-2. -2. -2.]]

 [[ 2.  2.  2.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [-2. -2. -2.]]

 [[ 2.  2.  2.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [ 0.  0.  0.]
  [-2. -2. -2.]]

 [[ 2.  2.  2.]
  [ 0.  0.  

In [111]:
# Auxiliary functions
def z_solver(x, u, lambd, rho, mask, tv_type):
    w = mask * (D(x) + (1 / rho) * u)
    if tv_type == 'anisotropic':
        z = soft_threshold(w, lambd / rho)
    else:
        w_v, w_h = w[:, :, 0], w[:, :, 1]
        t = np.sqrt(w_v**2 + w_h**2)
        z = (np.atleast_3d(soft_threshold(t, lambd / rho) / (t + np.finfo(float).eps)))*w
    return z

def soft_threshold(x, kappa):
    return np.maximum(x - kappa, 0) - np.maximum(-x - kappa, 0)

def x_solver(z, u, u0, rho, deno):
    x = ifft2(fft2(u0 + rho * DT(z) - DT(u)) / deno)
    return x

"""implements ADMM for TV-L^2."""
def ADMM_3D(u0,lambd, N, \
              tv_type = "anisotropic",
                rho = 0.05, mu = 10, tau = 2, ground_truth = None, eps = 1e-10, channel_axis = None):
    # note: when ground truth is nonzero, eps is used to
    # ensure the quality of output image
    m, n, c = u0.shape
    assert c == 3, "color channel mismatch error."
    # Initialization
    u0_R = u0[:,:,0]
    u0_G = u0[:,:,1]
    u0_B = u0[:,:,2]
    x = np.zeros_like(u0)
    x_next = np.zeros_like(u0)
    z = np.zeros((m,n,3,2))
    z_next = np.zeros((m,n,3,2))
    u = np.zeros_like(z)

    # z = D(u_channel)
    mask = np.ones((m,n,2))
    mask[-1, :, 0] = 0
    mask[:, -1, 1] = 0
    dh = np.array([[0, 0, 0], [-1, 1, 0], [0, 0, 0]])
    dh_pad = np.zeros((m, n))
    dh_pad[m//2:m//2+3, n//2:n//2+3] = dh 
    # horizontal difference operator
    dv = np.array([[0, -1, 0], [0, 1, 0], [0, 0, 0]])
    dv_pad = np.zeros((m, n))
    dv_pad[m//2:m//2+3, n//2:n//2+3] = dv
    # vertical difference operator
    fdh = fft2(dh_pad)
    fdv = fft2(dv_pad)
    deno = 1 + rho * np.abs(fdh)**2 + rho * np.abs(fdv)**2
        
    z = D(u0)
    print(z.shape)
    # print("z=",z)
    print("dtz=",DT(z))
    # tic = time.perf_counter()
    # denoise_process = []
    # Main loop
    # np.save("0_temp.npy",u0)
    """To display the denoising process, the intermediate values
    are saved to a temp.npy file. """
    # TODO: find RAM-economical solution to dynamical denoising display
    if ground_truth == None:
        # iterate for a set number of times
        for i in range(N):
            # TODO: set stopping criteria with original image
            x_next[:,:,0] = x_solver(z[:,:,0,:], u[:,:,0,:], u0_R, rho, deno)
            z_next[:,:,0,:] = z_solver(x_next[:,:,0], u[:,:,0,:], lambd, rho, mask, tv_type)
            x_next[:,:,1] = x_solver(z[:,:,1,:], u[:,:,1,:], u0_G, rho, deno)
            z_next[:,:,1,:] = z_solver(x_next[:,:,1], u[:,:,1,:], lambd, rho, mask, tv_type)
            x_next[:,:,2] = x_solver(z[:,:,2,:], u[:,:,2,:], u0_B, rho, deno)
            z_next[:,:,2,:] = z_solver(x_next[:,:,2], u[:,:,2,:], lambd, rho, mask, tv_type)

            u_next = u + rho * (D(x_next) - z_next)
            # print(DT(z))
            # z_next = np.stack([z_next_R, z_next_G, z_next_B], axis = 2)
            # print(DT(z_next))
            s = -rho * (DT(z_next - z))
            r = D(x) - z
            s_norm = np.linalg.norm(s)
            r_norm = np.linalg.norm(r)
            if r_norm > mu * s_norm:
                rho = rho * tau
            elif s_norm > mu * r_norm:
                rho = rho / tau
            print("ite",i)
            # print(np.linalg.norm(x_next-x, axis = 1))
            # TODO: find out why does the usual convergence not work

            x = x_next
            z = z_next
            u = u_next

            x_temp = x.astype(np.uint8)
            os_dir = [str(i),"temp.npy"]
            os_dir = "_".join(os_dir)
            np.save(os_dir, x_temp)
            # denoise_process.append(u)
    else:
        # compare the generated image with ground truth
        # end iteration if difference smaller than eps
        # or when k exceeds maxiter
        k = 0
        while np.linalg.norm(ground_truth, x)/np.linalg.norm(x)>eps:
            x_next = x_solver(z, u, u0, rho, deno)
            z_next = z_solver(x_next, u, lambd, rho, mask, tv_type)
            u_next = u + rho * (D(x_next) - z_next)
            
            s = -rho * (DT(z_next - z))
            r = D(x) - z
            
            s_norm = np.linalg.norm(s)
            r_norm = np.linalg.norm(r)
            
            if r_norm > mu * s_norm:
                rho = rho * tau
            elif s_norm > mu * r_norm:
                rho = rho / tau
            
            x = x_next
            z = z_next
            u = u_next
            x_temp = x.astype(np.uint8)
            os_dir = [str(k),"temp.npy"]
            os_dir = "_".join(os_dir)
            np.save(os_dir, x_temp)
            # denoise_process.append(u)
            if k > N:
                break
            k+=1
    # toc = time.perf_counter()
    # runtime = toc - tic
    
    return x

In [112]:
if __name__ == "__main__":
    fileName = "input.png"
    N = 100
    weight = 5

    u = cv2.imread(r"E:\\2nd-semester\\practical\\test_v1\\reformulate_v1\\input_color.jpg")
    #print(u)
    u = u.astype(np.float32)# / 255.0

    result = ADMM_3D(u, weight, N)
    # result *= 255
    result = result.astype(np.uint8)
    cv2.imwrite("E:\\2nd-semester\\practical\\test_v1\\reformulate_v1\\output_ADMM_color.png", result)

(355, 474, 3, 2)
dtz= [[[  26.   14.   40.]
  [ -12.  -16.   -8.]
  [ -71.  -80.  -80.]
  ...
  [   4.    0.   -5.]
  [  -5.   -9.  -11.]
  [ -16.  -23.  -18.]]

 [[-175. -193. -156.]
  [  14.    5.   20.]
  [ 273.  256.  283.]
  ...
  [  89.   62.   68.]
  [  21.   40.   44.]
  [  29.   31.   48.]]

 [[  99.   99.   92.]
  [-206. -209. -199.]
  [ -24.  -27.  -26.]
  ...
  [ -75.  -66.  -78.]
  [  54.   51.   55.]
  [ -19.  -24.   -6.]]

 ...

 [[   1.    4.    6.]
  [-252. -270. -265.]
  [ 320.  319.  326.]
  ...
  [  86.  121.   91.]
  [ -21.  -23.  -20.]
  [  48.   34.   52.]]

 [[ 106.  117.   98.]
  [-181. -194. -207.]
  [ 195.  211.  208.]
  ...
  [ -15.   51.   -9.]
  [ -43.  -41.  -40.]
  [ -13.  -33.  -10.]]

 [[ 118.  124.  124.]
  [-143. -161. -159.]
  [  98.  108.  106.]
  ...
  [  78.   92.   89.]
  [ -34.  -52.  -28.]
  [ -47.  -62.  -41.]]]


  x_next[:,:,0] = x_solver(z[:,:,0,:], u[:,:,0,:], u0_R, rho, deno)
  x_next[:,:,1] = x_solver(z[:,:,1,:], u[:,:,1,:], u0_G, rho, deno)
  x_next[:,:,2] = x_solver(z[:,:,2,:], u[:,:,2,:], u0_B, rho, deno)


ite 0
ite 1
ite 2
ite 3
ite 4
ite 5
ite 6
ite 7
ite 8
ite 9
ite 10
ite 11
ite 12
ite 13
ite 14
ite 15
ite 16
ite 17
ite 18


  u_next = u + rho * (D(x_next) - z_next)
  s = -rho * (DT(z_next - z))
  r = D(x) - z


ite 19
ite 20
ite 21
ite 22
ite 23
ite 24
ite 25
ite 26
ite 27
ite 28
ite 29
ite 30
ite 31
ite 32
ite 33
ite 34
ite 35
ite 36
ite 37
ite 38
ite 39
ite 40
ite 41
ite 42
ite 43
ite 44
ite 45
ite 46
ite 47
ite 48
ite 49
ite 50
ite 51
ite 52
ite 53
ite 54
ite 55
ite 56
ite 57
ite 58
ite 59
ite 60
ite 61
ite 62
ite 63
ite 64
ite 65
ite 66
ite 67
ite 68
ite 69
ite 70
ite 71
ite 72
ite 73
ite 74
ite 75
ite 76
ite 77
ite 78
ite 79
ite 80
ite 81
ite 82
ite 83
ite 84
ite 85
ite 86
ite 87
ite 88
ite 89
ite 90
ite 91
ite 92
ite 93
ite 94
ite 95
ite 96
ite 97
ite 98
ite 99
