In [None]:
import sys
sys.path.append("/home/jinseuk56/Desktop/github_repo/AEs/VAE/")
from AEs_module import *
import time
import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt
import ipywidgets as pyw
import tkinter.filedialog as tkf
import tifffile
from tabulate import tabulate
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import linalg as LA

In [None]:
file_adr = tkf.askopenfilenames()
print(*file_adr, sep="\n")

In [None]:
data_load = load_data(file_adr, dat_dim=4, dat_unit='1/nm', rescale=False)

In [None]:
data_load.find_center(cbox_edge=7, center_remove=0, result_visual=True, log_scale=True)

In [None]:
data_load.make_input(min_val=1E-6, max_normalize=True, 
           log_scale=False, radial_flat=False, 
           w_size=32, radial_range=None, final_dim=1)

In [None]:
if torch.cuda.is_available():
    print("%d gpu available"%(torch.cuda.device_count()))
    cuda_device = torch.device("cuda:0")
    print(torch.cuda.get_device_name(cuda_device))
    torch.cuda.set_device(cuda_device)
    print(torch.cuda.memory_summary(device=cuda_device))
else:
    cuda_device = None

In [None]:
rotation_check = True
angle_std = np.pi/4
translation_check = False
translation_std = 0.1

parallel_ = True

num_comp = 2
enc_hid_dim = [512]

enc_model = ivVAEFCNN_encoder(data_load.s_length, enc_hid_dim, num_comp, 
                              rotation_check, translation_check, translation_std)

if parallel_:
    enc_model = nn.DataParallel(enc_model)

enc_model.cuda(cuda_device)
print(enc_model)

In [None]:
batch_size = 510
mini_batches = [data_load.dataset_input[k:k+batch_size] for k in range(0, len(data_load.dataset_input), batch_size)]
print(len(mini_batches))
print(len(mini_batches[-1]))

In [None]:
grid = np.linspace(1, -1, data_load.w_size*2)
X, Y = np.meshgrid(grid, grid)
img_coord = np.stack([X.ravel(), Y.ravel()], 1)
img_coord = torch.from_numpy(img_coord)
img_coord = img_coord.to(torch.float32)
img_coord = img_coord.requires_grad_(requires_grad=False)
n_dim = img_coord.size(1)
print(img_coord.shape)

In [None]:
coord = img_coord.expand(batch_size, img_coord.size(0), img_coord.size(1))
coord = coord.requires_grad_(requires_grad=False)
coord = coord.to(cuda_device)
n_coord = coord.size(1)
print(n_coord)
print(coord.shape)

In [None]:
hid_dim = 512
num_hid = 1
dec_model = ivVAEFCNN_decoder(n_coord, n_dim, num_comp, hid_dim, num_hid, data_load.w_size*2, bi_lin=True)

if parallel_:
    dec_model = nn.DataParallel(dec_model)
    
dec_model.cuda(cuda_device)
print(dec_model)

In [None]:
glob_iter = 0
params = list(enc_model.parameters()) + list(dec_model.parameters())
optimizer = optim.Adam(params)

In [None]:
start = time.time()
n_fig = 5
n_epoch = 100

beta = 4.0
gamma = 1000
C_max = torch.Tensor([25.0]).to(cuda_device)
C_stop_iter = int(n_epoch/3)

l_rate = 0.001
optimizer.param_groups[0]['lr'] = l_rate

