## dimensionality reduction and unsupervised clustering for EELS-SI
Miyoung Kim (corresponding author, mkim@snu.ac.kr)<br>
Jinseok Ryu (jinseuk56@gmail.com), Hyeohn Kim, Ryeong Myeong Kim, Sungtae Kim, Ki Tae Nam, Young-Chang Joo<br>
Dept. of Materials Science and Engineering, Seoul National University<br>
https://doi.org/10.1016/j.ultramic.2021.113314<br>
### last update 20210719

In [None]:
# J. Ryu, Electron Microscopy and Spectroscopy Lab., Seoul National University
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.cm as cm
import matplotlib.colors as mcolors
from matplotlib.colors import ListedColormap
import hyperspy.api as hys
from sklearn.decomposition import NMF
from sklearn.manifold import TSNE
from sklearn.cluster import OPTICS
import tkinter.filedialog as tkf
import tifffile
import ipywidgets as pyw
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import linalg as LA

In [None]:
# create a customized colorbar
color_rep = ["black", "red", "green", "blue", "orange", "purple", "yellow", "lime", 
             "cyan", "magenta", "lightgray", "peru", "springgreen", "deepskyblue", 
             "hotpink", "darkgray"]
print(len(color_rep))
custom_cmap = mcolors.ListedColormap(color_rep)
bounds = np.arange(-1, len(color_rep))
norm = mcolors.BoundaryNorm(boundaries=bounds, ncolors=len(color_rep))
sm = cm.ScalarMappable(cmap=custom_cmap, norm=norm)
sm.set_array([])

cm_rep = ["gray", "Reds", "Greens", "Blues", "Oranges", "Purples"]
print(len(cm_rep))

In [None]:
def zero_one_rescale(spectrum):
    """
    normalize one spectrum from 0.0 to 1.0
    """
    spectrum = spectrum.clip(min=0.0)
    min_val = np.min(spectrum)
    
    rescaled = spectrum - min_val
    
    if np.max(rescaled) != 0:
        rescaled = rescaled / np.max(rescaled)
    
    return rescaled

In [None]:
def threed_roll_axis(img):
    stack = np.rollaxis(img, 2, 0)
    return stack

In [None]:
def data_load(adr, rescale=False, crop=None, roll_axis=False):
    """
    load a spectrum image
    """
    storage = []
    shape = []
    for i, adr in enumerate(adr):
        if crop:
            temp = hys.load(adr)
            print(temp.axes_manager[2])
            temp = temp.isig[crop[0]:crop[1]]
            temp = temp.data
            print(temp.shape)
        
        else:
            temp = hys.load(adr).data
            if roll_axis:
                temp = threed_roll_axis(temp)
            print(temp.shape)
            
        if rescale:
            for j in range(temp.shape[0]):
                for k in range(temp.shape[1]):
                    temp[j, k] = zero_one_rescale(temp[j, k])
        shape.append(temp.shape)
        storage.append(temp)
    
    shape = np.asarray(shape)
    return storage, shape

In [None]:
def binning_SI(si, bin_y, bin_x, str_y, str_x, offset, depth, rescale=True):
    """
    re-bin a spectrum image
    """
    si = np.asarray(si)
    rows = range(0, si.shape[0]-bin_y+1, str_y)
    cols = range(0, si.shape[1]-bin_x+1, str_x)
    new_shape = (len(rows), len(cols))
    
    binned = []
    for i in rows:
        for j in cols:
            temp_sum = np.mean(si[i:i+bin_y, j:j+bin_x, offset:(offset+depth)], axis=(0, 1))
            if rescale:
                binned.append(zero_one_rescale(temp_sum))
            else:
                binned.append(temp_sum)
            
    binned = np.asarray(binned).reshape(new_shape[0], new_shape[1], depth)
    
    return binned

