In [None]:
import time
import numpy as np
import tifffile
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import ipywidgets as pyw
import hyperspy.api as hys
import hyperspy.learn.mva as mva
import tkinter.filedialog as tkf
from tabulate import tabulate
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import linalg as LA
plt.rcParams['font.family'] = 'Times New Roman'

In [None]:
def data_load(adr, rescale=False, crop=None):
    """
    load a spectrum image
    """
    storage = []
    shape = []
    for i, adr in enumerate(adr):
        temp = hys.load(adr)
        temp = temp.isig[crop[0]:crop[1]]
        temp = temp.data
        if rescale:
            tmp = np.max(temp, axis=2)
            temp = temp / tmp[:, :, np.newaxis]
        print(np.max(temp), np.min(temp))
        print(temp.shape)

        shape.append(temp.shape)
        storage.append(temp)
    
    shape = np.asarray(shape)
    return storage, shape

In [None]:
def reshape_coeff(coeffs, new_shape):
    """
    reshape a coefficient matrix to restore the original scanning shapes.
    """
    coeff_reshape = []
    for i in range(len(new_shape)):
        temp = coeffs[:int(new_shape[i, 0]*new_shape[i, 1]), :]
        coeffs = np.delete(coeffs, range(int(new_shape[i, 0]*new_shape[i, 1])), axis=0)
        temp = np.reshape(temp, (new_shape[i, 0], new_shape[i, 1], -1))
        #print(temp.shape)
        coeff_reshape.append(temp)
        
    return coeff_reshape

In [None]:
file_adr = []

In [None]:
file_adr.extend(tkf.askopenfilenames())
print(len(file_adr))
print(*file_adr, sep="\n")

In [None]:
num_img = len(file_adr)
print(num_img)

In [None]:
# load spectrum images
cr_range = [1.0, 3.56, 0.01] # actual input
data_storage, data_shape = data_load(file_adr, rescale=True, crop=cr_range)
print(len(data_storage))
print(data_shape)

e_range = np.arange(cr_range[0], cr_range[1], cr_range[2])
depth = len(e_range)
print(len(e_range))

In [None]:
# create the input dataset
dataset_flat = []
for i in range(num_img):
    dataset_flat.extend(data_storage[i].clip(min=1E-5).reshape(-1, depth).tolist())
    
dataset_flat = np.asarray(dataset_flat)
print(dataset_flat.shape)
print(np.median(dataset_flat))

In [None]:
total_num = len(dataset_flat)
ri = np.random.choice(total_num, total_num, replace=False)

dataset_input = dataset_flat[ri]

In [None]:
class conv_ae(nn.Module):
    def __init__(self, input_size, encoded_dimension, channels, kernels):
        super(conv_ae, self).__init__()

        self.input_size = input_size
        self.encoded_dimension = encoded_dimension
        
        self.cnn_encoder = nn.Sequential(
            nn.Conv1d(1, channels[0], kernels[0], bias=True),
            nn.BatchNorm1d(channels[0]),
            nn.ReLU(),
            nn.AvgPool1d(2),
            nn.Conv1d(channels[0], channels[1], kernels[1], bias=True),
            nn.BatchNorm1d(channels[1]),
            nn.ReLU(),
            nn.AvgPool1d(2),
            nn.Conv1d(channels[1], channels[2], kernels[2], bias=True),
            nn.BatchNorm1d(channels[2]),
            nn.ReLU(),
            nn.AvgPool1d(2),
            nn.Conv1d(channels[2], channels[3], kernels[3], bias=True),
            nn.BatchNorm1d(channels[3]),
            nn.ReLU(),
            nn.AvgPool1d(2),
            nn.Flatten(),
        )
        
        self.decoder = nn.Sequential(
            nn.Linear(self.encoded_dimension[1], self.input_size, bias=False),
            nn.Hardsigmoid(),
        )
    def forward(self, x):
        x = x.view(-1, 1, self.input_size)
        encoded = self.cnn_encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

In [None]:
if torch.cuda.is_available():
    print("%d gpu available"%(torch.cuda.device_count()))
    cuda_device = torch.device("cuda:0")
    print(torch.cuda.get_device_name(cuda_device))
    torch.cuda.set_device(cuda_device)
else:
    cuda_device = None

In [None]:
num_comp = 5
channels = [8, 16, 32, num_comp]
kernels = [64, 32, 16, 0]
pooling = [2, 2, 2, 2]

