In [None]:
import time
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import ipywidgets as pyw
import hyperspy.api as hys
import tkinter.filedialog as tkf
from tabulate import tabulate
from sklearn.decomposition import NMF, PCA, FastICA, KernelPCA
#import malspy

In [None]:
def zero_one_rescale(spectrum):
    """
    normalize one spectrum from 0.0 to 1.0
    """
    spectrum = spectrum.clip(min=0.0)
    min_val = np.min(spectrum)
    
    rescaled = spectrum - min_val
    
    if np.max(rescaled) != 0:
        rescaled = rescaled / np.max(rescaled)
    
    return rescaled

def reshape_coeff(coeffs, new_shape):
    """
    reshape a coefficient matrix to restore the original scanning shapes.
    """
    coeff_reshape = []
    for i in range(len(new_shape)):
        temp = coeffs[:int(new_shape[i, 0]*new_shape[i, 1]), :]
        coeffs = np.delete(coeffs, range(int(new_shape[i, 0]*new_shape[i, 1])), axis=0)
        temp = np.reshape(temp, (new_shape[i, 0], new_shape[i, 1], -1))
        #print(temp.shape)
        coeff_reshape.append(temp)
        
    return coeff_reshape

def indices_at_r(shape, radius, center=None):
    y, x = np.indices(shape)
    if not center:
        center = np.array([(y.max()-y.min())/2.0, (x.max()-x.min())/2.0])
    r = np.hypot(y - center[0], x - center[1])
    r = np.around(r)
    
    ri = np.where(r == radius)
    
    angle_arr = np.zeros(shape)
    for i in range(shape[0]):
        for j in range(shape[1]):
            angle_arr[i, j] = np.angle(complex(x[i, j]-center[1], y[i, j]-center[0]), deg=True)
            
    angle_arr = angle_arr + 180
    angle_arr = np.around(angle_arr)
    
    ai = np.argsort(angle_arr[ri])
    r_sort = (ri[1][ai], ri[0][ai])
    a_sort = np.sort(angle_arr[ri])
        
    return r_sort, a_sort

def circle_flatten(f_stack, radial_range, c_pos):
    k_indx = []
    k_indy = []
    
    for r in range(radial_range[0], radial_range[1], radial_range[2]):
        tmp_k, tmp_a = indices_at_r(f_stack.shape[2:], r, c_pos)
        k_indx.extend(tmp_k[0].tolist())
        k_indy.extend(tmp_k[1].tolist())
    
    k_indx = np.asarray(k_indx)
    k_indy = np.asarray(k_indy)
    flat_data = f_stack[:, :, k_indy, k_indx]
    
    return flat_data

def flattening(fdata, flat_option="box", crop_dist=None, c_pos=None):
    
    fdata_shape = fdata.shape
    if flat_option == "box":
        if crop_dist:     
            box_size = np.array([crop_dist, crop_dist])
        
            for i in range(num_img):
                h_si = np.floor(c_pos[0]-box_size[0]).astype(int)
                h_fi = np.ceil(c_pos[0]+box_size[0]).astype(int)
                w_si = np.floor(c_pos[1]-box_size[1]).astype(int)
                w_fi = np.ceil(c_pos[1]+box_size[1]).astype(int)

            tmp = fdata[:, :, h_si:h_fi, w_si:w_fi]
            
            fig, ax = plt.subplots(1, 1, figsize=(5, 5))
            ax.imshow(np.log(np.mean(tmp, axis=(0, 1))), cmap="viridis")
            ax.axis("off")
            plt.show()
            
            tmp = tmp.reshape(fdata_shape[0], fdata_shape[1], -1)
            return tmp

        else:
            tmp = fdata.reshape(fdata_shape[0], fdata_shape[1], -1)
            return tmp

        
    elif flat_option == "radial":
        if len(crop_dist) != 3:
            print("Warning! 'crop_dist' must be a list containing 3 elements")
            
        tmp = circle_flatten(fdata, crop_dist, c_pos)
        return tmp
        
    else:
        print("Warning! Wrong option ('flat_option')")
        return

def fourd_roll_axis(stack):
    stack = np.rollaxis(np.rollaxis(stack, 2, 0), 3, 1)
    return stack