In [None]:
def reshape_coeff(coeffs, new_shape):
    """
    reshape a coefficient matrix to restore the original scanning shapes.
    """
    coeff_reshape = []
    for i in range(len(new_shape)):
        temp = coeffs[:int(new_shape[i, 0]*new_shape[i, 1]), :]
        coeffs = np.delete(coeffs, range(int(new_shape[i, 0]*new_shape[i, 1])), axis=0)
        temp = np.reshape(temp, (new_shape[i, 0], new_shape[i, 1], -1))
        #print(temp.shape)
        coeff_reshape.append(temp)
        
    return coeff_reshape

In [None]:
def label_arrangement(label_arr, new_shape):
    """
    reshape a clustering result obtained by performing OPTICS
    """
    label_sort = np.unique(label_arr)
    #print(label_sort)
    num_label = len(label_sort)
    hist, edge = np.histogram(label_arr, bins=num_label)
    #print(hist)
    label_reshape = reshape_coeff(label_arr.reshape(-1, 1), new_shape)
    
    for i in range(len(label_reshape)):
        label_reshape[i] = np.squeeze(label_reshape[i])
        
    selected = []
    for i in range(num_label):
        temp = []
        for j in range(len(label_reshape)):
            img_temp = np.zeros_like(label_reshape[j])
            img_temp[np.where(label_reshape[j] == label_sort[i])] = 1.0
            temp.append(img_temp)
        selected.append(temp)    
        
    return label_reshape, selected, hist

### load data

In [None]:
file_adr = []

In [None]:
file_adr.extend(tkf.askopenfilenames())
print(len(file_adr))
print(*file_adr, sep="\n")

In [None]:
num_img = len(file_adr)
print(num_img)

In [None]:
# load spectrum images
cr_range = [0.1, 5.0, 0.01] # reference
data_original, shape_original = data_load(file_adr, rescale=False, crop=cr_range, roll_axis=False)
print(len(data_original))
print(shape_original)

e_range_original = np.arange(cr_range[0], cr_range[1], cr_range[2])
print(len(e_range_original))

# load spectrum images
cr_range = [0.5, 3.5, 0.01] # actual input
data_storage, data_shape = data_load(file_adr, rescale=False, crop=cr_range, roll_axis=False)
print(len(data_storage))
print(data_shape)

e_range = np.arange(cr_range[0], cr_range[1], cr_range[2])
print(len(e_range))

In [None]:
# re-bin spectrum images in order to mitigate noises
bin_y = 1 # binning size (height)
bin_x = 1 # binning size (width)
str_y = 1 # stride height-direction
str_x = 1 # stride width-direction

# reference
dataset_original = []
shape_new_original = []

offset = 0
depth_original = len(e_range_original)
for img in data_original:
    print(img.shape)
    processed = binning_SI(img, bin_y, bin_x, str_y, str_x, offset, depth_original, rescale=False)
    print(processed.shape)
    shape_new_original.append(processed.shape)
    dataset_original.append(processed)
    
shape_new_original = np.asarray(shape_new_original)
print(shape_new_original)

# actual input
dataset = []
data_shape_new = []
offset = 0
depth = len(e_range)
for img in data_storage:
    print(img.shape)
    processed = binning_SI(img, bin_y, bin_x, str_y, str_x, offset, depth, rescale=True) # include the step for re-scaling the actual input
    print(processed.shape)
    data_shape_new.append(processed.shape)
    dataset.append(processed)
    
data_shape_new = np.asarray(data_shape_new)
print(data_shape_new)

In [None]:
dataset_original_flat = []
for i in range(num_img):
    dataset_original_flat.extend(dataset_original[i].clip(min=0.0).reshape(-1, depth_original).tolist())
    
dataset_original_flat = np.asarray(dataset_original_flat)
print(dataset_original_flat.shape)

# create the input dataset
dataset_flat = []
for i in range(num_img):
    dataset_flat.extend(dataset[i].clip(min=0.0).reshape(-1, depth).tolist())
    
dataset_flat = np.asarray(dataset_flat)
print(dataset_flat.shape)

In [None]:
total_num = len(dataset_flat)
ri = np.random.choice(total_num, total_num, replace=False)

dataset_input = dataset_flat[ri]