dat_dim = []
tmp_dim = depth
for i in range(len(kernels)):
    tmp_dim += (-kernels[i]+1)
    tmp_dim /= pooling[i]
    dat_dim.append(int(tmp_dim))
    
kernels[-1] = dat_dim[-2] - pooling[-1] + 1
dat_dim[-1] = 1

print(dat_dim)
print(kernels)
print(channels)
print(pooling)

In [None]:
model = conv_ae(depth, [dat_dim[-1]*channels[-1], num_comp], channels, kernels)
model.cuda(cuda_device)
for p in model.parameters():
    if p.requires_grad:
        print(p.data.shape)
train_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(train_params)
print(model)

In [None]:
batch_size = 970
mini_batches = [dataset_input[k:k+batch_size] for k in range(0, len(dataset_input), batch_size)]

In [None]:
n_epoch = 50

In [None]:
l_rate = 0.01
#optimizer = optim.SGD(model.parameters(), lr=l_rate)
optimizer = optim.Adam(model.parameters(), lr=l_rate)
torch.nn.init.orthogonal_(model.decoder[0].weight)
#torch.nn.init.xavier_normal_(model.decoder[0].weight)
print(optimizer)

In [None]:
start = time.time()
ae_coeffs = []
ae_bias = []
ce_losses = []
mse_losses = []
for epoch in range(n_epoch):
    tmp_ce = 0
    tmp_mse = 0
    for i, m_batch in enumerate(mini_batches):
        
        x = torch.from_numpy(mini_batches[i])
        x = x.to(torch.float32)
        x = x.to(cuda_device)
        x.requires_grad_(requires_grad=False)
        
        encoded, decoded = model(x)
        
        #ce_loss = F.binary_cross_entropy(decoded, x, reduction="mean")
        #tmp_ce += ce_loss.item()
        mse_loss = F.mse_loss(decoded, x, reduction="mean")
        tmp_mse += mse_loss.item()
        
        main_loss = mse_loss
        
        optimizer.zero_grad()
        main_loss.backward()
        optimizer.step()
        
        model.decoder[0].weight.data.clamp_(min=0.0)
        
        if epoch == n_epoch-1:
            coeff_batch = encoded.data.cpu().numpy().tolist()
            ae_coeffs.extend(coeff_batch)            
    
    ce_losses.append(tmp_ce)
    mse_losses.append(tmp_mse)
    
    if epoch == 0:
        print(torch.cuda.memory_summary(device=cuda_device))
    
    if (epoch+1) % int(n_epoch/10) == 0:
        print(tabulate([
                        ["epoch", epoch+1], 
                        ["loss", main_loss.item()],
                        ]))
        print("%.2f minutes have passed"%((time.time()-start)/60))
        
        fig, ax = plt.subplots(1, num_comp, figsize=(5*num_comp, 5))
        for i in range(num_comp):
            ax[i].plot(e_range, model.decoder[0].weight.data.cpu()[:, i])
            ax[i].grid()
        fig.tight_layout()
        plt.show()

print("The training has been finished.")

In [None]:
ae_coeffs = np.asarray(ae_coeffs)
ae_comp_vectors = model.decoder[0].weight.data.cpu().numpy().T
print(ae_coeffs.shape)
print(ae_comp_vectors.shape)

# convert the coefficient matrix into coefficient maps
coeffs = np.zeros_like(ae_coeffs)
coeffs[ri] = ae_coeffs.copy()
coeffs_reshape = reshape_coeff(coeffs, data_shape)

In [None]:
peak_ind = np.argmax(ae_comp_vectors, axis=1)
peak_pos = e_range[peak_ind]
peak_order = np.argsort(peak_pos)
print(peak_pos)
print(peak_order)

In [None]:
# create a customized colorbar
color_rep = ["black", "red", "green", "blue", "purple"]
print(len(color_rep))
custom_cmap = mcolors.ListedColormap(color_rep)
bounds = np.arange(-1, len(color_rep))
norm = mcolors.BoundaryNorm(boundaries=bounds, ncolors=len(color_rep))
sm = cm.ScalarMappable(cmap=custom_cmap, norm=norm)
sm.set_array([])

cm_rep = ["gray", "Reds", "Greens", "Blues", "Purples"]
print(len(cm_rep))

In [None]:
# visualize loading vectors

fig, ax = plt.subplots(1, 1, figsize=(10, 10)) # all loading vectors
for i in range(num_comp):
    ax.plot(e_range, ae_comp_vectors[i], "-", c=color_rep[np.where(peak_order==i)[0][0]], label="loading vector %d"%(i+1), linewidth=3)
