In [None]:
import time
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import ipywidgets as pyw
import hyperspy.api as hys
import tkinter.filedialog as tkf
from tabulate import tabulate
from sklearn.decomposition import NMF
import malspy

In [None]:
# create a customized colorbar
color_rep = ["black", "red", "green", "blue", "orange", "purple", "yellow", "lime", 
             "cyan", "magenta", "lightgray", "peru", "springgreen", "deepskyblue", 
             "hotpink", "darkgray"]
print(len(color_rep))
custom_cmap = mcolors.ListedColormap(color_rep)
bounds = np.arange(-1, len(color_rep))
norm = mcolors.BoundaryNorm(boundaries=bounds, ncolors=len(color_rep))
sm = cm.ScalarMappable(cmap=custom_cmap, norm=norm)
sm.set_array([])

cm_rep = ["gray", "Reds", "Greens", "Blues", "Oranges", "Purples"]
print(len(cm_rep))

In [None]:
def zero_one_rescale(spectrum):
    """
    normalize one spectrum from 0.0 to 1.0
    """
    spectrum = spectrum.clip(min=0.0)
    min_val = np.min(spectrum)
    
    rescaled = spectrum - min_val
    
    if np.max(rescaled) != 0:
        rescaled = rescaled / np.max(rescaled)
    
    return rescaled

In [None]:
def indices_at_r(shape, radius, center=None):
    y, x = np.indices(shape)
    if not center:
        center = np.array([(y.max()-y.min())/2.0, (x.max()-x.min())/2.0])
    r = np.hypot(y - center[0], x - center[1])
    r = np.around(r)
    
    ri = np.where(r == radius)
    
    angle_arr = np.zeros(shape)
    for i in range(shape[0]):
        for j in range(shape[1]):
            angle_arr[i, j] = np.angle(complex(x[i, j]-center[1], y[i, j]-center[0]), deg=True)
            
    angle_arr = angle_arr + 180
    angle_arr = np.around(angle_arr)
    
    ai = np.argsort(angle_arr[ri])
    r_sort = (ri[1][ai], ri[0][ai])
    a_sort = np.sort(angle_arr[ri])
        
    return r_sort, a_sort

In [None]:
def circle_flatten(f_stack, radial_range, c_pos):
    k_indx = []
    k_indy = []
    
    for r in range(radial_range[0], radial_range[1], radial_range[2]):
        tmp_k, tmp_a = indices_at_r(f_stack.shape[2:], r, c_pos)
        k_indx.extend(tmp_k[0].tolist())
        k_indy.extend(tmp_k[1].tolist())
    
    k_indx = np.asarray(k_indx)
    k_indy = np.asarray(k_indy)
    flat_data = f_stack[:, :, k_indy, k_indx]
    print(flat_data.shape)
    
    return flat_data

In [None]:
def reshape_coeff(coeffs, new_shape):
    """
    reshape a coefficient matrix to restore the original scanning shapes.
    """
    coeff_reshape = []
    for i in range(len(new_shape)):
        temp = coeffs[:int(new_shape[i, 0]*new_shape[i, 1]), :]
        coeffs = np.delete(coeffs, range(int(new_shape[i, 0]*new_shape[i, 1])), axis=0)
        temp = np.reshape(temp, (new_shape[i, 0], new_shape[i, 1], -1))
        #print(temp.shape)
        coeff_reshape.append(temp)
        
    return coeff_reshape

In [None]:
def fourd_roll_axis(stack):
    stack = np.rollaxis(np.rollaxis(stack, 2, 0), 3, 1)
    return stack

In [None]:
file_adr = []

In [None]:
file_adr.extend(tkf.askopenfilenames())
print(len(file_adr))
print(*file_adr, sep="\n")

In [None]:
num_img = len(file_adr)
print(num_img)

In [None]:
data_original = []
data_shape = []
for i in range(num_img):
    tmp = hys.load(file_adr[i]).data
    tmp = fourd_roll_axis(tmp)
    print(tmp.shape)
    data_shape.append(list(tmp.shape[:2]))
    data_original.append(tmp)
    
