In [None]:
from VAEs_module import *
import time
import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt
import ipywidgets as pyw
import tkinter.filedialog as tkf
import tifffile
from tabulate import tabulate
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import linalg as LA
import torchvision
import torchvision.transforms as transforms

In [None]:
file_adr = tkf.askopenfilenames()
print(*file_adr, sep="\n")

In [None]:
data_load = load_data(file_adr, dat_dim=4, dat_unit='nm', rescale=False)

In [None]:
data_load.data_storage[0] = data_load.data_storage[0].reshape(1, 121, 64, 64)
data_load.data_shape[0] = data_load.data_storage[0].shape[:2]

In [None]:
data_load.find_center(cbox_edge=32, center_remove=0, result_visual=True, log_scale=True)

In [None]:
data_load.center_pos[0] = [32.0, 32.0]

In [None]:
data_load.make_input(min_val=1E-6, max_normalize=True, 
           log_scale=False, radial_flat=False, 
           w_size=32, radial_range=None, final_dim=2)

In [None]:
rand_affine = transforms.Compose([transforms.ToPILImage(),
                                  transforms.RandomAffine(degrees=(-30, 30), translate=(0.05, 0.05)),
                                  transforms.ToTensor()]
                                  )

In [None]:
num_imgs = len(data_load.dataset_flat)

tmp = []
aug_times = 9
for i in range(aug_times):
    for j in range(num_imgs):
        test = rand_affine(data_load.dataset_flat[j].astype(np.float32))
        tmp.append(test.squeeze().numpy())
    
tmp = np.asarray(tmp)
print(tmp.shape)
dataset_flat = np.append(data_load.dataset_flat, tmp, axis=0)
print(dataset_flat.shape)

fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.imshow(dataset_flat[150], cmap="viridis")
ax.axis("off")
plt.show()

In [None]:
total_num = len(dataset_flat)
ri = np.random.choice(total_num, total_num, replace=False)

dataset_input = dataset_flat[ri].reshape(total_num, -1)

In [None]:
if torch.cuda.is_available():
    print("%d gpu available"%(torch.cuda.device_count()))
    cuda_device = torch.device("cuda:0")
    print(torch.cuda.get_device_name(cuda_device))
    torch.cuda.set_device(cuda_device)
    print(torch.cuda.memory_summary(device=cuda_device))
else:
    cuda_device = None

In [None]:
rotation_check = True
angle_std = np.pi/4
translation_check = True
translation_std = 0.05

parallel_ = True

num_comp = 2
enc_hid_dim = [256]

enc_model = ivVAEFCNN_encoder(data_load.s_length, enc_hid_dim, num_comp, 
                              rotation_check, translation_check, translation_std)

if parallel_:
    enc_model = nn.DataParallel(enc_model)

enc_model.cuda(cuda_device)
print(enc_model)

In [None]:
batch_size = 1210
mini_batches = [dataset_input[k:k+batch_size] for k in range(0, len(dataset_input), batch_size)]
print(len(mini_batches))
print(len(mini_batches[-1]))

In [None]:
grid = np.linspace(1, -1, data_load.w_size*2)
X, Y = np.meshgrid(grid, grid)
img_coord = np.stack([X.ravel(), Y.ravel()], 1)
img_coord = torch.from_numpy(img_coord)
img_coord = img_coord.to(torch.float32)
img_coord = img_coord.requires_grad_(requires_grad=False)
n_dim = img_coord.size(1)
print(img_coord.shape)

In [None]:
coord = img_coord.expand(batch_size, img_coord.size(0), img_coord.size(1))
coord = coord.requires_grad_(requires_grad=False)
coord = coord.to(cuda_device)
n_coord = coord.size(1)
print(n_coord)
print(coord.shape)

In [None]:
hid_dim = 256
num_hid = 1
dec_model = ivVAEFCNN_decoder(n_coord, n_dim, num_comp, hid_dim, num_hid, data_load.w_size*2, bi_lin=True)

if parallel_:
    dec_model = nn.DataParallel(dec_model)
    
dec_model.cuda(cuda_device)
print(dec_model)

In [None]:
n_epoch = 2000
l_rate = 0.001
params = list(enc_model.parameters()) + list(dec_model.parameters())
optimizer = optim.Adam(params, lr=l_rate)