### dimensionality reduction (NMF)

In [None]:
%matplotlib inline

In [None]:
%matplotlib qt

In [None]:
nmf_num_comp = 5

In [None]:
# NMF decomposition (linear dimensionality reduction)
# please visit the below link for detailed information on NMF
# https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.NMF.html?highlight=nmf#sklearn.decomposition.NMF

skl_nmf = NMF(n_components=nmf_num_comp, init="nndsvda", solver="mu", max_iter=2000, 
              verbose=True, beta_loss="frobenius", l1_ratio=0.0, alpha=0.0)

skl_coeffs = skl_nmf.fit_transform(dataset_input)
skl_comp_vectors = skl_nmf.components_
print(skl_coeffs.shape)
print(skl_comp_vectors.shape)

In [None]:
# convert the coefficient matrix into coefficient maps
num_comp = nmf_num_comp
coeffs = np.zeros_like(skl_coeffs)
coeffs[ri] = skl_coeffs.copy()
coeffs_reshape = reshape_coeff(coeffs, data_shape_new)

In [None]:
# visualize loading vectors

fig, ax = plt.subplots(1, 2, figsize=(10, 4)) # all loading vectors
for i in range(nmf_num_comp):
    ax[0].plot(e_range, skl_comp_vectors[i], "-", c=color_rep[i+1], label="loading vector %d"%(i+1))
ax[0].grid()
ax[0].legend(fontsize="large")
ax[0].set_xlabel("eV", fontsize=10)
ax[0].tick_params(axis="x", labelsize=10)
ax[0].axes.get_yaxis().set_visible(False)

sel_nmf_comp = [2, 3, 4, 5] # choose several loading vectors to visualize
for i in sel_nmf_comp:
    ax[1].plot(e_range, skl_comp_vectors[i-1], "-", c=color_rep[i], label="loading vector %d"%(i))
ax[1].grid()
ax[1].legend(fontsize="large")
ax[1].set_xlabel("eV", fontsize=10)
ax[1].tick_params(axis="x", labelsize=10)
ax[1].axes.get_yaxis().set_visible(False)

fig.tight_layout()
plt.show()

In [None]:
%matplotlib inline

In [None]:
%matplotlib qt

In [None]:
# visualize the coefficient maps
min_val = np.min(coeffs)
max_val = np.max(coeffs)
if num_img != 1:
    for i in range(num_comp):
        fig, ax = plt.subplots(1, num_img, figsize=(7*num_img, 7))
        for j in range(num_img):
            tmp = ax[j].imshow(coeffs_reshape[j][:, :, i], vmin=min_val, vmax=max_val, cmap="afmhot")
            ax[j].set_title("loading vector %d map"%(i+1), fontsize=10)
            ax[j].axis("off")
            fig.colorbar(tmp, cax=fig.add_axes([0.88, 0.15, 0.04, 0.7]))
        plt.show()
else:            
    for i in range(num_comp):
        fig, ax = plt.subplots(1, 1, figsize=(7, 7*num_img))
        tmp = ax.imshow(coeffs_reshape[0][:, :, i], vmin=min_val, vmax=max_val, cmap="afmhot")
        ax.set_title("loading vector %d map"%(i+1), fontsize=10)
        ax.axis("off")
        fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()

In [None]:
plt.close(fig)

### nonlinear dimensionality reduction (t-SNE)

In [None]:
%matplotlib inline

In [None]:
%matplotlib qt

