In [None]:
%matplotlib widget

import numpy as np
import matplotlib.pyplot as plt
from numpy.fft import fft2, ifft2, fftshift, fftfreq
from skimage.restoration import unwrap_phase

data = np.load('corn_fluorescence.npy')

data = (data-np.min(data))/np.max(data)
wrap_factor = 4*np.pi
data = wrap_factor*data - np.pi
                        
wrapped_data = np.angle(np.exp(1j*data))

def PHI(P,X,Y):
    M = Y.shape[0]
    N = X.shape[1]
    arbitrary_additive_constant = 0
    PHI = np.where( X + Y != 0 ,P / (2*np.cos(np.pi*Y/M) + 2*np.cos(np.pi*X/N) - 4), arbitrary_additive_constant)
    return PHI


def mirror_reflect_data(wrapped_data):
    data_flip_vert = np.flip(wrapped_data,axis=1)
    extended_data = np.concatenate((wrapped_data,data_flip_vert[:,1::]),axis=1) # see definition in equation 3 of the paper; M<j<=2M implies we need such slice
    data_flip_horiz = np.flip(extended_data,axis=0)
    full_data = np.concatenate((data_flip_horiz[0:-1,:],extended_data),axis=0) # see definition in equation 3 of the paper; N<k<2N implies we need such slice

    if 0: 
        figure, ax = plt.subplots(1,3,dpi=200)
        ax[0].imshow(wrapped_data), ax[0].set_title('Wrapped')
        ax[1].imshow(data_flip_vert), ax[1].set_title("Flipped")    
        ax[2].imshow(extended_data)

        figure, ax = plt.subplots(1,2,dpi=200)
        ax[0].imshow(wrapped_data), ax[0].set_title('Wrapped')
        ax[1].imshow(full_data), ax[1].set_title("Full")    
        
    return full_data