In [None]:
start = time.time()
loss_plot = []
n_fig = 5
for epoch in range(n_epoch):
    loss_epoch = 0
    recon_loss = 0
    KLD_loss = 0
    KLD_rot_loss = 0
    
    latent_z = []
    z_mu = []
    z_logvar = []
    rot_theta = []
    trans_delta = []
    for i, m_batch in enumerate(mini_batches):
        
        x = torch.from_numpy(m_batch).clamp_(min=0.001, max=0.999)
        x = x.to(torch.float32)
        x = x.to(cuda_device)
        x.requires_grad_(requires_grad=False)
        
        tf_coord, mu, logvar, z, rot_mu, rot_logvar, rot_z, trans_z = enc_model(x, coord)
        x_ = dec_model(tf_coord.contiguous(), z)

        reconstruction_error = F.binary_cross_entropy(x_, x, reduction="sum")
        KL_divergence = -0.5*torch.sum(1+logvar-mu**2-logvar.exp())
        KL_divergence_rot = torch.sum(-rot_logvar + np.log(angle_std) + (torch.exp(rot_logvar)**2 + 
                                                              rot_mu**2)/2/angle_std**2 - 0.5)
        
        loss = reconstruction_error + KL_divergence + KL_divergence_rot
        
        loss_epoch += loss.item()
        recon_loss += reconstruction_error.item()
        KLD_loss += KL_divergence.item()
        KLD_rot_loss += KL_divergence_rot.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        latent_z.extend(z.data.cpu().numpy().tolist())
        z_mu.extend(mu.data.cpu().numpy().tolist())
        z_logvar.extend(logvar.data.cpu().numpy().tolist())
        rot_theta.extend(rot_z.data.cpu().numpy().tolist())
        if translation_check:
            trans_delta.extend(trans_z.data.cpu().numpy().tolist())
    
    loss_plot.append(loss_epoch/total_num)
    
    latent_z = np.asarray(latent_z)
    z_mu = np.asarray(z_mu)
    z_logvar = np.asarray(z_logvar)
    rot_theta = np.asarray(rot_theta)
    trans_delta = np.asarray(trans_delta)
            
    if epoch == 0:
        print(torch.cuda.memory_summary(device=cuda_device))
        
    if (epoch+1) % int(n_epoch/10) == 0:
        fig, ax = plt.subplots(1, 1)
        ax.plot(np.arange(epoch+1)+1, loss_plot, "k-")
        ax.grid()
        plt.show()
        
        print(tabulate([
                        ["epoch", epoch+1], 
                        ["total loss", loss_epoch/total_num],
                        ["reconstruction error", recon_loss/total_num],
                        ["KL divergence", KLD_loss/total_num],
                        ["KL divergence (rotation)", KLD_rot_loss/total_num],
                        ]))
        print("%.2f minutes have passed"%((time.time()-start)/60))
        
        fig, ax = plt.subplots(2, n_fig, figsize=(5*n_fig, 5*2))
        for i in range(n_fig):
            ax[0][i].imshow(x[i].data.cpu().numpy().astype(np.float32).reshape(data_load.w_size*2, data_load.w_size*2), cmap="inferno")
            ax[1][i].imshow(x_[i].data.cpu().numpy().astype(np.float32).reshape(data_load.w_size*2, data_load.w_size*2), cmap="inferno")
        fig.tight_layout()
        plt.show()
print("The training has been finished.")

In [None]:
n_sample = 20
sigma = 5.0
z_test = np.linspace(-sigma, sigma, n_sample*10, endpoint=True)
rv = stats.norm(0, 1)
norm_pdf = rv.pdf(z_test)
norm_pdf = norm_pdf / np.sum(norm_pdf)
z_test = np.sort(np.random.choice(z_test, n_sample, replace=False, p=norm_pdf))
z_test = np.meshgrid(z_test, z_test)
z_test = np.stack((z_test[0].flatten(), z_test[1].flatten()), axis=1)
print(z_test.shape)
z_test = torch.from_numpy(z_test).to(torch.float32).to(cuda_device)

coord_test = img_coord.expand(n_sample**2, img_coord.size(0)*img_coord.size(1))
coord_test = coord_test.to(cuda_device)
print(coord_test.shape)

In [None]:
dec_model.eval()
generated = dec_model(coord_test.contiguous(), z_test)
print(generated.shape)
generated = generated.view(n_sample**2, data_load.w_size*2, data_load.w_size*2)
print(generated.shape)

In [None]:
fig, ax = plt.subplots(n_sample, n_sample, figsize=(30, 30))
for i, a in enumerate(ax.flat):
    a.imshow(generated[i].squeeze().data.cpu().numpy().astype(np.float32), cmap="jet")
    a.axis("off")
plt.subplots_adjust(hspace=0.01, wspace=0.01)
plt.show()