In [None]:
# J. Ryu, Electron Microscopy and Spectroscopy Lab., Seoul National University
# 20220413

import time
import glob
import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import ipywidgets as pyw
import tkinter.filedialog as tkf
import tifffile
from tabulate import tabulate
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import linalg as LA

In [None]:
def zero_one_rescale(spectrum):
    """
    normalize one spectrum from 0.0 to 1.0
    """
    spectrum = spectrum.clip(min=0.0)
    min_val = np.min(spectrum)
    
    rescaled = spectrum - min_val
    
    if np.max(rescaled) != 0:
        rescaled = rescaled / np.max(rescaled)
    
    return rescaled

def reshape_coeff(coeffs, new_shape):
    """
    reshape a coefficient matrix to restore the original scanning shapes.
    """
    coeff_reshape = []
    for i in range(len(new_shape)):
        temp = coeffs[:int(new_shape[i, 0]*new_shape[i, 1]), :]
        coeffs = np.delete(coeffs, range(int(new_shape[i, 0]*new_shape[i, 1])), axis=0)
        temp = np.reshape(temp, (new_shape[i, 0], new_shape[i, 1], -1))
        #print(temp.shape)
        coeff_reshape.append(temp)
        
    return coeff_reshape

def indices_at_r(shape, radius, center=None):
    y, x = np.indices(shape)
    if not center:
        center = np.array([(y.max()-y.min())/2.0, (x.max()-x.min())/2.0])
    r = np.hypot(y - center[0], x - center[1])
    r = np.around(r)
    
    ri = np.where(r == radius)
    
    angle_arr = np.zeros(shape)
    for i in range(shape[0]):
        for j in range(shape[1]):
            angle_arr[i, j] = np.angle(complex(x[i, j]-center[1], y[i, j]-center[0]), deg=True)
            
    angle_arr = angle_arr + 180
    angle_arr = np.around(angle_arr)
    
    ai = np.argsort(angle_arr[ri])
    r_sort = (ri[1][ai], ri[0][ai])
    a_sort = np.sort(angle_arr[ri])
        
    return r_sort, a_sort

def circle_flatten(f_stack, radial_range, c_pos):
    k_indx = []
    k_indy = []
    
    for r in range(radial_range[0], radial_range[1], radial_range[2]):
        tmp_k, tmp_a = indices_at_r(f_stack.shape[2:], r, c_pos)
        k_indx.extend(tmp_k[0].tolist())
        k_indy.extend(tmp_k[1].tolist())
    
    k_indx = np.asarray(k_indx)
    k_indy = np.asarray(k_indy)
    flat_data = f_stack[:, :, k_indy, k_indx]
    
    return flat_data

def flattening(fdata, flat_option="box", crop_dist=None, c_pos=None):
    
    fdata_shape = fdata.shape
    if flat_option == "box":
        if crop_dist:     
            box_size = np.array([crop_dist, crop_dist])
        
            for i in range(num_img):
                h_si = np.floor(c_pos[0]-box_size[0]).astype(int)
                h_fi = np.ceil(c_pos[0]+box_size[0]).astype(int)
                w_si = np.floor(c_pos[1]-box_size[1]).astype(int)
                w_fi = np.ceil(c_pos[1]+box_size[1]).astype(int)

            tmp = fdata[:, :, h_si:h_fi, w_si:w_fi]
            
            fig, ax = plt.subplots(1, 1, figsize=(5, 5))
            ax.imshow(np.log(np.mean(tmp, axis=(0, 1))), cmap="viridis")
            ax.axis("off")
            plt.show()
            
            tmp = tmp.reshape(fdata_shape[0], fdata_shape[1], -1)
            return tmp

        else:
            tmp = fdata.reshape(fdata_shape[0], fdata_shape[1], -1)
            return tmp

        
    elif flat_option == "radial":
        if len(crop_dist) != 3:
            print("Warning! 'crop_dist' must be a list containing 3 elements")
            
        tmp = circle_flatten(fdata, crop_dist, c_pos)
        return tmp
        
    else:
        print("Warning! Wrong option ('flat_option')")
        return

def fourd_roll_axis(stack):
    stack = np.rollaxis(np.rollaxis(stack, 2, 0), 3, 1)
    return stack

def radial_indices(shape, radial_range, center=None):
    y, x = np.indices(shape)
    if not center:
        center = np.array([(y.max()-y.min())/2.0, (x.max()-x.min())/2.0])
    
    r = np.hypot(y - center[0], x - center[1])
    ri = np.ones(r.shape)
    
    if len(np.unique(radial_range)) > 1:
        ri[np.where(r <= radial_range[0])] = 0
        ri[np.where(r > radial_range[1])] = 0
        
    else:
        r = np.round(r)
        ri[np.where(r != round(radial_range[0]))] = 0
    
    return ri