loss_plot = []
for epoch in range(n_epoch):
    glob_iter += 1
    loss_epoch = 0
    recon_loss = 0
    KLD_loss = 0
    KLD_rot_loss = 0
    
    latent_z = []
    z_mu = []
    z_logvar = []
    mu_rot = []
    logvar_rot = []
    mu_trans = []
    logvar_trans = []
    for i, m_batch in enumerate(mini_batches):
        
        x = torch.from_numpy(m_batch).clamp_(min=0.001, max=0.999)
        x = x.to(torch.float32)
        x = x.to(cuda_device)
        x.requires_grad_(requires_grad=False)
        
        tf_coord, mu, logvar, z, rot_mu, rot_logvar, rot_z, trans_mu, trans_logvar, trans_z = enc_model(x, coord)
        x_ = dec_model(tf_coord.contiguous(), z)

        reconstruction_error = reconstruction_loss(x_.squeeze(), x, mean=False, loss_fn="BCE")
        KL_divergence = ivVAE_KLD(mu, logvar, rot_mu, rot_logvar, ang_std=angle_std, mean=False, mode="normal", 
                                  beta=beta, gamma=gamma, C_max=C_max, C_stop_iter=C_stop_iter, glob_iter=glob_iter)
        
        loss = reconstruction_error + KL_divergence
        
        loss_epoch += loss.item()
        recon_loss += reconstruction_error.item()
        KLD_loss += KL_divergence.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        latent_z.extend(z.data.cpu().numpy().tolist())
        z_mu.extend(mu.data.cpu().numpy().tolist())
        z_logvar.extend(logvar.data.cpu().numpy().tolist())
        if rotation_check:
            mu_rot.extend(rot_mu.data.cpu().numpy().tolist())
            logvar_rot.extend(rot_logvar.data.cpu().numpy().tolist())

        if translation_check:
            mu_trans.extend(trans_mu.data.cpu().numpy().tolist())
            logvar_trans.extend(trans_logvar.data.cpu().numpy().tolist())
    
    loss_plot.append(loss_epoch/data_load.total_num)
    
    latent_z = np.asarray(latent_z)
    z_mu = np.asarray(z_mu)
    z_logvar = np.asarray(z_logvar)
    mu_rot = np.asarray(mu_rot).reshape(-1, 1)
    logvar_rot = np.asarray(logvar_rot).reshape(-1, 1)
    mu_trans = np.asarray(mu_trans)
    logvar_trans = np.asarray(logvar_trans)
            
    if epoch == 0:
        print(torch.cuda.memory_summary(device=cuda_device))
        
    if (epoch+1) % int(n_epoch/10) == 0:
        fig, ax = plt.subplots(1, 1)
        ax.plot(np.arange(epoch+1)+1, loss_plot, "k-")
        ax.grid()
        plt.show()
        
        print(tabulate([
                        ["epoch", epoch+1], 
                        ["total loss", loss_epoch/data_load.total_num],
                        ["reconstruction error", recon_loss/data_load.total_num],
                        ["KL divergence", KLD_loss/data_load.total_num],
                        ["error ratio", reconstruction_error/KL_divergence],
                        ["iteration ratio", C_stop_iter/glob_iter]
                        ]))
        print("%.2f minutes have passed"%((time.time()-start)/60))
        
        fig, ax = plt.subplots(2, n_fig, figsize=(5*n_fig, 5*2))
        for i in range(n_fig):
            ax[0][i].imshow(x[i].data.cpu().numpy().astype(np.float32).reshape(data_load.w_size*2, data_load.w_size*2), cmap="inferno")
            ax[1][i].imshow(x_[i].data.cpu().numpy().astype(np.float32).reshape(data_load.w_size*2, data_load.w_size*2), cmap="inferno")
        fig.tight_layout()
        plt.show()

        fig, ax = plt.subplots(1, 3, figsize=(15, 5))

        coeffs = np.zeros_like(latent_z)
        coeffs[data_load.ri] = latent_z.copy()
        latent_z_coeffs_reshape = reshape_coeff(coeffs, data_load.data_shape)  

        ax[0].set_title("latent z distribution")
        for i in range(num_comp):
            ax[0].hist(coeffs[:, i], bins=50, alpha=(1.0-i*(1/num_comp)))

        coeffs = np.zeros_like(z_mu)
        coeffs[data_load.ri] = z_mu.copy()
        z_mu_coeffs_reshape = reshape_coeff(coeffs, data_load.data_shape)

        ax[1].set_title("z mu distribution")
        for i in range(num_comp):
            ax[1].hist(coeffs[:, i], bins=50, alpha=(1.0-i*(1/num_comp)))

        coeffs = np.zeros_like(z_logvar)
        coeffs[data_load.ri] = z_logvar.copy()
        z_logvar_coeffs_reshape = reshape_coeff(coeffs, data_load.data_shape) 

        ax[2].set_title("z log(var) distribution")
        for i in range(num_comp):
            ax[2].hist(coeffs[:, i], bins=50, alpha=(1.0-i*(1/num_comp)))      

        plt.show()

        for i in range(data_load.num_img):
            fig, ax = plt.subplots(3, num_comp, figsize=(5*num_comp, 15))
            for j in range(num_comp):
                tmp = ax[0][j].imshow(latent_z_coeffs_reshape[i][:, :, j], cmap="inferno")
                ax[0][j].axis("off")
                tmp = ax[1][j].imshow(z_mu_coeffs_reshape[i][:, :, j], cmap="inferno")
                ax[1][j].axis("off")
                tmp = ax[2][j].imshow(z_logvar_coeffs_reshape[i][:, :, j], cmap="inferno")
                ax[2][j].axis("off")
            plt.show()  

        if rotation_check:
            for i in range(data_load.num_img):
                fig, ax = plt.subplots(1, 3, figsize=(15, 5))

                coeffs = np.zeros_like(mu_rot)
                coeffs[data_load.ri] = mu_rot.copy()
                rot_mu_coeffs_reshape = reshape_coeff(coeffs, data_load.data_shape)

                ax[0].hist(coeffs[:, 0], bins=50, alpha=0.5)
                ax[0].set_title("rotation mu and log(var) distribution")

                coeffs = np.zeros_like(logvar_rot)
                coeffs[data_load.ri] = logvar_rot.copy()
                rot_logvar_coeffs_reshape = reshape_coeff(coeffs, data_load.data_shape)

                ax[0].hist(coeffs[:, 0], bins=50, alpha=0.5)

                ax[1].imshow(rot_mu_coeffs_reshape[i][:, :, 0], cmap="inferno")
                ax[1].axis("off")
                ax[2].imshow(rot_logvar_coeffs_reshape[i][:, :, 0], cmap="inferno")
                ax[2].axis("off")

                plt.show()

        if translation_check:
            for i in range(data_load.num_img):
                fig, ax = plt.subplots(2, 3, figsize=(15, 10))

                coeffs = np.zeros_like(mu_trans)
                coeffs[data_load.ri] = mu_trans.copy()
                trans_mu_coeffs_reshape = reshape_coeff(coeffs, data_load.data_shape)

                ax[0][0].hist(coeffs[:, 0], bins=50, alpha=0.5)
                ax[0][0].hist(coeffs[:, 1], bins=50, alpha=0.5)
                ax[0][0].set_title("translation mu distribution y and x")

                coeffs = np.zeros_like(logvar_trans)
                coeffs[data_load.ri] = logvar_trans.copy()
                trans_logvar_coeffs_reshape = reshape_coeff(coeffs, data_load.data_shape)

                ax[1][0].hist(coeffs[:, 0], bins=50, alpha=0.5)
                ax[1][0].hist(coeffs[:, 1], bins=50, alpha=0.5)
                ax[1][0].set_title("translation log(var) distribution y and x")

                for j in range(2):
                    ax[0][j+1].imshow(trans_mu_coeffs_reshape[i][:, :, j], cmap="inferno")
                    ax[0][j+1].axis("off")
                    ax[1][j+1].imshow(trans_logvar_coeffs_reshape[i][:, :, j], cmap="inferno")
                    ax[1][j+1].axis("off")
                plt.show()