data_shape = np.asarray(data_shape)

In [None]:
center_pos = []
cbox_edge = 100
for i in range(num_img):
    mean_dp = np.mean(data_original[i], axis=(0, 1))
    cbox_outy = int(mean_dp.shape[0]/2 - cbox_edge/2)
    cbox_outx = int(mean_dp.shape[1]/2 - cbox_edge/2)
    center_box = mean_dp[cbox_outy:-cbox_outy, cbox_outx:-cbox_outx]
    Y, X = np.indices(center_box.shape)
    com_y = np.sum(center_box * Y) / np.sum(center_box)
    com_x = np.sum(center_box * X) / np.sum(center_box)
    c_pos = [np.around(com_y+cbox_outy), np.around(com_x+cbox_outx)]
    center_pos.append(c_pos)
print(center_pos)

In [None]:
radial_range = [0, 30, 1]
k_indx = []
k_indy = []
a_ind = []

for r in range(radial_range[0], radial_range[1], radial_range[2]):
    tmp_k, tmp_a = indices_at_r((radial_range[1]*2, radial_range[1]*2), r, (radial_range[1], radial_range[1]))
    k_indx.extend(tmp_k[0].tolist())
    k_indy.extend(tmp_k[1].tolist())
    a_ind.extend(tmp_a.tolist())
    
print(len(k_indx))
print(len(k_indy))
print(len(a_ind))

s_length = len(k_indx)

In [None]:
dataset = []
for i in range(num_img):
    tmp = circle_flatten(data_original[i], radial_range, center_pos[i])
    dataset.extend(tmp.tolist())

dataset = np.asarray(dataset)
print(dataset.shape)

In [None]:
# create the input dataset
dataset_flat = dataset.reshape(-1, s_length)
print(dataset_flat.shape)

In [None]:
total_num = len(dataset_flat)
ri = np.random.choice(total_num, total_num, replace=False)

dataset_input = dataset_flat[ri]

In [None]:
num_comp = 5

In [None]:
%matplotlib inline

In [None]:
%matplotlib qt

In [None]:
# https://github.com/MotokiShiga/malspy
# NMF with automatic relevance determination and soft orthogonality penalty
model_nmf_ardso = malspy.NMF_ARD_SO(n_components=num_comp, wo=0.1, reps=5, max_itr=100)
model_nmf_ardso.fit(dataset_input)
model_nmf_ardso.plot_spectra(figsize=(6, 3), normalize=False)
ardso_coeffs = model_nmf_ardso.C_
print(ardso_coeffs.shape)
ardso_comp_vectors = model_nmf_ardso.S_
ardso_comp_vectors = np.rollaxis(ardso_comp_vectors, 1, 0)
print(ardso_comp_vectors.shape)

In [None]:
# convert the coefficient matrix into coefficient maps
coeffs = np.zeros_like(ardso_coeffs)
coeffs[ri] = ardso_coeffs.copy()
coeffs_reshape = reshape_coeff(coeffs, data_shape)

In [None]:
for i in range(num_comp):
    tmp = np.zeros((radial_range[1]*2, radial_range[1]*2))
    tmp[k_indy, k_indx] = ardso_comp_vectors[i]
    
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    ax.imshow(tmp, cmap="viridis")
    ax.axis("off")
    fig.tight_layout()
    plt.show()

In [None]:
# visualize the coefficient maps
if num_img != 1:
    for i in range(num_comp):
        fig, ax = plt.subplots(1, num_img, figsize=(5*num_img, 5))
        for j in range(num_img):
            tmp = ax[j].imshow(coeffs_reshape[j][:, :, i], cmap="viridis")
            ax[j].set_title("loading vector %d map"%(i+1), fontsize=10)
            ax[j].axis("off")
            #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()