In [None]:
file_adr = tkf.askopenfilenames()
print(*file_adr, sep="\n")

In [None]:
num_img = len(file_adr)
print(num_img)

In [None]:
# load 4D-STEM data
data_original = []
data_shape = []
for i in range(num_img):
    tmp = tifffile.imread(file_adr[i])
    print(tmp.shape)
    data_shape.append(list(tmp.shape[:2]))
    data_original.append(tmp)
    
data_shape = np.asarray(data_shape)

In [None]:
# find the center position
center_pos = []
cbox_edge = 15
center_removed_ = False
for i in range(num_img):
    mean_dp = np.mean(data_original[i], axis=(0, 1))
    cbox_outy = int(mean_dp.shape[0]/2 - cbox_edge/2)
    cbox_outx = int(mean_dp.shape[1]/2 - cbox_edge/2)
    center_box = mean_dp[cbox_outy:-cbox_outy, cbox_outx:-cbox_outx]
    Y, X = np.indices(center_box.shape)
    com_y = np.sum(center_box * Y) / np.sum(center_box)
    com_x = np.sum(center_box * X) / np.sum(center_box)
    c_pos = [np.around(com_y+cbox_outy), np.around(com_x+cbox_outx)]
    center_pos.append(c_pos)
print(center_pos)

In [None]:
np.seterr(divide='ignore')
for i in range(num_img):
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    ax.imshow(np.log(np.mean(data_original[i], axis=(0, 1))), cmap="viridis")
    ax.scatter(center_pos[i][1], center_pos[i][0], c="r", s=10)
    ax.axis("off")
    plt.show()

In [None]:
# get rid of the center beam (optional)
center_removed_ = True
center_radius = 45
data_cr = []
for i in range(num_img):
    ri = radial_indices(data_original[i].shape[2:], [center_radius, 100], center=center_pos[i])
    data_cr.append(np.multiply(data_original[i], ri))

In [None]:
if center_removed_:
    for i in range(num_img):
        fig, ax = plt.subplots(1, 1, figsize=(10, 10))
        ax.imshow(np.log(np.mean(data_cr[i], axis=(0, 1))), cmap="viridis")
        ax.scatter(center_pos[i][1], center_pos[i][0], c="r", s=10)
        ax.axis("off")
        plt.show()

In [None]:
side_length = 50
box_size = np.array([side_length, side_length])
dataset = []
for i in range(num_img):
    h_si = np.floor(center_pos[i][0]-box_size[0]).astype(int)
    h_fi = np.ceil(center_pos[i][0]+box_size[0]).astype(int)
    w_si = np.floor(center_pos[i][1]-box_size[1]).astype(int)
    w_fi = np.ceil(center_pos[i][1]+box_size[1]).astype(int)
    
    if center_removed_:
        tmp = data_cr[i][:, :, h_si:h_fi, w_si:w_fi] 
    else: 
        tmp = data_original[i][:, :, h_si:h_fi, w_si:w_fi]    
    dataset.append(tmp)

In [None]:
w_size = side_length*2
print(w_size)

dataset_flat = []
for i in range(num_img):
    print(dataset[i].shape)
    dataset_flat.extend(dataset[i].reshape(-1, w_size, w_size))
dataset_flat = np.asarray(dataset_flat).clip(min=0.0)
print(dataset_flat.shape)
print(np.min(dataset_flat), np.max(dataset_flat))
print(np.mean(dataset_flat))

In [None]:
# convert values into log scale (optional)
dataset_flat[np.where(dataset_flat==0.0)] = 1.0
dataset_flat = np.log(dataset_flat)
dataset_flat = dataset_flat.clip(min=0.0)

In [None]:
# max-normalize each flattened diffraction pattern (optional)
dataset_flat = dataset_flat / np.max(dataset_flat, axis=(1, 2))[:, np.newaxis, np.newaxis]
print(np.min(dataset_flat), np.max(dataset_flat))
print(np.mean(dataset_flat))
print(dataset_flat.shape)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(5, 5))
ax.imshow(dataset_flat[5], cmap="viridis")
ax.axis("off")
plt.show()

In [None]:
total_num = len(dataset_flat)
ri = np.random.choice(total_num, total_num, replace=False)

dataset_input = dataset_flat[ri]

