In [None]:
from VAEs_module import *
import time
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import ipywidgets as pyw
import hyperspy.api as hys
import tkinter.filedialog as tkf
from tabulate import tabulate
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import linalg as LA

In [None]:
file_adr = []

In [None]:
file_adr.extend(tkf.askopenfilenames())
print(len(file_adr))
print(*file_adr, sep="\n")

In [None]:
data_load = load_data(file_adr, dat_dim=4, dat_unit='1/nm', rescale=False)

In [None]:
data_load.find_center(cbox_edge=40, center_remove=0, result_visual=True, log_scale=False)

In [None]:
data_load.make_input(min_val=1E-6, max_normalize=True, 
           log_scale=False, radial_flat=False, 
           w_size=64, radial_range=None, final_dim=1)

In [None]:
if torch.cuda.is_available():
    print("%d gpu available"%(torch.cuda.device_count()))
    cuda_device = torch.device("cuda:0")
    print(torch.cuda.get_device_name(cuda_device))
    torch.cuda.set_device(cuda_device)
    print(torch.cuda.memory_summary(device=cuda_device))
else:
    cuda_device = None

In [None]:
parallel_ = True

num_comp = 2
enc_hid_dim = [512]

enc_model = VAEFCNN_encoder(data_load.s_length, num_comp, enc_hid_dim)

if parallel_:
    enc_model = nn.DataParallel(enc_model)
enc_model.cuda(cuda_device)
print(enc_model)

In [None]:
hidden_dim = [512]

dec_model = VAEFCNN_decoder(num_comp, hidden_dim, data_load.s_length)

if parallel_:
    dec_model = nn.DataParallel(dec_model)
    
dec_model.cuda(cuda_device)
print(dec_model)

In [None]:
batch_size = 783
mini_batches = [dataset_input[k:k+batch_size] for k in range(0, len(dataset_input), batch_size)]
print(len(mini_batches))

In [None]:
n_epoch = 100
l_rate = 0.001
params = list(enc_model.parameters()) + list(dec_model.parameters())
optimizer = optim.Adam(params, lr=l_rate)

In [None]:
start = time.time()
loss_plot = []
n_fig = 5
for epoch in range(n_epoch):
    loss_epoch = 0
    recon_loss = 0
    KLD_loss = 0
    
    latent_z = []
    z_mu = []
    z_logvar = []
    for i, m_batch in enumerate(mini_batches):
        
        x = torch.from_numpy(m_batch).clamp_(min=0.001, max=0.999)
        x = x.to(torch.float32)
        x = x.to(cuda_device)
        x.requires_grad_(requires_grad=False)
        
        mu, logvar, z = enc_model(x)
        x_ = dec_model(z)
        
        reconstruction_error = F.binary_cross_entropy(x_.squeeze(), x, reduction="sum")
        KL_divergence = -0.5*torch.sum(1+logvar-mu**2-logvar.exp())
        
        loss = reconstruction_error + KL_divergence
        loss_epoch += loss.item()
        recon_loss += reconstruction_error.item()
        KLD_loss += KL_divergence.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        latent_z.extend(z.data.cpu().numpy().tolist())
        z_mu.extend(mu.data.cpu().numpy().tolist())
        z_logvar.extend(logvar.data.cpu().numpy().tolist())
    
    loss_plot.append(loss_epoch/total_num)       
    
    latent_z = np.asarray(latent_z)
    z_mu = np.asarray(z_mu)
    z_logvar = np.asarray(z_logvar)
            
    if epoch == 0:
        print(torch.cuda.memory_summary(device=cuda_device))
        
    if (epoch+1) % int(n_epoch/10) == 0:
        fig, ax = plt.subplots(1, 1)
        ax.plot(np.arange(epoch+1)+1, loss_plot, "k-")
        ax.grid()
        plt.show()        
        
        print(tabulate([
                        ["epoch", epoch+1], 
                        ["total loss", loss_epoch/total_num],
                        ["reconstruction error", recon_loss/total_num],
                        ["KL divergence", KLD_loss/total_num],
                        ["error ratio", reconstruction_error/KL_divergence],
                        ]))
        print("%.2f minutes have passed"%((time.time()-start)/60))
        
        fig, ax = plt.subplots(2, n_fig, figsize=(5*n_fig, 5*2))
        for i in range(n_fig):
            ax[0][i].imshow(x[i].data.cpu().reshape(data_load.w_size, data_load.w_size), cmap="inferno")
            ax[1][i].imshow(x_[i].squeeze().data.cpu().reshape(data_load.w_size, data_load.w_size), cmap="inferno")
        fig.tight_layout()
        plt.show()
        
        coeffs = np.zeros_like(latent_z)
        coeffs[ri] = latent_z.copy()
        coeffs_reshape = reshape_coeff(coeffs, data_shape)  

        fig, ax = plt.subplots(1, 1)
        ax.grid()
        for i in range(num_comp):
            ax.hist(coeffs[:, i], bins=50, alpha=(1.0-i*(1/num_comp)))
        plt.show()

        if num_img != 1:
            for i in range(num_comp):
                fig, ax = plt.subplots(1, num_img, figsize=(7*num_img, 7))
                for j in range(num_img):
                    tmp = ax[j].imshow(coeffs_reshape[j][:, :, i], cmap="viridis")
                    ax[j].axis("off")
                    #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
                plt.show()
        else:            
            for i in range(num_comp):
                fig, ax = plt.subplots(1, 1, figsize=(7, 7*num_img))
                tmp = ax.imshow(coeffs_reshape[0][:, :, i], cmap="viridis")
                ax.axis("off")
                #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
                plt.show()        
        