In [None]:
# t-SNE (non-linear dimensionality reduction)
# apply t-SNE to the coefficient matrix produced by NMF decomposition
# please visit the below link for detailed information on t-SNE
# https://scikit-learn.org/stable/modules/generated/sklearn.manifold.TSNE.html?highlight=tsne#sklearn.manifold.TSNE
start = time.time()
#perplex = [30, 35, 40, 45, 50]
#perplex = [5, 10, 15, 20, 25, 30, 35, 40, 45, 50] # try several perplexities
perplex = [50]
embeddings = []
num_comp_vis = 2 # number of dimensions of final data before clustering
for order, p in enumerate(perplex):
    tmp_tsne = TSNE(n_components=num_comp_vis, perplexity=p, early_exaggeration=5.0, learning_rate=300.0, 
                init="random", n_iter=1000, verbose=0)
    tmp_tsne.fit_transform(coeffs)
    plt.figure(figsize=(5, 5))
    plt.scatter(tmp_tsne.embedding_[:, 0], tmp_tsne.embedding_[:, 1], s=1, c="black")
    plt.title("perplexity %.1f"%perplex[order])
    plt.grid()
    plt.show()
    embeddings.append(tmp_tsne.embedding_)
    print("%d perplexity %.1f finished"%(order, p))
    print("%.2f min have passed"%((time.time()-start)/60))

In [None]:
# 2D results depending on perplexity
fig, ax = plt.subplots(2, 5, figsize=(20, 8))
for i, ai in enumerate(ax.flat):
    ai.scatter(embeddings[i][:, 0], embeddings[i][:, 1], s=1, c="black")
    ai.set_title("perplexity %.1f"%perplex[i])
    ai.grid()
fig.tight_layout()
plt.show()

In [None]:
# 2D result selected
sel_ind = 0
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.scatter(embeddings[sel_ind][:, 0], embeddings[sel_ind][:, 1], s=3, c="black", alpha=0.5)
ax.set_title("perplexity %.1f"%(perplex[sel_ind]))
#ax.grid()
#ax.axis("off")
fig.tight_layout()
plt.show()

In [None]:
plt.close(fig)

### unsupervised clustering (OPTICS)

In [None]:
# build a data matrix for clustering after performing t-SNE
num_comp = num_comp_vis
coeffs = embeddings[sel_ind].copy()

comp_axes = np.arange(num_comp)
#comp_axes = [1, 2]
if len(comp_axes) == 2:
    X = np.stack((coeffs[:, comp_axes[0]], coeffs[:, comp_axes[1]]), axis=1)
    print(X.shape)
    
elif len(comp_axes) == 3:
    X = np.stack((coeffs[:, comp_axes[0]], coeffs[:, comp_axes[1]], coeffs[:, comp_axes[2]]), axis=1)
    print(X.shape)

In [None]:
# OPTICS (density-based clustering)
# You can adjust the hyperparameters using the interative widgets
# The figure will be opened in a new window
# please visit the below link for detailed information on OPTICS
# https://scikit-learn.org/stable/modules/generated/sklearn.cluster.OPTICS.html?highlight=optics#sklearn.cluster.OPTICS

%matplotlib qt
fig = plt.figure(figsize=(10, 8))
G = gridspec.GridSpec(2, 4)
ax1 = plt.subplot(G[0, :])

if X.shape[1] == 3:
    ax2 = plt.subplot(G[1, :2], projection="3d")
    
elif X.shape[1] == 2:
    ax2 = plt.subplot(G[1, :2])

ax3 = plt.subplot(G[1, 2:])

optics_before = [-1, -1, -1]
optics_object = []
label_result = {"label_0":[]}

