In [None]:
# J. Ryu, Electron Microscopy and Spectroscopy Lab., Seoul National University
from scipy import stats
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as pch
import tkinter.filedialog as tkf
from scipy import optimize
from scipy import ndimage
import ipywidgets as pyw
import time
import tifffile
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import mean_squared_error
plt.rcParams['font.family'] = 'Times New Roman'

In [None]:
# refer to "github.com/mkolopanis/python/blob/master/radialProfile.py"
def radial_stats(image, center=None, var=True):
   
    y, x = np.indices(image.shape)
    if not center:
        center = np.array([(y.max()-y.min())/2.0, (x.max()-x.min())/2.0])
        
    r = np.hypot(y - center[0], x - center[1])
    #plt.imshow(r, cmap="Accent")
    #plt.show()

    # Get sorted radii
    ind = np.argsort(r.flat)
    r_sorted = r.flat[ind]
    i_sorted = image.flat[ind]

    # Get the integer part of the radii (bin size = 1)
    r_int = np.around(r_sorted)

    # Find all pixels that fall within each radial bin.
    deltar = r_int[1:] - r_int[:-1]  # Assumes all radii represented
    rind = np.where(deltar)[0]       # location of changed radius
    nr = rind[1:] - rind[:-1]        # number of radius bin
    #print(nr)
    
    csim = np.cumsum(i_sorted, dtype=float)
    sq_csim = np.cumsum(np.square(i_sorted), dtype=float)
    radial_avg  = (csim[rind[1:]] - csim[rind[:-1]]) / nr
    
    if var:    
        avg_square = np.square(radial_avg)
        square_avg = (sq_csim[rind[1:]] - sq_csim[rind[:-1]]) / nr
        mask = avg_square.copy()
        mask[np.where(avg_square==0)] = 1.0
        radial_var = (square_avg - avg_square) / mask
        return r, radial_avg, radial_var, (square_avg - avg_square)
    
    else:
        return r, radial_avg
    
def load_binary_4D_stack(img_adr, datatype, original_shape, final_shape, log_scale=False):
    stack = np.fromfile(img_adr, dtype=datatype)
    stack = stack.reshape(original_shape)
    print(stack.shape)
    if log_scale:
        stack = np.log(stack[:final_shape[0], :final_shape[1], :final_shape[2], :final_shape[3]])
    else:
        stack = stack[:final_shape[0], :final_shape[1], :final_shape[2], :final_shape[3]]
    
    print(stack.shape) 
    return stack

def fourd_roll_axis(stack):
    stack = np.rollaxis(np.rollaxis(stack, 2, 0), 3, 1)
    return stack

def remove_center_beam(image, center=None, cb_rad=0):
    y, x = np.indices(image.shape)
    if not center:
        center = np.array([(y.max()-y.min())/2.0, (x.max()-x.min())/2.0])
        
    r = np.hypot(y - center[0], x - center[1])
    r = np.around(r)
    ri = np.where(r<=cb_rad)
    
    image[ri] = 0
    
    return image

def radial_indices(shape, radial_range, center=None):
    y, x = np.indices(shape)
    if not center:
        center = np.array([(y.max()-y.min())/2.0, (x.max()-x.min())/2.0])
    
    r = np.hypot(y - center[0], x - center[1])
    ri = np.ones(r.shape)
    
    if len(np.unique(radial_range)) > 1:
        ri[np.where(r <= radial_range[0])] = 0
        ri[np.where(r > radial_range[1])] = 0
        
    else:
        r = np.round(r)
        ri[np.where(r != round(radial_range[0]))] = 0
    
    return ri