def radial_indices(shape, radial_range, center=None):
    y, x = np.indices(shape)
    if not center:
        center = np.array([(y.max()-y.min())/2.0, (x.max()-x.min())/2.0])
    
    r = np.hypot(y - center[0], x - center[1])
    ri = np.ones(r.shape)
    
    if len(np.unique(radial_range)) > 1:
        ri[np.where(r <= radial_range[0])] = 0
        ri[np.where(r > radial_range[1])] = 0
        
    else:
        r = np.round(r)
        ri[np.where(r != round(radial_range[0]))] = 0
    
    return ri

def label_arrangement(label_arr, new_shape):
    """
    reshape a clustering result obtained by performing OPTICS
    """
    label_sort = np.unique(label_arr)
    #print(label_sort)
    num_label = len(label_sort)
    hist, edge = np.histogram(label_arr, bins=num_label)
    #print(hist)
    label_reshape = reshape_coeff(label_arr.reshape(-1, 1), new_shape)
    
    for i in range(len(label_reshape)):
        label_reshape[i] = np.squeeze(label_reshape[i])
        
    selected = []
    for i in range(num_label):
        temp = []
        for j in range(len(label_reshape)):
            img_temp = np.zeros_like(label_reshape[j])
            img_temp[np.where(label_reshape[j] == label_sort[i])] = 1.0
            temp.append(img_temp)
        selected.append(temp)    
        
    return label_reshape, selected, hist

In [None]:
file_adr = []

In [None]:
file_adr.extend(tkf.askopenfilenames())
print(len(file_adr))
print(*file_adr, sep="\n")

In [None]:
num_img = len(file_adr)
print(num_img)

In [None]:
# load 4D-STEM data
data_original = []
data_shape = []
for i in range(num_img):
    tmp = hys.load(file_adr[i]).data
    if file_adr[i][-3:]=="dm3" or file_adr[i][-3:]=="dm4":
        tmp = fourd_roll_axis(tmp)
    print(tmp.shape)
    data_shape.append(list(tmp.shape[:2]))
    data_original.append(tmp)
    
data_shape = np.asarray(data_shape)

In [None]:
# find the center position
center_pos = []
cbox_edge = 150
center_removed_ = False
for i in range(num_img):
    mean_dp = np.mean(data_original[i], axis=(0, 1))
    cbox_outy = int(mean_dp.shape[0]/2 - cbox_edge/2)
    cbox_outx = int(mean_dp.shape[1]/2 - cbox_edge/2)
    center_box = mean_dp[cbox_outy:-cbox_outy, cbox_outx:-cbox_outx]
    Y, X = np.indices(center_box.shape)
    com_y = np.sum(center_box * Y) / np.sum(center_box)
    com_x = np.sum(center_box * X) / np.sum(center_box)
    c_pos = [np.around(com_y+cbox_outy), np.around(com_x+cbox_outx)]
    center_pos.append(c_pos)
print(center_pos)

In [None]:
np.seterr(divide='ignore')
for i in range(num_img):
    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    ax.imshow(np.log(np.mean(data_original[i], axis=(0, 1))), cmap="viridis")
    ax.scatter(center_pos[i][1], center_pos[i][0], c="r", s=10)
    ax.axis("off")
    plt.show()

In [None]:
# get rid of the center beam (optional)
center_removed_ = True
center_radius = 10
data_cr = []
for i in range(num_img):
    ri = radial_indices(data_original[i].shape[2:], [center_radius, 100], center=center_pos[i])
    data_cr.append(np.multiply(data_original[i], ri))

In [None]:
if center_removed_:
    for i in range(num_img):
        fig, ax = plt.subplots(1, 1, figsize=(10, 10))
        ax.imshow(np.log(np.mean(data_cr[i], axis=(0, 1))), cmap="viridis")
        ax.scatter(center_pos[i][1], center_pos[i][0], c="r", s=10)
        ax.axis("off")
        plt.show()

In [None]:
# 2D diffraction pattern -> 1D data
# option 1 : flatten a box
radial_flat_ = False

dataset = []
w_size = 100
for i in range(num_img):
    if center_removed_:
        flattened = flattening(data_cr[i], flat_option="box", crop_dist=w_size, c_pos=center_pos[i])
    
    else:
        flattened = flattening(data_original[i], flat_option="box", crop_dist=w_size, c_pos=center_pos[i])
    
    dataset.append(flattened)
    