In [None]:
class simple_VAE(nn.Module):
    def __init__(self, final_length, channels, kernels, strides, paddings, z_dim, pooling,
                 dec_kernels, dec_strides, dec_paddings, dec_outpads, dec_poolings, f_kernel):
        super(simple_VAE, self).__init__()
        
        self.z_dim = z_dim
        self.final_length = final_length
        self.channels = channels
        
        enc_net = []
        enc_net.append(nn.Conv2d(1, channels[0], kernels[0], stride=strides[0], 
                                 padding=paddings[0], bias=True))
        enc_net.append(nn.BatchNorm2d(channels[0]))
        enc_net.append(nn.Tanh())
        if pooling[0] != 1:
            enc_net.append(nn.AvgPool2d(pooling[0]))
        for i in range(1, len(channels)):
            enc_net.append(nn.Conv2d(channels[i-1], channels[i], kernels[i], stride=strides[i],
                                     padding=paddings[i], bias=True))
            enc_net.append(nn.BatchNorm2d(channels[i]))
            enc_net.append(nn.Tanh())
            if pooling[i] != 1:
                enc_net.append(nn.AvgPool2d(pooling[i]))
                
        enc_net.append(nn.Flatten())
        enc_net.append(nn.Linear(self.final_length**2*channels[-1], 2*self.z_dim))
        
        self.encoder = nn.Sequential(*enc_net)
        
        self.init_decoder = nn.Linear(self.z_dim, self.final_length**2*channels[-1])
        
        dec_net = []
        for i in range(len(channels)-1, 0, -1):
            dec_net.append(nn.ConvTranspose2d(channels[i], channels[i-1], dec_kernels[i], dec_strides[i],
                                              dec_paddings[i], output_padding=dec_outpad[i], bias=True))
            dec_net.append(nn.BatchNorm2d(channels[i-1]))
            dec_net.append(nn.Tanh())
            
        dec_net.append(nn.ConvTranspose2d(channels[0], 1, dec_kernels[0], dec_strides[0], 
                                          dec_paddings[0], output_padding=dec_outpad[0], bias=True))
        dec_net.append(nn.BatchNorm2d(1))
        dec_net.append(nn.Tanh())
        
        dec_net.append(nn.Conv2d(1, 1, f_kernel, bias=True))
        dec_net.append(nn.Sigmoid())
        
        self.decoder = nn.Sequential(*dec_net)
        
        
    def encode(self, x):
        
        latent = self.encoder(x)
        mu = latent[:, :self.z_dim]
        logvar = latent[:, self.z_dim:]
        
        return mu, logvar
        
                
    def reparametrization(self, mu, logvar):
        
        return mu+torch.exp(0.5*logvar)*torch.randn_like(logvar)
        
        
    def decode(self, z):
        
        init_decoded = self.init_decoder(z)
        init_decoded = init_decoded.view(-1, self.channels[-1], self.final_length, self.final_length)
        
        return self.decoder(init_decoded.contiguous())
        
        
    def forward(self, x):
        x = x.unsqueeze(1)
        mu, logvar = self.encode(x)
        z = self.reparametrization(mu, logvar)
        
        return mu, logvar, z, self.decode(z)

In [None]:
if torch.cuda.is_available():
    print("%d gpu available"%(torch.cuda.device_count()))
    cuda_device = torch.device("cuda:0")
    print(torch.cuda.get_device_name(cuda_device))
    torch.cuda.set_device(cuda_device)
    print(torch.cuda.memory_summary(device=cuda_device))
else:
    cuda_device = None

In [None]:
num_comp = 2
channels = [32, 32, 32, 32]
kernels = [5, 5, 5, 5]
padding = [1, 1, 1, 1]
stride = [1, 1, 1, 1]
pooling = [2, 2, 2, 2]

dat_dim = []
tmp_dim = w_size
for i in range(len(kernels)):
    tmp_dim += (-kernels[i]+2*padding[i])
    tmp_dim /= stride[i]
    tmp_dim += 1
    tmp_dim /= pooling[i]
    dat_dim.append(int(tmp_dim))

print(dat_dim)
print(kernels)
print(channels)
print(padding)
print(stride)
print(pooling)

In [None]:
dec_kernel = [3, 3, 3, 3]
dec_stride = [2, 2, 2, 2]
dec_padding = [1, 1, 1, 1]
dec_outpad = [1, 1, 1, 1]

dec_dim = []
enc_dim = dat_dim[-1]
for i in range(len(dec_kernel)):
    enc_dim = (enc_dim-1)*dec_stride[i] + dec_kernel[i] - 2*dec_padding[i] + dec_outpad[i]
    dec_dim.append(enc_dim)
    
print(dec_dim)

dec_kernel.reverse()
dec_stride.reverse()
dec_padding.reverse()
dec_outpad.reverse()