#ax.grid()
#ax.legend(fontsize="large")
ax.set_xlabel("Energy Loss (eV)", fontsize=30)
ax.set_ylabel("Intensity (arb. unit)", fontsize=30)
ax.tick_params(axis="both", labelsize=30)
#ax.axes.get_yaxis().set_visible(False)

fig.tight_layout()
plt.show()

In [None]:
# visualize loading vectors

fig, ax = plt.subplots(1, 1, figsize=(10, 10)) # all loading vectors
for i in range(num_comp):
    ax.plot(e_range, ae_comp_vectors[i], "-", c=color_rep[np.where(peak_order==i)[0][0]], label="loading vector %d"%(i+1), linewidth=5)
#ax.grid()
#ax.legend(fontsize="large")
ax.set_xlabel("Energy Loss (eV)", fontsize=30)
ax.set_ylabel("Intensity (arb. unit)", fontsize=30)
ax.tick_params(axis="both", labelsize=30)
#ax.axes.get_yaxis().set_visible(False)

fig.tight_layout()
plt.show()

In [None]:
# visualize loading vectors

fig, ax = plt.subplots(1, 1, figsize=(10, 7)) # all loading vectors
for i in range(num_comp):
    ax.plot(e_range, ae_comp_vectors[i], "-", c=color_rep[np.where(peak_order==i)[0][0]], label="loading vector %d"%(i+1), linewidth=3)
#ax.grid()
#ax.legend(fontsize="large")
ax.set_xlabel("Energy Loss (eV)", fontsize=30)
ax.set_ylabel("Intensity (arb. unit)", fontsize=30)
ax.tick_params(axis="both", labelsize=30)
#ax.axes.get_yaxis().set_visible(False)

fig.tight_layout()
plt.show()

In [None]:
# visualize loading vectors

fig, ax = plt.subplots(1, 1, figsize=(10, 7)) # all loading vectors
for i in range(num_comp):
    ax.plot(e_range, ae_comp_vectors[i], "-", c=color_rep[np.where(peak_order==i)[0][0]], label="loading vector %d"%(i+1), linewidth=5)
#ax.grid()
#ax.legend(fontsize="large")
ax.set_xlabel("Energy Loss (eV)", fontsize=30)
ax.set_ylabel("Intensity (arb. unit)", fontsize=30)
ax.tick_params(axis="both", labelsize=30)
#ax.axes.get_yaxis().set_visible(False)

fig.tight_layout()
plt.show()

In [None]:
%matplotlib inline

In [None]:
%matplotlib qt

In [None]:
coeffs_reshape = np.asarray(coeffs_reshape)
print(coeffs_reshape.shape)

In [None]:
# visualize coefficient maps
for i in range(num_comp):
    fig, ax = plt.subplots(1, num_img, figsize=(120, 10))
    min_val = np.min(coeffs_reshape[:][:, :, i])
    max_val = np.max(coeffs_reshape[:][:, :, i])
    for j in range(num_img):
        tmp = ax[j].imshow(coeffs_reshape[j][:, :, i], 
                               vmin=min_val, vmax=max_val, cmap=cm_rep[np.where(peak_order==i)[0][0]])
        #ax[j].set_title("loading vector %d map"%(i+1), fontsize=10)
        ax[j].axis("off")
    fig.tight_layout()
    #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
plt.show()

In [None]:
# visualize coefficient maps
for i in range(num_comp):
    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    min_val = np.min(coeffs_reshape[:][:, :, i])
    max_val = np.max(coeffs_reshape[:][:, :, i])
    for j in range(1):
        tmp = ax.imshow(np.zeros((10, 10)), 
                               vmin=min_val, vmax=max_val, cmap=cm_rep[np.where(peak_order==i)[0][0]])
        #ax[j].set_title("loading vector %d map"%(i+1), fontsize=10)
        ax.axis("off")
    c_bar = fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
    c_bar.ax.tick_params(labelsize=30)
plt.show()

In [None]:
# visualize coefficient maps
for i in range(num_comp):
    fig, ax = plt.subplots(1, num_img, figsize=(120, 10))
    for j in range(num_img):
        tmp = ax[j].imshow(coeffs_reshape[j][:, :, i], cmap=cm_rep[i])
        #ax[j].set_title("loading vector %d map"%(i+1), fontsize=10)
        ax[j].axis("off")
        fig.tight_layout()
    #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