def indices_at_r(shape, radius, center=None):
    y, x = np.indices(shape)
    if not center:
        center = np.array([(y.max()-y.min())/2.0, (x.max()-x.min())/2.0])
    r = np.hypot(y - center[0], x - center[1])
    r = np.around(r)
    
    ri = np.where(r == radius)
    
    angle_arr = np.zeros(shape)
    for i in range(shape[0]):
        for j in range(shape[1]):
            angle_arr[i, j] = np.angle(complex(x[i, j]-center[1], y[i, j]-center[0]), deg=True)
            
    angle_arr = angle_arr + 180
    angle_arr = np.around(angle_arr)
    
    ai = np.argsort(angle_arr[ri])
    r_sort = (ri[1][ai], ri[0][ai])
    a_sort = np.sort(angle_arr[ri])
        
    return r_sort, a_sort

def local_var_similarity(var_map, w_size, stride):
    var_map = np.asarray(var_map)
    rows = range(0, var_map.shape[0]-w_size+1, stride)
    cols = range(0, var_map.shape[1]-w_size+1, stride)
    new_shape = (len(rows), len(cols))
    
    surr_avg = []
    surr_std = []
    surr_dif = []
    for i in rows:
        for j in cols:
            local_region = var_map[i:i+w_size, j:j+w_size].flatten()
            
            if np.max(local_region) != 0.0:
                local_region = local_region / np.max(local_region)
            else:
                local_region = local_region * 0.0
            
            temp_avg = np.mean(local_region)
            temp_std = np.std(local_region)
            surr_avg.append(temp_avg)
            surr_std.append(temp_std)
            diff_mse = np.sum(np.square(local_region - local_region[int(w_size**2/2)]))/(w_size**2-1)
            surr_dif.append(diff_mse)
            
    surr_avg = np.asarray(surr_avg).reshape(new_shape)
    surr_std = np.asarray(surr_std).reshape(new_shape)
    surr_dif = np.asarray(surr_dif).reshape(new_shape)
    
    return surr_avg, surr_std, surr_dif, new_shape

def local_DP_similarity(f_flat, w_size, stride):
    f_flat = np.asarray(f_flat)
    rows = range(0, f_flat.shape[0]-w_size+1, stride)
    cols = range(0, f_flat.shape[1]-w_size+1, stride)
    new_shape = (len(rows), len(cols))
    
    dp_mse = []
    dp_ssim = []
    for i in rows:
        for j in cols:        
            local_region = f_flat[i:i+w_size, j:j+w_size].reshape(w_size**2, -1)
            ref_dp = local_region[int(w_size**2/2)]
            local_region = np.delete(local_region, int(w_size**2/2), axis=0)
            tmp_mse = []
            tmp_ssim = []
            for fdp in local_region:
                tmp_mse.append(mean_squared_error(ref_dp/np.max(ref_dp), fdp/np.max(fdp)))
                tmp_ssim.append(ssim(ref_dp/np.max(ref_dp), fdp/np.max(fdp)))
                
            dp_mse.append(np.mean(tmp_mse))
            dp_ssim.append(np.mean(tmp_ssim))
            
    dp_mse = np.asarray(dp_mse).reshape(new_shape)
    dp_ssim = np.asarray(dp_ssim).reshape(new_shape)
    
    return dp_mse, dp_ssim, new_shape

In [None]:
raw_adr = tkf.askopenfilenames()
print(raw_adr)