def unwrap_LS(wrapped_data):

    extended_data = mirror_reflect_data(wrapped_data)

    dx = extended_data - np.roll(extended_data,-1,axis=1)
    dy = extended_data - np.roll(extended_data,1,axis=0)

    dx_shift = np.roll(dx,-1,axis=1)
    dy_shift = np.roll(dy,1,axis=0)
    rho = dx - dx_shift + dy - dy_shift

    if 0:
        figure, ax = plt.subplots(1,3,dpi=200)
        ax[0].imshow(dx), ax[0].set_title('DeltaX')
        ax[1].imshow(dy), ax[1].set_title("DeltaY")
        ax[2].imshow(rho)

    X = fftshift(fftfreq(rho.shape[1])) # spatial frequency
    Y = fftshift(fftfreq(rho.shape[0])) # spatial frequency
    X, Y = np.meshgrid(X,Y)

    rho_FT = fftshift(fft2(rho))
    phi_full = ifft2(fftshift(PHI(rho_FT,X,Y)))
    phi_full = np.abs(phi_full)
    phi = phi_full[phi_full.shape[0]//2:-1,0:phi_full.shape[1]//2]

    if 0:
        figure, ax = plt.subplots(2,4,figsize=(15,10))
        ax[0,0].imshow(data), ax[0,0].set_title('Data')
        ax[0,1].imshow(wrapped_data),  ax[0,1].set_title("Wrapped")
        ax[0,2].imshow(phi), ax[0,2].set_title("Unwrapped")
        ax[0,3].imshow(unwrap_phase(wrapped_data)), ax[0,3].set_title("skimage")
        ax[1,0].imshow(data[135:185,200:250])
        ax[1,1].imshow(wrapped_data[135:185,200:250])
        ax[1,2].imshow(phi[135:185,200:250])
        ax[1,3].imshow(unwrap_phase(wrapped_data)[135:185,200:250])

    return phi

    
def calculate_phase_gradient(phase,direction):
    """ Axis convention:
    y-direction
    ^
    | 
    |___>  x-direction
    """
    if direction == 1:
        return np.roll(phase,1,axis=1)-phase
    if direction == 0:
        return np.roll(phase,-1,axis=0)-phase
        
def calculate_normalized_weight(phase,delta,eps0=0.01,p=0,direction=0):
    """
    Equation (38) and (39) of 'Minimum Lp-norm two-dimensional phase unwrapping' (https://doi.org/10.1364/JOSAA.13.001999)
    """
    if direction == 1:
        phase_gradient_x = calculate_phase_gradient(phase,direction)
        U = eps0 / ( np.abs(phase_gradient_x-delta)**(2-p) + eps0 ) 
        U[:,-1] = 0
        return U
    elif direction == 0:
        phase_gradient_y = calculate_phase_gradient(phase,direction)
        V = eps0 / ( np.abs(phase_gradient_y-delta)**(2-p) + eps0 )
        V[-1,:] = 0 
        return V


def calculate_wrapped_phase_difference(wrapped_phase,direction=0):
    """
    Equation (2) and (3) of 'Minimum Lp-norm two-dimensional phase unwrapping' (https://doi.org/10.1364/JOSAA.13.001999)
    """
    if direction == 1:
        f = calculate_phase_gradient(wrapped_phase,direction) 
        f[:,-1] = 0
        return f
    elif direction == 0:
        g = calculate_phase_gradient(wrapped_phase,direction) 
        g[:,-1] = 0
        return f

    
def modified_Laplacian_weighted_wrapped_phase_gradient(phase,wrapped_phase):
    """
    Equation (37) of 'Minimum Lp-norm two-dimensional phase unwrapping' (https://doi.org/10.1364/JOSAA.13.001999)
    
    """
    
    delta_x1 = calculate_wrapped_phase_difference(wrapped_phase,0)
    U1 = calculate_normalized_weight(phase,delta_x1,direction=0)
    
    delta_x2 = calculate_wrapped_phase_difference(np.roll(wrapped_phase,-1,axis=1),0) # check if internal shift is correct
    U2 = calculate_normalized_weight(np.roll(phase,-1,axis=1),delta_x2,direction=0)
    
    delta_y1 = calculate_wrapped_phase_difference(wrapped_phase,1)
    V1 = calculate_normalized_weight(phase,delta_y1,direction=1)
    
    delta_y2 = calculate_wrapped_phase_difference(np.roll(wrapped_phase,1,axis=0),1)
    V2 = calculate_normalized_weight(np.roll(phase,1,axis=0),delta_y2,direction=1)
    
    return delta_x1*U1 - delta_x2*U2 + delta_y1*V1 - delta_y2*V2
    
def modified_Laplacian_phase_gradient(phase,wrapped_phase):
    """
    Equation (37) of 'Minimum Lp-norm two-dimensional phase unwrapping' (https://doi.org/10.1364/JOSAA.13.001999)
    
    """

    delta_x1_phase = calculate_wrapped_phase_difference(phase,0)    
    delta_x1_wrapped = calculate_wrapped_phase_difference(wrapped_phase,0)
    U1 = calculate_normalized_weight(phase,delta_x1_wrapped,direction=0)
    
    delta_x2_phase = calculate_wrapped_phase_difference(np.roll(phase,-1,axis=1),0)    
    delta_x2_wrapped = calculate_wrapped_phase_difference(np.roll(wrapped_phase,-1,axis=1),0) # check if internal shift is correct
    U2 = calculate_normalized_weight(np.roll(phase,-1,axis=1),delta_x2_wrapped,direction=0)
    
    delta_y1_phase = calculate_wrapped_phase_difference(phase,0)    
    delta_y1_wrapped = calculate_wrapped_phase_difference(wrapped_phase,1)
    V1 = calculate_normalized_weight(phase,delta_y1_wrapped,direction=1)
    
    delta_y2_phase = calculate_wrapped_phase_difference(np.roll(phase,1,axis=0),1)    
    delta_y2_wrapped = calculate_wrapped_phase_difference(np.roll(wrapped_phase,1,axis=0),1)
    V2 = calculate_normalized_weight(np.roll(delta_y2_wrapped,1,axis=0),delta_y2,direction=1)
    
    return delta_x1_phase*U1 - delta_x2_phase*U2 + delta_y1_phase*V1 - delta_y2_phase*V2

def calculate_beta():
    pass

In [None]:
""" Minimun Lp-norm 2D phase unwrapping """
phi_guess = np.ones((100,100))
original_shape = phi_guess.shape
l_max = 1



phi = phi_guess
for l in range(l_max): # outer loop
    
    if l >= l_max: break
    
    #TODO: computer residuals R(i,j)
    #TODO: correct phi according to wrapped residual
    
    C = modified_Laplacian_weighted_wrapped_phase_gradient(phase,wrapped_phase)
    Qphi = modified_Laplacian_phase_gradient(phase,wrapped_phase)
    
    #TODO: solve Q*phi = cl using Weighted Least Squares (WLS)
    residual  = C - Qphi
    z = unwrap_LS(residual).flatten() # solve P*z = r (step 2, Algorithm WLS, 'Minimum Lp-norm two-dimensional phase unwrapping' https://doi.org/10.1364/JOSAA.13.001999)
    residual = residual.flatten()
    for k in range(1): # inner loop
    """ Weighted Least Squares (WLS) """

        if k == 0:
            r2 = residual
            z2 = z
            continue
        if k == 1:
            p = z2 # from k==0 conditional
            r1 = r2
        else:
            beta = calculate_beta(r1,z1,r2,z2)
            p = z1 + beta*p
        
        alpha = calculate_alpha()
        
        phi = phi + alpha*p
        
        residual = r1 - alpha*modified_Laplacian_phase_gradient(p,wrapped_phase)
        
        r2 = r1 # save for next iterations
        r1 = residual
        
        if 0: # calculate error and compare to threshold
            break # if true, end loop
        else:
            z = unwrap_LS(np.reshape(residual,original_shape)).flatten() # solve P*z = r (step 2, Algorithm WLS, 'Minimum Lp-norm two-dimensional phase unwrapping' https://doi.org/10.1364/JOSAA.13.001999)
            z2 = z1
            z1 = z

