In [None]:
import time
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import ipywidgets as pyw
import hyperspy.api as hys
import tkinter.filedialog as tkf
from tabulate import tabulate
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import linalg as LA
from torch.utils.tensorboard import SummaryWriter
from torchvision.utils import make_grid

In [None]:
# create a customized colorbar
color_rep = ["black", "red", "green", "blue", "orange", "purple", "yellow", "lime", 
             "cyan", "magenta", "lightgray", "peru", "springgreen", "deepskyblue", 
             "hotpink", "darkgray"]
print(len(color_rep))
custom_cmap = mcolors.ListedColormap(color_rep)
bounds = np.arange(-1, len(color_rep))
norm = mcolors.BoundaryNorm(boundaries=bounds, ncolors=len(color_rep))
sm = cm.ScalarMappable(cmap=custom_cmap, norm=norm)
sm.set_array([])

cm_rep = ["gray", "Reds", "Greens", "Blues", "Oranges", "Purples"]
print(len(cm_rep))

def zero_one_rescale(spectrum):
    """
    normalize one spectrum from 0.0 to 1.0
    """
    spectrum = spectrum.clip(min=0.0)
    min_val = np.min(spectrum)
    
    rescaled = spectrum - min_val
    
    if np.max(rescaled) != 0:
        rescaled = rescaled / np.max(rescaled)
    
    return rescaled

def threed_roll_axis(img):
    stack = np.rollaxis(img, 2, 0)
    return stack

def data_load(adr, rescale=False, crop=None, roll_axis=False):
    """
    load a spectrum image
    """
    storage = []
    shape = []
    for i, adr in enumerate(adr):
        if crop:
            temp = hys.load(adr)
            print(temp.axes_manager[2])
            temp = temp.isig[crop[0]:crop[1]]
            temp = temp.data
            print(temp.shape)
        
        else:
            temp = hys.load(adr).data
            if roll_axis:
                temp = threed_roll_axis(temp)
            print(temp.shape)
            
        if rescale:
            for j in range(temp.shape[0]):
                for k in range(temp.shape[1]):
                    temp[j, k] = zero_one_rescale(temp[j, k])
        shape.append(temp.shape)
        storage.append(temp)
    
    shape = np.asarray(shape)
    return storage, shape

def binning_SI(si, bin_y, bin_x, str_y, str_x, offset, depth, rescale=True):
    """
    re-bin a spectrum image
    """
    si = np.asarray(si)
    rows = range(0, si.shape[0]-bin_y+1, str_y)
    cols = range(0, si.shape[1]-bin_x+1, str_x)
    new_shape = (len(rows), len(cols))
    
    binned = []
    for i in rows:
        for j in cols:
            temp_sum = np.mean(si[i:i+bin_y, j:j+bin_x, offset:(offset+depth)], axis=(0, 1))
            if rescale:
                binned.append(zero_one_rescale(temp_sum))
            else:
                binned.append(temp_sum)
            
    binned = np.asarray(binned).reshape(new_shape[0], new_shape[1], depth)
    
    return binned

def reshape_coeff(coeffs, new_shape):
    """
    reshape a coefficient matrix to restore the original scanning shapes.
    """
    coeff_reshape = []
    for i in range(len(new_shape)):
        temp = coeffs[:int(new_shape[i, 0]*new_shape[i, 1]), :]
        coeffs = np.delete(coeffs, range(int(new_shape[i, 0]*new_shape[i, 1])), axis=0)
        temp = np.reshape(temp, (new_shape[i, 0], new_shape[i, 1], -1))
        #print(temp.shape)
        coeff_reshape.append(temp)
        
    return coeff_reshape

In [None]:
file_adr = []

In [None]:
file_adr.extend(tkf.askopenfilenames())
print(len(file_adr))
print(*file_adr, sep="\n")

In [None]:
num_img = len(file_adr)
print(num_img)

In [None]:
# load spectrum images
cr_range = [1.0, 3.0, 0.01] # actual input
data_storage, data_shape = data_load(file_adr, rescale=False, crop=cr_range, roll_axis=False)
print(len(data_storage))
print(data_shape)

e_range = np.arange(cr_range[0], cr_range[1], cr_range[2])
print(len(e_range))