print("The training has been finished.")

In [None]:
%matplotlib inline

In [None]:
%matplotlib qt

In [None]:
coeffs = np.zeros_like(latent_z)
coeffs[ri] = latent_z.copy()
coeffs_reshape = reshape_coeff(coeffs, data_shape)

fig, ax = plt.subplots(1, 1)
ax.grid()
for i in range(num_comp):
    ax.hist(coeffs[:, i], bins=50, alpha=(1.0-i*(1/num_comp)))
plt.show()

if num_img != 1:
    for i in range(num_comp):
        fig, ax = plt.subplots(1, num_img, figsize=(7*num_img, 7))
        for j in range(num_img):
            tmp = ax[j].imshow(coeffs_reshape[j][:, :, i], cmap="viridis")
            ax[j].axis("off")
            #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()
else:            
    for i in range(num_comp):
        fig, ax = plt.subplots(1, 1, figsize=(7, 7*num_img))
        tmp = ax.imshow(coeffs_reshape[0][:, :, i], cmap="viridis")
        ax.axis("off")
        #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()

In [None]:
coeffs = np.zeros_like(z_mu)
coeffs[ri] = z_mu.copy()
coeffs_reshape = reshape_coeff(coeffs, data_shape)

fig, ax = plt.subplots(1, 1)
ax.grid()
for i in range(num_comp):
    ax.hist(coeffs[:, i], bins=50, alpha=(1.0-i*(1/num_comp)))
plt.show()

if num_img != 1:
    for i in range(num_comp):
        fig, ax = plt.subplots(1, num_img, figsize=(7*num_img, 7))
        for j in range(num_img):
            tmp = ax[j].imshow(coeffs_reshape[j][:, :, i], cmap="viridis")
            ax[j].axis("off")
            #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()
else:            
    for i in range(num_comp):
        fig, ax = plt.subplots(1, 1, figsize=(7, 7*num_img))
        tmp = ax.imshow(coeffs_reshape[0][:, :, i], cmap="viridis")
        ax.axis("off")
        #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()

In [None]:
coeffs = np.zeros_like(z_logvar)
coeffs[ri] = np.exp(0.5*z_logvar)
coeffs_reshape = reshape_coeff(coeffs, data_shape)

fig, ax = plt.subplots(1, 1)
ax.grid()
for i in range(num_comp):
    ax.hist(coeffs[:, i], bins=50, alpha=(1.0-i*(1/num_comp)))
plt.show()

if num_img != 1:
    for i in range(num_comp):
        fig, ax = plt.subplots(1, num_img, figsize=(7*num_img, 7))
        for j in range(num_img):
            tmp = ax[j].imshow(coeffs_reshape[j][:, :, i], cmap="viridis")
            ax[j].axis("off")
            #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()
else:            
    for i in range(num_comp):
        fig, ax = plt.subplots(1, 1, figsize=(7, 7*num_img))
        tmp = ax.imshow(coeffs_reshape[0][:, :, i], cmap="viridis")
        ax.axis("off")
        #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()

In [None]:
n_sample = 10
sigma = 5.0
z_test = np.linspace(-sigma, sigma, n_sample*10, endpoint=True)
rv = stats.norm(0, 1)
norm_pdf = rv.pdf(z_test)
norm_pdf = norm_pdf / np.sum(norm_pdf)
z_test = np.sort(np.random.choice(z_test, n_sample, replace=False, p=norm_pdf))
z_test = np.meshgrid(z_test, z_test)
z_test = np.stack((z_test[0].flatten(), z_test[1].flatten()), axis=1)
print(z_test.shape)
z_test = torch.from_numpy(z_test).to(torch.float32).to(cuda_device)

In [None]:
dec_model.eval()
generated = dec_model(z_test)
print(generated.shape)

In [None]:
fig, ax = plt.subplots(n_sample, n_sample, figsize=(30, 30))
for i, a in enumerate(ax.flat):
    a.imshow(generated[i].squeeze().data.cpu().reshape(data_load.w_size, data_load.w_size), cmap="jet")
    a.axis("off")
plt.subplots_adjust(hspace=0.01, wspace=0.01)
plt.show()

In [None]:
# 2D subspace
%matplotlib qt
fig, ax = plt.subplots(1, 1, figsize=(7, 7))

def projection(c1, c2):
    ax.cla()
    ax.scatter(coeffs[:, c1], coeffs[:, c2], s=30, c="black", alpha=0.5)
    ax.grid()
    ax.set_xlabel("loading vector %d"%(c1+1), fontsize=15)
    ax.set_ylabel("loading vector %d"%(c2+1), fontsize=15)
    ax.tick_params(axis="both", labelsize=15)
    fig.canvas.draw()
    fig.tight_layout()

x_widget = pyw.IntSlider(min=0, max=z_dim-1, step=1, value=1)
y_widget = pyw.IntSlider(min=0, max=z_dim-1, step=1, value=2)

pyw.interact(projection, c1=x_widget, c2=y_widget)
plt.show()