s_length = (w_size*2)**2

In [None]:
# 2D diffraction pattern -> 1D data
# option 2 : flatten radially
radial_flat_ = True

dataset = []
radial_range = [45, 60, 1]
k_indx = []
k_indy = []
a_ind = []

for r in range(radial_range[0], radial_range[1], radial_range[2]):
    tmp_k, tmp_a = indices_at_r((radial_range[1]*2, radial_range[1]*2), r, (radial_range[1], radial_range[1]))
    k_indx.extend(tmp_k[0].tolist())
    k_indy.extend(tmp_k[1].tolist())
    a_ind.extend(tmp_a.tolist())
    
s_length = len(k_indx)

for i in range(num_img):
    if center_removed_:
        flattened = circle_flatten(data_cr[i], radial_range, center_pos[i])
    else:
        flattened = circle_flatten(data_original[i], radial_range, center_pos[i])
        
    dataset.append(flattened)

In [None]:
# create the input dataset
dataset_flat = []
for i in range(num_img):
    print(dataset[i].shape)
    dataset_flat.extend(dataset[i].reshape(-1, s_length))
dataset_flat = np.asarray(dataset_flat).clip(min=0.0)
print(dataset_flat.shape)

In [None]:
# convert values into log scale (optional)
dataset_flat[np.where(dataset_flat==0.0)] = 1.0
dataset_flat = np.log(dataset_flat)

In [None]:
# max-normalize each flattened diffraction pattern (optional)
dataset_flat = dataset_flat / np.max(dataset_flat, axis=1)[:, np.newaxis]
dataset_flat = dataset_flat.clip(min=0.0)
print(dataset_flat.shape)

In [None]:
total_num = len(dataset_flat)
ri = np.random.choice(total_num, total_num, replace=False)

dataset_input = dataset_flat[ri]
dataset_input = dataset_input.astype(np.float32)

In [None]:
num_comp = 4

In [None]:
%matplotlib inline

In [None]:
%matplotlib qt

In [None]:
# NMF decomposition (linear dimensionality reduction)
# please visit the below link for detailed information on NMF
# https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.NMF.html?highlight=nmf#sklearn.decomposition.NMF

skl_nmf = NMF(n_components=num_comp, init="nndsvda", solver="mu", max_iter=2000, 
              verbose=True, beta_loss="frobenius", l1_ratio=0.0, alpha=0.0)

skl_coeffs = skl_nmf.fit_transform(dataset_input)
skl_comp_vectors = skl_nmf.components_
print(skl_coeffs.shape)
print(skl_comp_vectors.shape)

In [None]:
# convert the coefficient matrix into coefficient maps
coeffs = np.zeros_like(skl_coeffs)
coeffs[ri] = skl_coeffs.copy()
coeffs_reshape = reshape_coeff(coeffs, data_shape)

In [None]:
# visualize loading vectors
if radial_flat_:
    for i in range(nmf_num_comp):
        tmp = np.zeros((radial_range[1]*2, radial_range[1]*2))
        tmp[k_indy, k_indx] = skl_comp_vectors[i]

        fig, ax = plt.subplots(1, 1, figsize=(5, 5))
        ax.imshow(tmp, cmap="viridis")
        ax.axis("off")
        fig.tight_layout()
        plt.show()

else:
    for i in range(num_comp):
        fig, ax = plt.subplots(1, 1, figsize=(5, 5))
        ax.imshow(skl_comp_vectors[i].reshape((w_size*2, w_size*2)), cmap="viridis")
        ax.axis("off")
        fig.tight_layout()
        plt.show()

In [None]:
# visualize the coefficient maps
if num_img != 1:
    for i in range(num_comp):
        fig, ax = plt.subplots(1, num_img, figsize=(5*num_img, 5))
        for j in range(num_img):
            tmp = ax[j].imshow(coeffs_reshape[j][:, :, i], cmap="viridis")
            ax[j].set_title("loading vector %d map"%(i+1), fontsize=10)
            ax[j].axis("off")
            #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()
