In [None]:
import sys
sys.path.append("E:/github_repo/AEs/")
from AEs_module import *
import time
import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import ipywidgets as pyw
import tkinter.filedialog as tkf
import tifffile
from tabulate import tabulate
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import linalg as LA
plt.rcParams['font.family'] = 'Cambria'

In [None]:
file_adr = tkf.askopenfilenames()
print(*file_adr, sep="\n")

In [None]:
data_load = load_data(file_adr, dat_dim=3, dat_unit='eV', cr_range=[1.0, 3.56, 0.01], dat_scale=1.0, rescale=False, DM_file=True)

In [None]:
# binning (optional)
# rescale_0to1: rescale each data from 0 to 1
bin_y = 4 # binning size (height)
bin_x = 4 # binning size (width)
str_y = 4 # stride height-direction
str_x = 4 # stride width-direction

data_load.binning(bin_y, bin_x, str_y, str_x, offset=0, rescale_0to1=False)

In [None]:
data_load.make_input(min_val=0.0, max_normalize=True, rescale_0to1=False, final_dim=1)

In [None]:
if torch.cuda.is_available():
    print("%d gpu available"%(torch.cuda.device_count()))
    cuda_device = torch.device("cuda:0")
    print(torch.cuda.get_device_name(cuda_device))
    torch.cuda.set_device(cuda_device)
    print(torch.cuda.memory_summary(device=cuda_device))
else:
    cuda_device = None

In [None]:
num_comp = 5
channels = [8, 16, 32, num_comp]
kernels = [64, 32, 16, 0]
pooling = [2, 2, 2, 2]

dat_dim = []
tmp_dim = data_load.num_dim
for i in range(len(kernels)):
    tmp_dim += (-kernels[i]+1)
    tmp_dim /= pooling[i]
    dat_dim.append(int(tmp_dim))
    
kernels[-1] = dat_dim[-2] - pooling[-1] + 1
dat_dim[-1] = 1

print(dat_dim)
print(kernels)
print(channels)
print(pooling)

In [None]:
parallel_ = False

enc_model = CAE1D_encoder(data_load.num_dim, channels, kernels, pooling)

if parallel_:
    enc_model = nn.DataParallel(enc_model)
    
enc_model.cuda(cuda_device)
for p in enc_model.parameters():
    if p.requires_grad:
        print(p.data.shape)
train_params = sum(p.numel() for p in enc_model.parameters() if p.requires_grad)
print(train_params)
print(enc_model)

In [None]:
dec_model = linFE_decoder(num_comp, data_load.num_dim)

if parallel_:
    dec_model = nn.DataParallel(dec_model)

dec_model.cuda(cuda_device)
print(dec_model)

In [None]:
batch_size = 970
mini_batches = [data_load.dataset_input[k:k+batch_size] for k in range(0, len(data_load.dataset_input), batch_size)]
print(len(mini_batches))
print(len(mini_batches[-1]))

In [None]:
l_rate = 0.001

params = list(enc_model.parameters()) + list(dec_model.parameters())
optimizer = optim.Adam(params)

torch.nn.init.orthogonal_(dec_model.decoder[0].weight)
#torch.nn.init.xavier_normal_(model.decoder[0].weight)
print(optimizer)

In [None]:
start = time.time()
n_epoch = 50

ae_coeffs = []
ae_bias = []
ce_losses = []
mse_losses = []
for epoch in range(n_epoch):
    tmp_ce = 0
    tmp_mse = 0
    for i, m_batch in enumerate(mini_batches):
        
        x = torch.from_numpy(mini_batches[i])
        x = x.to(torch.float32)
        x = x.to(cuda_device)
        x.requires_grad_(requires_grad=False)
        
        encoded = enc_model(x)
        decoded = dec_model(encoded)
        
        #ce_loss = F.binary_cross_entropy(decoded, x, reduction="mean")
        #tmp_ce += ce_loss.item()
        mse_loss = F.mse_loss(decoded, x, reduction="mean")
        tmp_mse += mse_loss.item()
        
        main_loss = mse_loss
        
        optimizer.zero_grad()
        main_loss.backward()
        optimizer.step()
        
        dec_model.decoder[0].weight.data.clamp_(min=0.0)
        
        if epoch == n_epoch-1:
            coeff_batch = encoded.data.cpu().numpy().tolist()
            ae_coeffs.extend(coeff_batch)            
    
    ce_losses.append(tmp_ce)
    mse_losses.append(tmp_mse)
    
    if epoch == 0:
        print(torch.cuda.memory_summary(device=cuda_device))
    
    if (epoch+1) % int(n_epoch/10) == 0:
        print(tabulate([
                        ["epoch", epoch+1], 
                        ["loss", main_loss.item()],
                        ]))
        print("%.2f minutes have passed"%((time.time()-start)/60))
        
        fig, ax = plt.subplots(1, num_comp, figsize=(5*num_comp, 5))
        for i in range(num_comp):
            ax[i].plot(data_load.dat_dim_range, dec_model.decoder[0].weight.data.cpu()[:, i])
        fig.tight_layout()
        plt.show()