def clustering(msample, steep, msize, img_sel):
    start = time.time()
    if msample <= 0:
        print("'min_sample' must be larger than 0")
        return
    
    if steep <= 0:
        print("'steepness' must be larger than 0")
        return
    
    if msize <= 0:
        print("'min_cluster_size' must be larger than 0")
        return
    
    optics_check = [msample, steep, msize]

    if optics_before != optics_check:
        ax1.cla()
        del label_result["label_0"]
        del optics_object[:]
        print("optics activated")
        clust = OPTICS(min_samples=msample, xi=steep, min_cluster_size=msize).fit(X)
        optics_object.append(clust)
        space = np.arange(len(X))
        reachability = clust.reachability_[clust.ordering_]
        labels = clust.labels_[clust.ordering_]
        labels_0 = clust.labels_
        label_result["label_0"] = labels_0

        for klass, color in zip(range(0, len(color_rep)), color_rep[1:]):
            Xk = space[labels == klass]
            Rk = reachability[labels == klass]
            ax1.plot(Xk, Rk, color, alpha=0.3)
        
        ax1.plot(space[labels == -1], reachability[labels == -1], "k.", alpha=0.3)
        ax1.set_ylabel('Reachability-distance')
        ax1.set_title('Reachability-Plot')
        ax1.grid()
        
        ax2.cla()
        if X.shape[1] == 3:
            for klass, color in zip(range(0, len(color_rep)), color_rep[1:]):
                Xo = X[labels_0 == klass]
                ax2.scatter(Xo[:, 0], Xo[:, 1], Xo[:, 2], color=color, alpha=0.3, marker='.')
            ax2.plot(X[labels_0 == -1, 0], X[labels_0 == -1, 1], X[labels_0 == -1, 2], 'k+', alpha=0.1)
            ax2.set_title('Automatic Clustering\nOPTICS(# of clusters=%d)\n(%f, %f, %f)'%(len(np.unique(labels_0)), msample, steep, msize))

        elif X.shape[1] == 2:
            for klass, color in zip(range(0, len(color_rep)), color_rep[1:]):
                Xo = X[labels_0 == klass]
                ax2.scatter(Xo[:, 0], Xo[:, 1], color=color, alpha=0.3, marker='.')
            ax2.plot(X[labels_0 == -1, 0], X[labels_0 == -1, 1], 'k+', alpha=0.1)
            ax2.set_title('Automatic Clustering\nOPTICS(# of clusters=%d)\n(%f, %f, %f)'%(len(np.unique(labels_0)), msample, steep, msize))
        

    ax3.cla()
    
    label_reshape_0, _, _ = label_arrangement(label_result["label_0"], data_shape_new)
    
    ax3.imshow(label_reshape_0[img_sel-1], cmap=custom_cmap, norm=norm)
    ax3.set_title("image %d"%(img_sel), fontsize=10)
    ax3.axis("off")

    fig.tight_layout()
    
    del optics_before[:]
    for i in range(len(optics_check)):
        optics_before.append(optics_check[i])
    print("minimum number of samples in a neighborhood: %f"%msample)
    print("minimum steepness: %f"%steep)
    print("minumum number of samples in a cluster: %f"%msize)
    print("%.2f min have passed"%((time.time()-start)/60))


st = {"description_width": "initial"}
msample_wg = pyw.FloatText(value=0.05, description="min. # of samples in a neighborhood", style=st)
steep_wg = pyw.FloatText(value=0.001, description="min. steepness", style=st)
msize_wg = pyw.FloatText(value=0.05, description="min. # of samples in a cluster", style=st)
img_wg = pyw.Select(options=np.arange(num_img)+1, value=1, description="image selection", style=st)

pyw.interact(clustering, msample=msample_wg, steep=steep_wg, msize=msize_wg,  img_sel=img_wg)
fig.show()

In [None]:
plt.close(fig)

### spatial distribution and representative spectra of clusters

In [None]:
# reshape the clustering result
label_selected = label_result["label_0"].copy()
label_sort = np.unique(label_selected)
label_reshape, selected, hist = label_arrangement(label_selected, data_shape_new)
num_label = len(label_sort)
print(label_sort) # label "-1" -> not a cluster
print(hist) # number of data points in each cluster

In [None]:
%matplotlib inline

In [None]:
%matplotlib qt

In [None]:
# clustering result - clusters in the final DR space
# black points -> label "-1"
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
for klass, color in zip(range(0, len(color_rep)), color_rep[1:]):
    Xo = X[label_selected == klass]
    ax.scatter(Xo[:, 0], Xo[:, 1], color=color, alpha=0.3, marker='.')
ax.plot(X[label_selected == -1, 0], X[label_selected == -1, 1], 'k+', alpha=0.1)
#ax.axis("off")
fig.tight_layout()
plt.show()

In [None]:
# clustering result - spatial distribution of each cluster
row_n = 1
col_n = num_img
    