final_kernel = 5
print(dec_dim[-1] - final_kernel + 1)

In [None]:
model = simple_VAE(dat_dim[-1], channels, kernels, stride, padding, num_comp, 
                   dec_kernel, dec_stride, dec_padding, dec_outpad, final_kernel)

parallel_ = True

if parallel_:
    model = nn.DataParallel(model)
    
model.cuda(cuda_device)
print(model)

In [None]:
batch_size = 256
mini_batches = [dataset_input[k:k+batch_size] for k in range(0, len(dataset_input), batch_size)]
print(len(mini_batches))
print(len(mini_batches[-1]))

In [None]:
n_epoch = 100
l_rate = 0.001
optimizer = optim.Adam(model.parameters(), lr=l_rate)
print(optimizer)

In [None]:
start = time.time()
loss_plot = []
n_fig = 5
for epoch in range(n_epoch):
    loss_epoch = 0
    recon_loss = 0
    KLD_loss = 0
    
    latent_z = []
    z_mu = []
    z_logvar = []
    for i, m_batch in enumerate(mini_batches):
        
        x = torch.from_numpy(m_batch).clamp_(min=0.001, max=0.999)
        x = x.to(torch.float32)
        x = x.to(cuda_device)
        x.requires_grad_(requires_grad=False)
        
        mu, logvar, z, x_ = model(x)
        
        reconstruction_error = F.binary_cross_entropy(x_.squeeze(), x, reduction="sum")
        KL_divergence = -0.5*torch.sum(1+logvar-mu**2-logvar.exp())
        
        loss = reconstruction_error + KL_divergence
        loss_epoch += loss.item()
        recon_loss += reconstruction_error.item()
        KLD_loss += KL_divergence.item()
        
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        latent_z.extend(z.data.cpu().numpy().tolist())
        z_mu.extend(mu.data.cpu().numpy().tolist())
        z_logvar.extend(logvar.data.cpu().numpy().tolist())
    
    loss_plot.append(loss_epoch/total_num)       
    
    latent_z = np.asarray(latent_z)
    z_mu = np.asarray(z_mu)
    z_logvar = np.asarray(z_logvar)
            
    if epoch == 0:
        print(torch.cuda.memory_summary(device=cuda_device))
        
    if (epoch+1) % int(n_epoch/10) == 0:
        fig, ax = plt.subplots(1, 1)
        ax.plot(np.arange(epoch+1)+1, loss_plot, "k-")
        ax.grid()
        plt.show()        
        
        print(tabulate([
                        ["epoch", epoch+1], 
                        ["total loss", loss_epoch/total_num],
                        ["reconstruction error", recon_loss/total_num],
                        ["KL divergence", KLD_loss/total_num],
                        ["error ratio", reconstruction_error/KL_divergence],
                        ]))
        print("%.2f minutes have passed"%((time.time()-start)/60))
        
        fig, ax = plt.subplots(2, n_fig, figsize=(5*n_fig, 5*2))
        for i in range(n_fig):
            ax[0][i].imshow(x[i].data.cpu(), cmap="inferno")
            ax[1][i].imshow(x_[i].squeeze().data.cpu(), cmap="inferno")
        fig.tight_layout()
        plt.show()
        
        coeffs = np.zeros_like(latent_z)
        coeffs[ri] = latent_z.copy()
        coeffs_reshape = reshape_coeff(coeffs, data_shape)  

        fig, ax = plt.subplots(1, 1)
        ax.grid()
        for i in range(num_comp):
            ax.hist(coeffs[:, i], bins=50, alpha=(1.0-i*(1/num_comp)))
        plt.show()

        if num_img != 1:
            for i in range(num_comp):
                fig, ax = plt.subplots(1, num_img, figsize=(7*num_img, 7))
                for j in range(num_img):
                    tmp = ax[j].imshow(coeffs_reshape[j][:, :, i], cmap="viridis")
                    ax[j].axis("off")
                    #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
                plt.show()
        else:            
            for i in range(num_comp):
                fig, ax = plt.subplots(1, 1, figsize=(7, 7*num_img))
                tmp = ax.imshow(coeffs_reshape[0][:, :, i], cmap="viridis")
                ax.axis("off")
                #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
                plt.show()        
        
print("The training has been finished.")

In [None]:
coeffs = np.zeros_like(latent_z)
coeffs[ri] = latent_z.copy()
coeffs_reshape = reshape_coeff(coeffs, data_shape)

fig, ax = plt.subplots(1, 1)
ax.grid()
for i in range(num_comp):
    ax.hist(coeffs[:, i], bins=50, alpha=(1.0-i*(1/num_comp)))