plt.show()

In [None]:
# visualize coefficient maps
for i in range(num_comp):
    fig, ax = plt.subplots(1, num_img, figsize=(120, 10))
    for j in range(num_img):
        tmp = ax[j].imshow(coeffs_reshape[j][:, :, i], 
                               vmin=np.percentile(coeffs_reshape[j][:, :, i], 30), cmap=cm_rep[i])
        #ax[j].set_title("loading vector %d map"%(i+1), fontsize=10)
        ax[j].axis("off")
    fig.tight_layout()
plt.show()

In [None]:
# visualize coefficient maps
for i in range(num_comp):
    fig, ax = plt.subplots(1, num_img, figsize=(120, 10))
    for j in range(num_img):
        tmp = ax[j].imshow(coeffs_reshape[j][:, :, i], 
                               vmin=np.percentile(coeffs_reshape[j][:, :, i], 50), cmap=cm_rep[i])
        #ax[j].set_title("loading vector %d map"%(i+1), fontsize=10)
        ax[j].axis("off")
    fig.tight_layout()
plt.show()

In [None]:
# visualize coefficient maps

if num_img != 1:
    for i in range(num_comp):
        fig, ax = plt.subplots(1, num_img, figsize=(5*num_img, 5))
        for j in range(num_img):
            tmp = ax[j].imshow(coeffs_reshape[j][:, :, i], 
                               vmin=np.percentile(coeffs_reshape[j][:, :, i], 50), cmap="viridis")
            #ax[j].set_title("loading vector %d map"%(i+1), fontsize=10)
            ax[j].axis("off")
        fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
            
else:
    for i in range(num_comp):
        fig, ax = plt.subplots(1, 1, figsize=(5*num_img, 5))
        tmp = ax.imshow(coeffs_reshape[0][:, :, i], 
                        vmin=np.percentile(coeffs_reshape[0][:, :, i], 50), cmap="viridis")
        #ax.set_title("loading vector %d map"%(i+1), fontsize=10)
        ax.axis("off")
        fig.tight_layout()
        fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()
        
plt.show()

In [None]:
print(*mse_losses, sep="\n")
fig, ax = plt.subplots(1, 1, figsize=(8, 5)) # all loading vectors
ax.plot(mse_losses, "-")

fig.tight_layout()
plt.show()

In [None]:
for i in range(num_comp):
    tilt_series = []
    for j in range(num_img):
        tilt_series.append(coeffs_reshape[j][:, :, i].astype(np.float32))
    tilt_series = np.asarray(tilt_series)
    tifffile.imsave("tilt_series_20221110_latent_%02d.tif"%(i+1), tilt_series)

In [None]:
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

In [None]:
output = sigmoid(ae_coeffs[:, 2:3] @ ae_comp_vectors[2:3, :])
print(output.shape)

output_reshape = np.zeros_like(output)
print(output_reshape.shape)
output_reshape[ri] = output.copy()
output_reshape = reshape_coeff(output_reshape, data_shape)
print(len(output_reshape), output_reshape[0].shape)

In [None]:
reconstructed = np.asarray(output_reshape[6])
tifffile.imsave("reconstructed_SI_L_H3_latent_2_%02d.tif"%(6+1), reconstructed)

In [None]:
for i in range(num_img):
    reconstructed = np.asarray(output_reshape[i])
    tifffile.imsave("reconstructed_SI_L_H3_full_%02d.tif"%(i+1), reconstructed)

In [None]:
%matplotlib widget
ref_data = output_reshape[0]
hys.signals.Signal1D(ref_data).plot()

In [None]:
# 2D subspace
%matplotlib qt
fig, ax = plt.subplots(1, 1, figsize=(7, 7))

def projection(c1, c2):
    ax.cla()
    ax.scatter(coeffs[:, c1], coeffs[:, c2], s=30, c="black", alpha=0.5)
    ax.grid()
    ax.set_xlabel("loading vector %d"%(c1+1), fontsize=15)
    ax.set_ylabel("loading vector %d"%(c2+1), fontsize=15)
    ax.tick_params(axis="both", labelsize=15)
    fig.canvas.draw()
    fig.tight_layout()

x_widget = pyw.IntSlider(min=0, max=num_comp-1, step=1, value=0)
y_widget = pyw.IntSlider(min=0, max=num_comp-1, step=1, value=1)

pyw.interact(projection, c1=x_widget, c2=y_widget)
plt.show()