else:            
    for i in range(num_comp):
        fig, ax = plt.subplots(1, 1, figsize=(5*num_img, 5))
        tmp = ax.imshow(coeffs_reshape[0][:, :, i],cmap="viridis")
        ax.set_title("loading vector %d map"%(i+1), fontsize=10)
        ax.axis("off")
        #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()

In [None]:
# https://github.com/MotokiShiga/malspy
# NMF with soft orthogonality penalty
model_nmf_so = malspy.NMF_SO(n_components=num_comp, wo=0.10, reps=5, max_itr=100)
model_nmf_so.fit(dataset_input)
so_coeffs = model_nmf_so.C_
print(so_coeffs.shape)
so_comp_vectors = model_nmf_so.S_
so_comp_vectors = np.rollaxis(so_comp_vectors, 1, 0)
print(so_comp_vectors.shape)

In [None]:
# convert the coefficient matrix into coefficient maps
coeffs = np.zeros_like(so_coeffs)
coeffs[ri] = so_coeffs.copy()
coeffs_reshape = reshape_coeff(coeffs, data_shape)

In [None]:
for i in range(num_comp):
    tmp = np.zeros((radial_range[1]*2, radial_range[1]*2))
    tmp[k_indy, k_indx] = so_comp_vectors[i]
    
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    ax.imshow(tmp, cmap="viridis")
    ax.axis("off")
    fig.tight_layout()
    plt.show()

In [None]:
# visualize the coefficient maps
if num_img != 1:
    for i in range(num_comp):
        fig, ax = plt.subplots(1, num_img, figsize=(5*num_img, 5))
        for j in range(num_img):
            tmp = ax[j].imshow(coeffs_reshape[j][:, :, i], cmap="viridis")
            ax[j].set_title("loading vector %d map"%(i+1), fontsize=10)
            ax[j].axis("off")
            #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()
else:            
    for i in range(num_comp):
        fig, ax = plt.subplots(1, 1, figsize=(5*num_img, 5))
        tmp = ax.imshow(coeffs_reshape[0][:, :, i],cmap="viridis")
        ax.set_title("loading vector %d map"%(i+1), fontsize=10)
        ax.axis("off")
        #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()

In [None]:
# NMF decomposition (linear dimensionality reduction)
# please visit the below link for detailed information on NMF
# https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.NMF.html?highlight=nmf#sklearn.decomposition.NMF

skl_nmf = NMF(n_components=num_comp, init="nndsvda", solver="mu", max_iter=2000, 
              verbose=True, beta_loss="frobenius", l1_ratio=0.0, alpha=0.0)

skl_coeffs = skl_nmf.fit_transform(dataset_input)
skl_comp_vectors = skl_nmf.components_
print(skl_coeffs.shape)
print(skl_comp_vectors.shape)

In [None]:
# convert the coefficient matrix into coefficient maps
coeffs = np.zeros_like(skl_coeffs)
coeffs[ri] = skl_coeffs.copy()
coeffs_reshape = reshape_coeff(coeffs, data_shape)

In [None]:
for i in range(num_comp):
    tmp = np.zeros((radial_range[1]*2, radial_range[1]*2))
    tmp[k_indy, k_indx] = skl_comp_vectors[i]
    
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    ax.imshow(tmp, cmap="viridis")
    ax.axis("off")
    fig.tight_layout()
    plt.show()

In [None]:
# visualize the coefficient maps
if num_img != 1:
    for i in range(num_comp):
        fig, ax = plt.subplots(1, num_img, figsize=(5*num_img, 5))
        for j in range(num_img):
            tmp = ax[j].imshow(coeffs_reshape[j][:, :, i], cmap="viridis")
            ax[j].set_title("loading vector %d map"%(i+1), fontsize=10)
            ax[j].axis("off")
            #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()
else:            
    for i in range(num_comp):
        fig, ax = plt.subplots(1, 1, figsize=(5*num_img, 5))
        tmp = ax.imshow(coeffs_reshape[0][:, :, i],cmap="viridis")
        ax.set_title("loading vector %d map"%(i+1), fontsize=10)
        ax.axis("off")
        #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()