print("The training has been finished.")

In [None]:
ae_coeffs = np.asarray(ae_coeffs)
ae_comp_vectors = dec_model.decoder[0].weight.data.cpu().numpy().T
print(ae_coeffs.shape)
print(ae_comp_vectors.shape)

# convert the coefficient matrix into coefficient maps
coeffs = np.zeros_like(ae_coeffs)
coeffs[data_load.ri] = ae_coeffs.copy()
coeffs_reshape = reshape_coeff(coeffs, data_load.data_shape)

In [None]:
peak_ind = np.argmax(ae_comp_vectors, axis=1)
peak_pos = data_load.dat_dim_range[peak_ind]
peak_order = np.argsort(peak_pos)
print(peak_pos)
print(peak_order)

In [None]:
# create a customized colorbar
color_rep = ["black", "red", "green", "blue", "purple", "orange"]
print(len(color_rep))
custom_cmap = mcolors.ListedColormap(color_rep)
bounds = np.arange(-1, len(color_rep))
norm = mcolors.BoundaryNorm(boundaries=bounds, ncolors=len(color_rep))
sm = cm.ScalarMappable(cmap=custom_cmap, norm=norm)
sm.set_array([])

cm_rep = ["Greys", "Reds", "Greens", "Blues", "Purples", "Oranges"]
print(len(cm_rep))

In [None]:
# visualize loading vectors

fig, ax = plt.subplots(1, 1, figsize=(10, 10)) # all loading vectors
for i in range(num_comp):
    ax.plot(data_load.dat_dim_range, ae_comp_vectors[i], "-", c=color_rep[np.where(peak_order==i)[0][0]], label="loading vector %d"%(i+1), linewidth=5)
#ax.grid()
#ax.legend(fontsize="large")
ax.set_xlabel("Energy Loss (eV)", fontsize=30)
ax.set_ylabel("Intensity (arb. unit)", fontsize=30)
ax.tick_params(axis="both", labelsize=30)
#ax.axes.get_yaxis().set_visible(False)

fig.tight_layout()
plt.show()

In [None]:
# visualize coefficient maps
for i in range(num_comp):
    fig, ax = plt.subplots(1, data_load.num_img, figsize=(120, 10))
    min_val = np.min(coeffs[:, i])
    max_val = np.max(coeffs[:, i])
    for j in range(data_load.num_img):
        tmp = ax[j].imshow(coeffs_reshape[j][:, :, i], 
                               vmin=min_val, vmax=max_val, cmap=cm_rep[np.where(peak_order==i)[0][0]])
        #ax[j].set_title("loading vector %d map"%(i+1), fontsize=10)
        ax[j].axis("off")
    fig.tight_layout()
    #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
plt.show()

In [None]:
# visualize colorbars
for i in range(num_comp):
    fig, ax = plt.subplots(1, 1, figsize=(10, 10))
    min_val = np.min(coeffs[:, i])
    max_val = np.max(coeffs[:, i])
    for j in range(1):
        tmp = ax.imshow(np.zeros((10, 10)), 
                               vmin=min_val, vmax=max_val, cmap=cm_rep[np.where(peak_order==i)[0][0]])
        #ax[j].set_title("loading vector %d map"%(i+1), fontsize=10)
        ax.axis("off")
    c_bar = fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
    c_bar.ax.tick_params(labelsize=30)
plt.show()

In [None]:
# 2D subspace
fig, ax = plt.subplots(1, 1, figsize=(7, 7))

def projection(c1, c2):
    ax.cla()
    ax.scatter(coeffs[:, c1], coeffs[:, c2], s=30, c="black", alpha=0.5)
    ax.grid()
    ax.set_xlabel("loading vector %d"%(c1+1), fontsize=15)
    ax.set_ylabel("loading vector %d"%(c2+1), fontsize=15)
    ax.tick_params(axis="both", labelsize=15)
    fig.canvas.draw()
    fig.tight_layout()

x_widget = pyw.IntSlider(min=0, max=num_comp-1, step=1, value=0)
y_widget = pyw.IntSlider(min=0, max=num_comp-1, step=1, value=1)

pyw.interact(projection, c1=x_widget, c2=y_widget)
plt.show()