In [None]:
import sys
sys.path.append("/home/jinseuk56/Desktop/github_repo/AEs/VAE/")
from AEs_module import *
import time
import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import ipywidgets as pyw
import tkinter.filedialog as tkf
import tifffile
from tabulate import tabulate
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import linalg as LA

In [None]:
file_adr = tkf.askopenfilenames()
print(*file_adr, sep="\n")

In [None]:
data_load = load_data(file_adr, dat_dim=4, dat_unit='1/nm', rescale=False)

In [None]:
data_load.find_center(cbox_edge=10, center_remove=0, result_visual=True, log_scale=True)

In [None]:
data_load.make_input(min_val=1E-6, max_normalize=True, 
           log_scale=False, radial_flat=False, 
           w_size=32, radial_range=None, final_dim=2)

In [None]:
if torch.cuda.is_available():
    print("%d gpu available"%(torch.cuda.device_count()))
    cuda_device = torch.device("cuda:0")
    print(torch.cuda.get_device_name(cuda_device))
    torch.cuda.set_device(cuda_device)
    print(torch.cuda.memory_summary(device=cuda_device))
else:
    cuda_device = None

In [None]:
num_comp = 2
channels = [32, 64, 128, 256]
kernels = [4, 4, 4, 4]
padding = [1, 1, 1, 1]
stride = [2, 2, 2, 2]
pooling = [1, 1, 1, 1]

dat_dim = []
tmp_dim = data_load.w_size*2
for i in range(len(kernels)):
    tmp_dim += (-kernels[i]+2*padding[i])
    tmp_dim /= stride[i]
    tmp_dim += 1
    tmp_dim /= pooling[i]
    dat_dim.append(int(tmp_dim))

print(dat_dim)
print(kernels)
print(channels)
print(padding)
print(stride)
print(pooling)

In [None]:
parallel_ = False

enc_model = CAE2D_encoder(dat_dim[-1], channels, kernels, stride, padding, pooling)
if parallel_:
    enc_model = nn.DataParallel(enc_model)
enc_model.cuda(cuda_device)
print(enc_model)

In [None]:
dec_kernel = [4, 4, 4, 4]
dec_stride = [2, 2, 2, 2]
dec_padding = [1, 1, 1, 1]
dec_outpad = [0, 0, 1, 1]

dec_dim = []
enc_dim = dat_dim[-1]
for i in range(len(dec_kernel)):
    enc_dim = (enc_dim-1)*dec_stride[i] + dec_kernel[i] - 2*dec_padding[i] + dec_outpad[i]
    dec_dim.append(enc_dim)
    
print(dec_dim)

dec_kernel.reverse()
dec_stride.reverse()
dec_padding.reverse()
dec_outpad.reverse()

final_kernel = 4
print(dec_dim[-1] - final_kernel + 1)

In [None]:
dec_model = linFE_decoder(num_comp, (2*data_load.w_size)**2)

if parallel_:
    dec_model = nn.DataParallel(dec_model)
    
dec_model.cuda(cuda_device)
print(dec_model)

In [None]:
batch_size = 510
mini_batches = [data_load.dataset_input[k:k+batch_size] for k in range(0, len(data_load.dataset_input), batch_size)]
print(len(mini_batches))
print(len(mini_batches[-1]))

In [None]:
params = list(enc_model.parameters()) + list(dec_model.parameters())
optimizer = optim.Adam(params)

In [None]:
start = time.time()
n_epoch = 200
ae_coeffs = []
ae_bias = []
for epoch in range(n_epoch):
    for i, m_batch in enumerate(mini_batches):
        
        x = torch.from_numpy(m_batch)
        x = x.to(torch.float32)
        x = x.to(cuda_device)
        x.requires_grad_(requires_grad=False)
        
        encoded = enc_model(x)
        decoded = dec_model(encoded)
        
        if parallel_:
            model_access = dec_model.module
        else:
            model_access = dec_model
        
        
        main_loss = F.binary_cross_entropy(decoded.view(-1, data_load.w_size, data_load.w_size), x, reduction="mean")
        #main_loss = LA.norm((decoded - flat_x), 2) / len(m_batch)
        
        optimizer.zero_grad()
        main_loss.backward()
        optimizer.step()
        
        model_access.decoder[0].weight.data.clamp_(min=0.0)
        
        if epoch == n_epoch-1:
            coeff_batch = encoded.data.cpu().numpy().tolist()
            ae_coeffs.extend(coeff_batch)            
    
    
    if epoch == 0:
        print(torch.cuda.memory_summary(device=cuda_device))
    
    if (epoch+1) % int(n_epoch/10) == 0:
        print(tabulate([
                        ["epoch", epoch+1], 
                        ["main loss", main_loss.item()],
                        ]))
        print("%.2f minutes have passed"%((time.time()-start)/60))
        
        fig, ax = plt.subplots(1, num_comp, figsize=(5*num_comp, 5))
        for i in range(num_comp):
            ax[i].imshow(model_access.decoder[0].weight.data.cpu()[:, i].reshape(data_load.w_size, data_load.w_size), cmap="viridis")
            ax[i].axis("off")
        fig.tight_layout()
        plt.show()

print("The training has been finished.")

In [None]:
ae_coeffs = np.asarray(ae_coeffs)
ae_comp_vectors = model_access.decoder[0].weight.data.cpu().numpy().T
print(ae_coeffs.shape)
print(ae_comp_vectors.shape)

# convert the coefficient matrix into coefficient maps
coeffs = np.zeros_like(ae_coeffs)
coeffs[data_load.ri] = ae_coeffs.copy()
coeffs_reshape = reshape_coeff(coeffs, data_load.data_shape)

In [None]:
# visualize loading vectors
for i in range(num_comp):
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    ax.imshow(ae_comp_vectors[i].reshape(data_load.w_size, data_load.w_size), cmap="viridis")
    ax.axis("off")
    fig.tight_layout()
    plt.show()

In [None]:
# visualize the coefficient maps
if data_load.num_img != 1:
    for i in range(num_comp):
        fig, ax = plt.subplots(data_load.num_img, 1, figsize=(10*data_load.num_img, 10))
        for j in range(data_load.num_img):
            tmp = ax[j].imshow(coeffs_reshape[j][:, :, i], cmap="viridis")
            ax[j].set_title("loading vector %d map"%(i+1), fontsize=10)
            ax[j].axis("off")
            #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()
else:            
    for i in range(num_comp):
        fig, ax = plt.subplots(1, 1, figsize=(5*data_load.num_img, 5))
        tmp = ax.imshow(coeffs_reshape[0][:, :, i], cmap="viridis")
        ax.set_title("loading vector %d map"%(i+1), fontsize=10)
        ax.axis("off")
        #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()

In [None]:
# 2D subspace
fig, ax = plt.subplots(1, 1, figsize=(7, 7))

def projection(c1, c2):
    ax.cla()
    ax.scatter(coeffs[:, c1], coeffs[:, c2], s=30, c="black", alpha=0.5)
    ax.grid()
    ax.set_xlabel("loading vector %d"%(c1+1), fontsize=15)
    ax.set_ylabel("loading vector %d"%(c2+1), fontsize=15)
    ax.tick_params(axis="both", labelsize=15)
    fig.canvas.draw()
    fig.tight_layout()

x_widget = pyw.IntSlider(min=0, max=num_comp-1, step=1, value=1)
y_widget = pyw.IntSlider(min=0, max=num_comp-1, step=1, value=2)

pyw.interact(projection, c1=x_widget, c2=y_widget)
plt.show()