# re-bin spectrum images in order to mitigate noises
bin_y = 1 # binning size (height)
bin_x = 1 # binning size (width)
str_y = 1 # stride width-direction
str_x = 1 # stride height-direction

dataset = []
data_shape_new = []
offset = 0
depth = len(e_range)
for img in data_storage:
    print(img.shape)
    processed = binning_SI(img, bin_y, bin_x, str_y, str_x, offset, depth, rescale=True) # include the step for re-scaling the actual input
    print(processed.shape)
    data_shape_new.append(processed.shape)
    dataset.append(processed)
    
data_shape_new = np.asarray(data_shape_new)
print(data_shape_new)

# create the input dataset
dataset_flat = []
for i in range(num_img):
    dataset_flat.extend(dataset[i].clip(min=0.0).reshape(-1, depth).tolist())
    
dataset_flat = np.asarray(dataset_flat)
print(dataset_flat.shape)

In [None]:
total_num = len(dataset_flat)
ri = np.random.choice(total_num, total_num, replace=False)

dataset_input = dataset_flat[ri]

In [None]:
class simple_VAE(nn.Module):
    def __init__(self, in_dim, h_dim, z_dim):
        super(simple_VAE, self).__init__()
        
        self.in_dim = in_dim
        self.z_dim = z_dim
        self.h_dim = h_dim
        
        enc_net = []
        enc_net.append(nn.Linear(self.in_dim, self.h_dim[0]))
        enc_net.append(nn.BatchNorm1d(self.h_dim[0]))
        enc_net.append(nn.LeakyReLU(0.1))
        for i in range(1, len(h_dim)):
            enc_net.append(nn.Linear(self.h_dim[i-1], self.h_dim[i]))
            enc_net.append(nn.BatchNorm1d(self.h_dim[i]))
            enc_net.append(nn.LeakyReLU(0.1))
        enc_net.append(nn.Linear(self.h_dim[-1], 2*z_dim))
        
        self.encoder = nn.Sequential(*enc_net)
        
        dec_net = []
        dec_net.append(nn.Linear(z_dim, self.h_dim[-1]))
        enc_net.append(nn.BatchNorm1d(self.h_dim[-1]))
        dec_net.append(nn.LeakyReLU(0.1))
        for i in range(len(h_dim), 1, -1):
            dec_net.append(nn.Linear(self.h_dim[i-1], self.h_dim[i-2]))
            enc_net.append(nn.BatchNorm1d(self.h_dim[i-2]))
            dec_net.append(nn.LeakyReLU(0.1))
        dec_net.append(nn.Linear(self.h_dim[0], self.in_dim))
        dec_net.append(nn.Tanh())
        
        self.decoder = nn.Sequential(*dec_net)
        
    def encode(self, x):
        
        latent = self.encoder(x)
        mu = latent[:, :self.z_dim]
        logvar = latent[:, self.z_dim:]
        
        return mu, logvar
        
                
    def reparametrization(self, mu, logvar):
        
        return mu+torch.exp(0.5*logvar)*torch.randn_like(logvar)
        
        
    def decode(self, z):
        
        return self.decoder(z)
        
        
    def forward(self, x):
        
        mu, logvar = self.encode(x)
        z = self.reparametrization(mu, logvar)
        
        return mu, logvar, z, self.decode(z)

In [None]:
if torch.cuda.is_available():
    print("%d gpu available"%(torch.cuda.device_count()))
    cuda_device = torch.device("cuda:0")
    print(torch.cuda.get_device_name(cuda_device))
    torch.cuda.set_device(cuda_device)
else:
    cuda_device = None

In [None]:
h_dim = [200]
z_dim = 5

model = simple_VAE(depth, h_dim, z_dim)
model.cuda(cuda_device)
print(model)

In [None]:
batch_size = 783
mini_batches = [dataset_input[k:k+batch_size] for k in range(0, len(dataset_input), batch_size)]
print(len(mini_batches))

In [None]:
n_epoch = 2000
optimizer = optim.Adam(model.parameters())
print(optimizer)