else:            
    for i in range(num_comp):
        fig, ax = plt.subplots(1, 1, figsize=(5*num_img, 5))
        tmp = ax.imshow(coeffs_reshape[0][:, :, i], cmap="viridis")
        ax.set_title("loading vector %d map"%(i+1), fontsize=10)
        ax.axis("off")
        #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()

In [None]:
skl_nmf = PCA(n_components=10, whiten=True, svd_solver="randomized")

skl_coeffs = skl_nmf.fit_transform(dataset_input)
skl_comp_vectors = skl_nmf.components_
print(skl_coeffs.shape)
print(skl_comp_vectors.shape)

In [None]:
# convert the coefficient matrix into coefficient maps
coeffs = np.zeros_like(skl_coeffs)
coeffs[ri] = skl_coeffs.copy()
coeffs_reshape = reshape_coeff(coeffs, data_shape)

In [None]:
# visualize loading vectors
if radial_flat_:
    for i in range(nmf_num_comp):
        tmp = np.zeros((radial_range[1]*2, radial_range[1]*2))
        tmp[k_indy, k_indx] = skl_comp_vectors[i]

        fig, ax = plt.subplots(1, 1, figsize=(5, 5))
        ax.imshow(tmp, cmap="viridis")
        ax.axis("off")
        fig.tight_layout()
        plt.show()

else:
    for i in range(num_comp):
        fig, ax = plt.subplots(1, 1, figsize=(5, 5))
        ax.imshow(skl_comp_vectors[i].reshape((w_size*2, w_size*2)), cmap="viridis")
        ax.axis("off")
        fig.tight_layout()
        plt.show()

In [None]:
# visualize the coefficient maps
if num_img != 1:
    for i in range(num_comp):
        fig, ax = plt.subplots(1, num_img, figsize=(5*num_img, 5))
        for j in range(num_img):
            tmp = ax[j].imshow(coeffs_reshape[j][:, :, i], cmap="viridis")
            ax[j].set_title("loading vector %d map"%(i+1), fontsize=10)
            ax[j].axis("off")
            #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()
else:            
    for i in range(num_comp):
        fig, ax = plt.subplots(1, 1, figsize=(5*num_img, 5))
        tmp = ax.imshow(coeffs_reshape[0][:, :, i], cmap="viridis")
        ax.set_title("loading vector %d map"%(i+1), fontsize=10)
        ax.axis("off")
        #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()

In [None]:
print(skl_nmf.explained_variance_)
print(skl_nmf.explained_variance_ratio_)
print(np.sum(skl_nmf.explained_variance_ratio_))

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(15, 8))
ax.plot(np.cumsum(skl_nmf.explained_variance_ratio_), 'k-')
ax.plot(np.cumsum(skl_nmf.explained_variance_ratio_), 'r+')
plt.show()

In [None]:
skl_nmf = FastICA(n_components=num_comp)

skl_coeffs = skl_nmf.fit_transform(dataset_input)
skl_comp_vectors = skl_nmf.components_
print(skl_coeffs.shape)
print(skl_comp_vectors.shape)

In [None]:
# convert the coefficient matrix into coefficient maps
coeffs = np.zeros_like(skl_coeffs)
coeffs[ri] = skl_coeffs.copy()
coeffs_reshape = reshape_coeff(coeffs, data_shape)

In [None]:
# visualize loading vectors
if radial_flat_:
    for i in range(nmf_num_comp):
        tmp = np.zeros((radial_range[1]*2, radial_range[1]*2))
        tmp[k_indy, k_indx] = skl_comp_vectors[i]

        fig, ax = plt.subplots(1, 1, figsize=(5, 5))
        ax.imshow(tmp, cmap="viridis")
        ax.axis("off")
        fig.tight_layout()
        plt.show()

else:
    for i in range(num_comp):
        fig, ax = plt.subplots(1, 1, figsize=(5, 5))
        ax.imshow(skl_comp_vectors[i].reshape((w_size*2, w_size*2)), cmap="viridis")
        ax.axis("off")
        fig.tight_layout()
        plt.show()

In [None]:
# visualize the coefficient maps
if num_img != 1:
    for i in range(num_comp):
        fig, ax = plt.subplots(1, num_img, figsize=(5*num_img, 5))
        for j in range(num_img):
            tmp = ax[j].imshow(coeffs_reshape[j][:, :, i], cmap="viridis")
            ax[j].set_title("loading vector %d map"%(i+1), fontsize=10)
            ax[j].axis("off")
            #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()