In [None]:
for adr in raw_adr:

    # Load a data (DM)
    stack_4d = tifffile.imread(adr)
    print(stack_4d.shape)
    f_shape = stack_4d.shape

    stack_4d = stack_4d / np.max(stack_4d)
    if np.isnan(np.max(stack_4d)):
        print("NaN exists")
        stack_4d = np.nan_to_num(stack_4d)
    stack_4d = stack_4d.clip(min=0.0)
    print(np.max(stack_4d))
    print(np.min(stack_4d))
    print(np.mean(stack_4d))

    stack_4d_original = stack_4d.copy()
    
    # find the center position (center of mass)
    mean_dp = np.mean(stack_4d, axis=(0, 1))
    cbox_edge = 15
    cbox_outy = int(mean_dp.shape[0]/2 - cbox_edge/2)
    cbox_outx = int(mean_dp.shape[1]/2 - cbox_edge/2)
    center_box = mean_dp[cbox_outy:-cbox_outy, cbox_outx:-cbox_outx]
    Y, X = np.indices(center_box.shape)
    com_y = np.sum(center_box * Y) / np.sum(center_box)
    com_x = np.sum(center_box * X) / np.sum(center_box)
    c_pos_original = [com_y+cbox_outy, com_x+cbox_outx]
    c_pos = [com_y+cbox_outy, com_x+cbox_outx]
    print(c_pos)

    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    ax.imshow(mean_dp, cmap="gray")
    ax.axis("off")
    ax.scatter(c_pos[1], c_pos[0], s=15, c="r")
    ax.add_patch(pch.Rectangle((cbox_outx, cbox_outy), cbox_edge, cbox_edge, 
                               linewidth=1, edgecolor="r", facecolor="none"))
    fig.tight_layout()
    plt.show()
    
    # radial average of DPs (not variance, intensity direcltly, RDF?)
    radial_avg_stack = []
    len_profile = []
    for i in range(f_shape[0]):
        for j in range(f_shape[1]):
            _, radial_temp = radial_stats(stack_4d[i, j], center=c_pos, var=False)
            len_profile.append(len(radial_temp))
            radial_avg_stack.append(radial_temp)

    if len(np.unique(len_profile)) > 1:
        print(np.unique(len_profile))
        shortest = np.min(len_profile)
        for i in range(len(radial_avg_stack)):
            radial_avg_stack[i] = radial_avg_stack[i][:shortest]

    radial_avg_stack = np.asarray(radial_avg_stack).reshape(f_shape[0], f_shape[1], -1)
    print(radial_avg_stack.shape)

    radial_avg_sum = np.sum(radial_avg_stack, axis=(0, 1))
    print(radial_avg_sum.shape)
    
    # save (radial average, 3D)
    tifffile.imsave(adr[:-4]+"_radial_avg.tif", radial_avg_stack)
    
    # calculate variance with all angles at a certain k
    radial_var_stack = []
    len_profile = []
    for i in range(f_shape[0]):
        for j in range(f_shape[1]):
            _, _, radial_temp, _ = radial_stats(stack_4d[i, j], center=c_pos, var=True)
            len_profile.append(len(radial_temp))
            radial_var_stack.append(radial_temp)

    if len(np.unique(len_profile)) > 1:
        print(np.unique(len_profile))
        shortest = np.min(len_profile)
        for i in range(len(len_profile)):
            radial_var_stack[i] = radial_var_stack[i][:shortest]

    radial_var_stack = np.asarray(radial_var_stack).reshape(f_shape[0], f_shape[1], -1)
    print(radial_var_stack.shape)

    radial_var_sum = np.sum(radial_var_stack, axis=(0, 1))
    print(radial_var_sum.shape)
    
    # save (radial variance, 3D)
    tifffile.imsave(adr[:-4]+"_radial_var.tif", radial_var_stack)

In [None]:
%matplotlib widget

In [None]:
%matplotlib inline

In [None]:
# maximum intensity distribution
max_int = np.max(stack_4d, axis=(2, 3))
print(max_int.shape)
print(np.max(max_int))
print(np.min(max_int))
print(np.mean(max_int))
print(np.median(max_int))

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(max_int, cmap="viridis")
ax[0].axis("off")
ax[1].hist(max_int.flatten(), bins=len(max_int))
ax[1].grid()
fig.tight_layout()
plt.show()

In [None]:
# total intensity distribution
tot_int = np.sum(stack_4d, axis=(2, 3))
print(max_int.shape)
print(np.max(tot_int))
print(np.min(tot_int))
print(np.mean(tot_int))
print(np.median(tot_int))

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
ax[0].imshow(tot_int, cmap="viridis")
ax[0].axis("off")
ax[1].hist(tot_int.flatten(), bins=len(max_int))
ax[1].grid()
fig.tight_layout()
plt.show()