fig, ax = plt.subplots(row_n, col_n, figsize=(7, 10))
if num_img != 1:
    for i, axs in enumerate(ax.flat):
        axs.imshow(label_reshape[i], cmap=custom_cmap, norm=norm)
        axs.set_title("image %d"%(i+1), fontsize=10)
        axs.axis("off")

else:
    ax.imshow(label_reshape[0], cmap=custom_cmap, norm=norm)
    ax.set_title("image %d"%(1), fontsize=10)
    ax.axis("off")
#fig.colorbar(sm)
fig.tight_layout()

In [None]:
# save the spatial distribution as a tiff stack
save_result = label_reshape[0].copy()
save_result = save_result[:, :, np.newaxis]
save_result = save_result.astype(np.int16)
print(save_result.shape)

tifffile.imsave(tkf.asksaveasfilename(), save_result)

In [None]:
# clustering result
if num_img != 1:
    for i in range(num_label):
        fig, ax = plt.subplots(1, num_img, figsize=(7*num_img, 7))
        for j in range(num_img):
            ax[j].imshow(selected[i][j], cmap="afmhot")
            ax[j].set_title("label %d map"%(label_sort[i]+1), fontsize=10)
            ax[j].axis("off")
            fig.tight_layout()
        plt.show()
else:            
    for i in range(num_label):
        fig, ax = plt.subplots(1, 1, figsize=(7*num_img, 7))
        tmp = ax.imshow(selected[i][0], cmap="afmhot")
        ax.axis("off")
        fig.tight_layout()
        plt.show()

In [None]:
plt.close(fig)

In [None]:
%matplotlib inline

In [None]:
%matplotlib qt

In [None]:
# clustering result - representative spectra (cropped)
# average all of the spectra in each cluster
lines = np.zeros((num_label, depth))

for i in range(num_label):
    ind = np.where(label_selected == label_sort[i])
    print("number of pixels in the label %d cluster: %d"%(label_sort[i], hist[i]))
    lines[i] = np.mean(dataset_flat[ind], axis=0)
    
fig, ax = plt.subplots(2, 2, figsize=(15, 15))

# normalize representative spectra for comparison
denominator = np.max(lines[:, :20], axis=1)
lines = lines / denominator[:, np.newaxis]

if -1 in label_sort:
    for i in range(1, num_label):
        ax[0][0].plot(e_range, (lines[i]), label="cluster %d"%(i), c=color_rep[i])
        ax[1][0].plot(e_range, (lines[i]+(i-1)*0.1), label="cluster %d"%(i), c=color_rep[i])
        
else:
    for i in range(0, num_label):
        ax[0][0].plot(e_range, (lines[i]), label="cluster %d"%(i+1), c=color_rep[i+1])
        ax[1][0].plot(e_range, (lines[i]+i*0.1), label="cluster %d"%(i+1), c=color_rep[i+1])

ax[0][0].grid()
ax[0][0].legend(fontsize="x-large")
ax[0][0].set_xlabel("eV")
ax[1][0].grid()
ax[1][0].legend(fontsize="x-large")
ax[1][0].set_xlabel("eV")

# clustering result - representative spectra (original)
# average all of the spectra in each cluster
lines_original = np.zeros((num_label, depth_original))

for i in range(num_label):
    ind = np.where(label_selected == label_sort[i])
    #print("number of pixels in the label %d cluster: %d"%(label_sort[i], hist[i]))
    lines_original[i] = np.mean(dataset_original_flat[ind], axis=0)

# normalize representative spectra for comparison
denominator = np.max(lines_original[:, :20], axis=1)
lines_original = lines_original / denominator[:, np.newaxis]


if -1 in label_sort:
    for i in range(1, num_label):
        ax[0][1].plot(e_range_original, (lines_original[i]), label="cluster %d"%(i), c=color_rep[i])
        ax[1][1].plot(e_range_original, (lines_original[i]+(i-1)*0.1), label="cluster %d"%(i), c=color_rep[i])
        