In [None]:
start = time.time()
latent_z = []
model.train()
n_fig = 5
for epoch in range(n_epoch):
    loss_epoch = 0
    recon_loss = 0
    KLD_loss = 0
    for i, m_batch in enumerate(mini_batches):
        
        x = torch.from_numpy(mini_batches[i])
        x = x.to(torch.float32)
        x = x.to(cuda_device)
        x.requires_grad_(requires_grad=False)
        
        mu, logvar, z, x_ = model(x)
        
        reconstruction_error = F.mse_loss(x_, x, reduction="sum")
        KL_divergence = -0.5*torch.sum(1+logvar-mu**2-logvar.exp())
        
        loss = reconstruction_error + KL_divergence
        loss_epoch += loss.item()
        recon_loss += reconstruction_error.item()
        KLD_loss += KL_divergence.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if epoch == n_epoch-1:
            z_batch = z.data.cpu().numpy().tolist()
            latent_z.extend(z_batch)   
            
    if epoch == 0:
        print(torch.cuda.memory_summary(device=cuda_device))
        
    if (epoch+1) % int(n_epoch/10) == 0:
        print(tabulate([
                        ["epoch", epoch+1], 
                        ["total loss", loss_epoch/total_num],
                        ["reconstruction error", recon_loss/total_num],
                        ["KL divergence", KLD_loss/total_num],
                        ["error ratio", reconstruction_error/KL_divergence],
                        ]))
        print("%.2f minutes have passed"%((time.time()-start)/60))
        
        fig, ax = plt.subplots(2, n_fig, figsize=(5*n_fig, 5))
        for i in range(n_fig):
            ax[0][i].plot(e_range, x[i].data.cpu())
            ax[0][i].grid()
            ax[1][i].plot(e_range, x_[i].data.cpu())
            ax[1][i].grid()
        fig.tight_layout()
        plt.show()

print("The training has been finished.")

In [None]:
ae_coeffs = np.asarray(latent_z)
print(ae_coeffs.shape)

# convert the coefficient matrix into coefficient maps
coeffs = np.zeros_like(ae_coeffs)
coeffs[ri] = ae_coeffs.copy()
coeffs_reshape = reshape_coeff(coeffs, data_shape_new)

In [None]:
%matplotlib inline

In [None]:
%matplotlib qt

In [None]:
# visualize coefficient maps

if num_img != 1:
    for i in range(z_dim):
        fig, ax = plt.subplots(1, num_img, figsize=(5*num_img, 5))
        for j in range(num_img):
            tmp = ax[j].imshow(coeffs_reshape[j][:, :, i], 
                               vmin=np.percentile(coeffs_reshape[j][:, :, i], 50), cmap="afmhot")
            ax[j].set_title("loading vector %d map"%(i+1), fontsize=10)
            ax[j].axis("off")
            
else:
    for i in range(z_dim):
        fig, ax = plt.subplots(1, 1, figsize=(5*num_img, 5))
        tmp = ax.imshow(coeffs_reshape[0][:, :, i], 
                        vmin=np.percentile(coeffs_reshape[0][:, :, i], 50), cmap="afmhot")
        ax.set_title("loading vector %d map"%(i+1), fontsize=10)
        ax.axis("off")
        fig.tight_layout()
        fig.colorbar(tmp, cax=fig.add_axes([0.88, 0.15, 0.04, 0.7]))
        plt.show()
        
plt.show()

In [None]:
# 2D subspace
%matplotlib qt
fig, ax = plt.subplots(1, 1, figsize=(7, 7))

def projection(c1, c2):
    ax.cla()
    ax.scatter(coeffs[:, c1], coeffs[:, c2], s=30, c="black", alpha=0.5)
    ax.grid()
    ax.set_xlabel("loading vector %d"%(c1+1), fontsize=15)
    ax.set_ylabel("loading vector %d"%(c2+1), fontsize=15)
    ax.tick_params(axis="both", labelsize=15)
    fig.canvas.draw()
    fig.tight_layout()

x_widget = pyw.IntSlider(min=0, max=z_dim-1, step=1, value=1)
y_widget = pyw.IntSlider(min=0, max=z_dim-1, step=1, value=2)

pyw.interact(projection, c1=x_widget, c2=y_widget)
plt.show()