plt.show()

if num_img != 1:
    for i in range(num_comp):
        fig, ax = plt.subplots(1, num_img, figsize=(7*num_img, 7))
        for j in range(num_img):
            tmp = ax[j].imshow(coeffs_reshape[j][:, :, i], cmap="viridis")
            ax[j].axis("off")
            #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()
else:            
    for i in range(num_comp):
        fig, ax = plt.subplots(1, 1, figsize=(7, 7*num_img))
        tmp = ax.imshow(coeffs_reshape[0][:, :, i], cmap="viridis")
        ax.axis("off")
        #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()

In [None]:
coeffs = np.zeros_like(z_mu)
coeffs[ri] = z_mu.copy()
coeffs_reshape = reshape_coeff(coeffs, data_shape)

fig, ax = plt.subplots(1, 1)
ax.grid()
for i in range(num_comp):
    ax.hist(coeffs[:, i], bins=50, alpha=(1.0-i*(1/num_comp)))
plt.show()

if num_img != 1:
    for i in range(num_comp):
        fig, ax = plt.subplots(1, num_img, figsize=(7*num_img, 7))
        for j in range(num_img):
            tmp = ax[j].imshow(coeffs_reshape[j][:, :, i], cmap="viridis")
            ax[j].axis("off")
            #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()
else:            
    for i in range(num_comp):
        fig, ax = plt.subplots(1, 1, figsize=(7, 7*num_img))
        tmp = ax.imshow(coeffs_reshape[0][:, :, i], cmap="viridis")
        ax.axis("off")
        #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()

In [None]:
coeffs = np.zeros_like(z_logvar)
coeffs[ri] = np.exp(0.5*z_logvar)
coeffs_reshape = reshape_coeff(coeffs, data_shape)

fig, ax = plt.subplots(1, 1)
ax.grid()
for i in range(num_comp):
    ax.hist(coeffs[:, i], bins=50, alpha=(1.0-i*(1/num_comp)))
plt.show()

if num_img != 1:
    for i in range(num_comp):
        fig, ax = plt.subplots(1, num_img, figsize=(7*num_img, 7))
        for j in range(num_img):
            tmp = ax[j].imshow(coeffs_reshape[j][:, :, i], cmap="viridis")
            ax[j].axis("off")
            #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()
else:            
    for i in range(num_comp):
        fig, ax = plt.subplots(1, 1, figsize=(7, 7*num_img))
        tmp = ax.imshow(coeffs_reshape[0][:, :, i], cmap="viridis")
        ax.axis("off")
        #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()

In [None]:
%matplotlib inline

In [None]:
%matplotlib qt

In [None]:
n_sample = 10
sigma = 5.0
z_test = np.linspace(-sigma, sigma, n_sample*10, endpoint=True)
rv = stats.norm(0, 1)
norm_pdf = rv.pdf(z_test)
norm_pdf = norm_pdf / np.sum(norm_pdf)
z_test = np.sort(np.random.choice(z_test, n_sample, replace=False, p=norm_pdf))
z_test = np.meshgrid(z_test, z_test)
z_test = np.stack((z_test[0].flatten(), z_test[1].flatten()), axis=1)
print(z_test.shape)
z_test = torch.from_numpy(z_test).to(torch.float32).to(cuda_device)

In [None]:
model.eval()
generated = model.module.decode(z_test)
print(generated.shape)

In [None]:
fig, ax = plt.subplots(n_sample, n_sample, figsize=(30, 30))
for i, a in enumerate(ax.flat):
    a.imshow(generated[i].squeeze().data.cpu(), cmap="jet")
    a.axis("off")
plt.subplots_adjust(hspace=0.01, wspace=0.01)
plt.show()

In [None]:
# 2D subspace
%matplotlib qt
fig, ax = plt.subplots(1, 1, figsize=(7, 7))

def projection(c1, c2):
    ax.cla()
    ax.scatter(coeffs[:, c1], coeffs[:, c2], s=30, c="black", alpha=0.5)
    ax.grid()
    ax.set_xlabel("loading vector %d"%(c1+1), fontsize=15)
    ax.set_ylabel("loading vector %d"%(c2+1), fontsize=15)
    ax.tick_params(axis="both", labelsize=15)
    fig.canvas.draw()
    fig.tight_layout()

x_widget = pyw.IntSlider(min=0, max=num_comp-1, step=1, value=0)
y_widget = pyw.IntSlider(min=0, max=num_comp-1, step=1, value=1)

pyw.interact(projection, c1=x_widget, c2=y_widget)
plt.show()