else:
    for i in range(0, num_label):
        ax[0][1].plot(e_range_original, (lines_original[i]), label="cluster %d"%(i+1), c=color_rep[i+1])
        ax[1][1].plot(e_range_original, (lines_original[i]+i*0.1), label="cluster %d"%(i+1), c=color_rep[i+1])

ax[0][1].grid()
ax[0][1].legend(fontsize="x-large")
ax[0][1].set_xlabel("eV")
ax[1][1].grid()
ax[1][1].legend(fontsize="x-large")
ax[1][1].set_xlabel("eV")

fig.tight_layout()
plt.show()

In [None]:
# save the representative spectra as a tiff stack
tifffile.imsave(tkf.asksaveasfilename(), lines.reshape(lines.shape[0], -1, lines.shape[1]))

In [None]:
if torch.cuda.is_available():
    print("%d gpu available"%(torch.cuda.device_count()))
    cuda_device = torch.device("cuda:0")
    print(torch.cuda.get_device_name(cuda_device))
    torch.cuda.set_device(cuda_device)
else:
    cuda_device = None

In [None]:
#cluster_list = np.arange(1, len(lines))
cluster_list = [4]
ref_spec = []
for i in cluster_list:
    ref_spec.append(lines[i])

ref_spec = np.asarray(ref_spec)
ref_spec = torch.from_numpy(ref_spec).to(torch.float32)
ref_spec = ref_spec.cuda(cuda_device)
ref_spec.requires_grad_(requires_grad=False)
print(ref_spec.device)
print(ref_spec.shape)

In [None]:
batch_size = 1600
mini_batches = [dataset_input[k:k+batch_size] for k in range(0, len(dataset_input), batch_size)]

In [None]:
coeff_tmp = torch.randn((batch_size, ref_spec.shape[0])).to(torch.float32)
coeff_tmp = coeff_tmp.cuda(cuda_device)
coeff_tmp.requires_grad_(requires_grad=True)
print(coeff_tmp.device)
print(coeff_tmp.shape)

In [None]:
l_rate = 0.05
moment = 0.00
optimizer = optim.SGD([coeff_tmp], lr=l_rate, momentum=moment)
n_iter = 200

In [None]:
coeff_ret_tmp = []
for i, m_batch in enumerate(mini_batches):
    nn.init.xavier_uniform_(coeff_tmp)
    batch = torch.from_numpy(m_batch).to(torch.float32).cuda(cuda_device)
    for n in range(n_iter):
        reconstructed_tmp = torch.matmul(coeff_tmp, ref_spec)
        loss = LA.norm((batch - reconstructed_tmp), 2)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        coeff_tmp.data.clamp_(min=0.0)
        
    coeff_ret_tmp.extend(coeff_tmp.data.cpu().numpy().tolist())
    if int((i+1) % 4) == 0:
        print("%d-th batch of %d batches completed"%(i+1, len(mini_batches)))

In [None]:
coeff_ret_tmp = np.asarray(coeff_ret_tmp)
coeff_ret = np.zeros_like(coeff_ret_tmp)
coeff_ret[ri] = coeff_ret_tmp.copy()
coeff_ret_reshape = reshape_coeff(coeff_ret, data_shape_new)

In [None]:
# visualize the coefficient maps
if num_img != 1:
    for i in range(len(cluster_list)):
        fig, ax = plt.subplots(1, num_img, figsize=(7*num_img, 7))
        for j in range(num_img):
            tmp = ax[j].imshow(coeff_ret_reshape[j][:, :, i], cmap="afmhot")
            ax[j].set_title("representative spectrum %d map"%(i+1), fontsize=10)
            ax[j].axis("off")
            fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()
else:            
    for i in range(len(cluster_list)):
        fig, ax = plt.subplots(1, 1, figsize=(7, 7*num_img))
        tmp = ax.imshow(coeff_ret_reshape[0][:, :, i], cmap="afmhot")
        ax.set_title("representative spectrum %d map"%(i+1), fontsize=10)
        ax.axis("off")
        fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()