print("The training has been finished.")

In [None]:
save_result = False

for i in range(data_load.num_img):
    fig, ax = plt.subplots(1, num_comp, figsize=(5*num_comp, 5))
    for j in range(num_comp):
        tmp = ax[j].imshow(z_mu_coeffs_reshape[i][:, :, j], cmap="inferno")
        ax[j].axis("off")
        if save_result:
            tifffile.imwrite("z_mu_comp_%d.tif"%(j+1), z_mu_coeffs_reshape[i][:, :, j])
    plt.show()

In [None]:
n_sample = 20
sigma = 5.0
z_test = np.linspace(-sigma, sigma, n_sample*10, endpoint=True)
rv = stats.norm(0, 1)
norm_pdf = rv.pdf(z_test)
norm_pdf = norm_pdf / np.sum(norm_pdf)
z_test = np.sort(np.random.choice(z_test, n_sample, replace=False, p=norm_pdf))
z_test = np.meshgrid(z_test, z_test)
z_test = np.stack((z_test[0].flatten(), z_test[1].flatten()), axis=1)
print(z_test.shape)
z_test = torch.from_numpy(z_test).to(torch.float32).to(cuda_device)

coord_test = img_coord.expand(n_sample**2, img_coord.size(0), img_coord.size(1))
coord_test = coord_test.to(cuda_device)
print(coord_test.shape)

In [None]:
dec_model.eval()
generated = dec_model(coord_test.contiguous(), z_test)
print(generated.shape)
generated = generated.view(n_sample**2, data_load.w_size*2, data_load.w_size*2)
print(generated.shape)

In [None]:
fig, ax = plt.subplots(n_sample, n_sample, figsize=(30, 30))
for i, a in enumerate(ax.flat):
    a.imshow(generated[i].squeeze().data.cpu().numpy().astype(np.float32), cmap="jet")
    a.axis("off")
plt.subplots_adjust(hspace=0.01, wspace=0.01)
plt.show()