else:            
    for i in range(num_comp):
        fig, ax = plt.subplots(1, 1, figsize=(5*num_img, 5))
        tmp = ax.imshow(coeffs_reshape[0][:, :, i], cmap="viridis")
        ax.set_title("loading vector %d map"%(i+1), fontsize=10)
        ax.axis("off")
        #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()

In [None]:
# https://github.com/MotokiShiga/malspy
# NMF with automatic relevance determination and soft orthogonality penalty
model_nmf_ardso = malspy.NMF_ARD_SO(n_components=num_comp, wo=0.1, reps=5, max_itr=100)
model_nmf_ardso.fit(dataset_input)
model_nmf_ardso.plot_spectra(figsize=(6, 3), normalize=False)
ardso_coeffs = model_nmf_ardso.C_
print(ardso_coeffs.shape)
ardso_comp_vectors = model_nmf_ardso.S_
ardso_comp_vectors = np.rollaxis(ardso_comp_vectors, 1, 0)
print(ardso_comp_vectors.shape)

In [None]:
# convert the coefficient matrix into coefficient maps
coeffs = np.zeros_like(ardso_coeffs)
coeffs[ri] = ardso_coeffs.copy()
coeffs_reshape = reshape_coeff(coeffs, data_shape)

In [None]:
# visualize loading vectors
for i in range(num_comp):
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    ax.imshow(ardso_comp_vectors[i].reshape(box_size*2), cmap="viridis")
    ax.axis("off")
    fig.tight_layout()
    plt.show()

In [None]:
# visualize the coefficient maps
if num_img != 1:
    for i in range(num_comp):
        fig, ax = plt.subplots(1, num_img, figsize=(5*num_img, 5))
        for j in range(num_img):
            tmp = ax[j].imshow(coeffs_reshape[j][:, :, i], cmap="viridis")
            ax[j].set_title("loading vector %d map"%(i+1), fontsize=10)
            ax[j].axis("off")
            #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()
else:            
    for i in range(num_comp):
        fig, ax = plt.subplots(1, 1, figsize=(5*num_img, 5))
        tmp = ax.imshow(coeffs_reshape[0][:, :, i], cmap="viridis")
        ax.set_title("loading vector %d map"%(i+1), fontsize=10)
        ax.axis("off")
        #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()

In [None]:
# https://github.com/MotokiShiga/malspy
# NMF with soft orthogonality penalty
model_nmf_so = malspy.NMF_SO(n_components=num_comp, wo=0.10, reps=5, max_itr=100)
model_nmf_so.fit(dataset_input)
so_coeffs = model_nmf_so.C_
print(so_coeffs.shape)
so_comp_vectors = model_nmf_so.S_
so_comp_vectors = np.rollaxis(so_comp_vectors, 1, 0)
print(so_comp_vectors.shape)

In [None]:
# convert the coefficient matrix into coefficient maps
coeffs = np.zeros_like(so_coeffs)
coeffs[ri] = so_coeffs.copy()
coeffs_reshape = reshape_coeff(coeffs, data_shape)

In [None]:
# visualize loading vectors
for i in range(num_comp):
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    ax.imshow(so_comp_vectors[i].reshape(box_size*2), cmap="viridis")
    ax.axis("off")
    fig.tight_layout()
    plt.show()

In [None]:
# visualize the coefficient maps
min_val = np.min(coeffs)
max_val = np.max(coeffs)
if num_img != 1:
    for i in range(num_comp):
        fig, ax = plt.subplots(1, num_img, figsize=(5*num_img, 5))
        for j in range(num_img):
            tmp = ax[j].imshow(coeffs_reshape[j][:, :, i], cmap="viridis")
            ax[j].set_title("loading vector %d map"%(i+1), fontsize=10)
            ax[j].axis("off")
            #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()
else:            
    for i in range(num_comp):
        fig, ax = plt.subplots(1, 1, figsize=(5*num_img, 5))
        tmp = ax.imshow(coeffs_reshape[0][:, :, i], cmap="viridis")
        ax.set_title("loading vector %d map"%(i+1), fontsize=10)
        ax.axis("off")
        #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()