In [None]:
import time
import numpy as np
import scipy.stats as stats
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import tifffile
import ipywidgets as pyw
import tkinter.filedialog as tkf
from tabulate import tabulate
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import linalg as LA

In [None]:
# create a customized colorbar
color_rep = ["black", "red", "green", "blue", "orange", "purple", "yellow", "lime", 
             "cyan", "magenta", "lightgray", "peru", "springgreen", "deepskyblue", 
             "hotpink", "darkgray"]
print(len(color_rep))
custom_cmap = mcolors.ListedColormap(color_rep)
bounds = np.arange(-1, len(color_rep))
norm = mcolors.BoundaryNorm(boundaries=bounds, ncolors=len(color_rep))
sm = cm.ScalarMappable(cmap=custom_cmap, norm=norm)
sm.set_array([])

cm_rep = ["gray", "Reds", "Greens", "Blues", "Oranges", "Purples"]
print(len(cm_rep))

def zero_one_rescale(spectrum):
    """
    normalize one spectrum from 0.0 to 1.0
    """
    spectrum = spectrum.clip(min=0.0)
    min_val = np.min(spectrum)
    
    rescaled = spectrum - min_val
    
    if np.max(rescaled) != 0:
        rescaled = rescaled / np.max(rescaled)
    
    return rescaled

def indices_at_r(shape, radius, center=None):
    y, x = np.indices(shape)
    if not center:
        center = np.array([(y.max()-y.min())/2.0, (x.max()-x.min())/2.0])
    r = np.hypot(y - center[0], x - center[1])
    r = np.around(r)
    
    ri = np.where(r == radius)
    
    angle_arr = np.zeros(shape)
    for i in range(shape[0]):
        for j in range(shape[1]):
            angle_arr[i, j] = np.angle(complex(x[i, j]-center[1], y[i, j]-center[0]), deg=True)
            
    angle_arr = angle_arr + 180
    angle_arr = np.around(angle_arr)
    
    ai = np.argsort(angle_arr[ri])
    r_sort = (ri[1][ai], ri[0][ai])
    a_sort = np.sort(angle_arr[ri])
        
    return r_sort, a_sort

def circle_flatten(f_stack, radial_range, c_pos):
    k_indx = []
    k_indy = []
    
    for r in range(radial_range[0], radial_range[1], radial_range[2]):
        tmp_k, tmp_a = indices_at_r(f_stack.shape[2:], r, c_pos)
        k_indx.extend(tmp_k[0].tolist())
        k_indy.extend(tmp_k[1].tolist())
    
    k_indx = np.asarray(k_indx)
    k_indy = np.asarray(k_indy)
    flat_data = f_stack[:, :, k_indy, k_indx]
    
    return flat_data

def flattening(fdata, flat_option="box", crop_dist=None, c_pos=None):
    
    fdata_shape = fdata.shape
    if flat_option == "box":
        if crop_dist:     
            box_size = np.array([crop_dist, crop_dist])
        
            for i in range(num_img):
                h_si = np.floor(c_pos[0]-box_size[0]).astype(int)
                h_fi = np.ceil(c_pos[0]+box_size[0]).astype(int)
                w_si = np.floor(c_pos[1]-box_size[1]).astype(int)
                w_fi = np.ceil(c_pos[1]+box_size[1]).astype(int)

            tmp = fdata[:, :, h_si:h_fi, w_si:w_fi]
            
            fig, ax = plt.subplots(1, 1, figsize=(5, 5))
            ax.imshow(np.log(np.mean(tmp, axis=(0, 1))), cmap="viridis")
            ax.axis("off")
            plt.show()
            
            tmp = tmp.reshape(fdata_shape[0], fdata_shape[1], -1)
            return tmp

        else:
            tmp = fdata.reshape(fdata_shape[0], fdata_shape[1], -1)
            return tmp

        
    elif flat_option == "radial":
        if len(crop_dist) != 3:
            print("Warning! 'crop_dist' must be a list containing 3 elements")
            
        tmp = circle_flatten(fdata, crop_dist, c_pos)
        return tmp
        
    else:
        print("Warning! Wrong option ('flat_option')")
        return

def fourd_roll_axis(stack):
    stack = np.rollaxis(np.rollaxis(stack, 2, 0), 3, 1)
    return stack

def reshape_coeff(coeffs, new_shape):
    """
    reshape a coefficient matrix to restore the original scanning shapes.
    """
    coeff_reshape = []
    for i in range(len(new_shape)):
        temp = coeffs[:int(new_shape[i, 0]*new_shape[i, 1]), :]
        coeffs = np.delete(coeffs, range(int(new_shape[i, 0]*new_shape[i, 1])), axis=0)
        temp = np.reshape(temp, (new_shape[i, 0], new_shape[i, 1], -1))
        #print(temp.shape)
        coeff_reshape.append(temp)
        
    return coeff_reshape

def radial_indices(shape, radial_range, center=None):
    y, x = np.indices(shape)
    if not center:
        center = np.array([(y.max()-y.min())/2.0, (x.max()-x.min())/2.0])
    
    r = np.hypot(y - center[0], x - center[1])
    ri = np.ones(r.shape)
    
    if len(np.unique(radial_range)) > 1:
        ri[np.where(r <= radial_range[0])] = 0
        ri[np.where(r > radial_range[1])] = 0
        
    else:
        r = np.round(r)
        ri[np.where(r != round(radial_range[0]))] = 0
    
    return ri

In [None]:
file_adr = []

In [None]:
file_adr.extend(tkf.askopenfilenames())
print(len(file_adr))
print(*file_adr, sep="\n")

In [None]:
num_img = len(file_adr)
print(num_img)

In [None]:
# load 4D-STEM data
data_original = []
data_shape = []
for i in range(num_img):
    tmp = tifffile.imread(file_adr[i])
    print(tmp.shape)
    data_shape.append(list(tmp.shape[:2]))
    data_original.append(tmp)
    
data_shape = np.asarray(data_shape)

In [None]:
# find the center position
center_pos = []
cbox_edge = 100
center_removed_ = False
for i in range(num_img):
    mean_dp = np.mean(data_original[i], axis=(0, 1))
    cbox_outy = int(mean_dp.shape[0]/2 - cbox_edge/2)
    cbox_outx = int(mean_dp.shape[1]/2 - cbox_edge/2)
    center_box = mean_dp[cbox_outy:-cbox_outy, cbox_outx:-cbox_outx]
    Y, X = np.indices(center_box.shape)
    com_y = np.sum(center_box * Y) / np.sum(center_box)
    com_x = np.sum(center_box * X) / np.sum(center_box)
    c_pos = [np.around(com_y+cbox_outy), np.around(com_x+cbox_outx)]
    center_pos.append(c_pos)
print(center_pos)

In [None]:
np.seterr(divide='ignore')
for i in range(num_img):
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    ax.imshow(np.log(np.mean(data_original[i], axis=(0, 1))), cmap="viridis")
    ax.scatter(center_pos[i][1], center_pos[i][0], c="r", s=10)
    ax.axis("off")
    plt.show()

In [None]:
# get rid of the center beam (optional)
center_removed_ = True
center_radius = 10
data_cr = []
for i in range(num_img):
    ri = radial_indices(data_original[i].shape[2:], [center_radius, 100], center=center_pos[i])
    data_cr.append(np.multiply(data_original[i], ri))

In [None]:
if center_removed_:
    for i in range(num_img):
        fig, ax = plt.subplots(1, 1, figsize=(10, 10))
        ax.imshow(np.log(np.mean(data_cr[i], axis=(0, 1))), cmap="viridis")
        ax.scatter(center_pos[i][1], center_pos[i][0], c="r", s=10)
        ax.axis("off")
        plt.show()

In [None]:
# 2D diffraction pattern -> 1D data
# option 1 : flatten a box
radial_flat_ = False

dataset = []
w_size = 30
for i in range(num_img):
    if center_removed_:
        flattened = flattening(data_cr[i], flat_option="box", crop_dist=w_size, c_pos=center_pos[i])
    
    else:
        flattened = flattening(data_original[i], flat_option="box", crop_dist=w_size, c_pos=center_pos[i])
    
    dataset.append(flattened)
    
s_length = (w_size*2)**2

In [None]:
# 2D diffraction pattern -> 1D data
# option 2 : flatten radially
radial_flat_ = True

dataset = []
radial_range = [44, 64, 1]
k_indx = []
k_indy = []
a_ind = []

for r in range(radial_range[0], radial_range[1], radial_range[2]):
    tmp_k, tmp_a = indices_at_r((radial_range[1]*2, radial_range[1]*2), r, (radial_range[1], radial_range[1]))
    k_indx.extend(tmp_k[0].tolist())
    k_indy.extend(tmp_k[1].tolist())
    a_ind.extend(tmp_a.tolist())
    
s_length = len(k_indx)

k_indx = np.asarray(k_indx)
k_indy = np.asarray(k_indy)
a_ind = np.asarray(a_ind)
print(k_indx.shape, k_indy.shape, a_ind.shape)

for i in range(num_img):
    if center_removed_:
        flattened = circle_flatten(data_cr[i], radial_range, center_pos[i])
    else:
        flattened = circle_flatten(data_original[i], radial_range, center_pos[i])
        
    dataset.append(flattened)

In [None]:
# create the input dataset
dataset_flat = []
for i in range(num_img):
    print(dataset[i].shape)
    dataset_flat.extend(dataset[i].reshape(-1, s_length))
dataset_flat = np.asarray(dataset_flat).clip(min=0.0)
print(dataset_flat.min(), dataset_flat.max())
print(dataset_flat.shape)

In [None]:
# convert values into log scale (optional)
dataset_flat[np.where(dataset_flat==0.0)] = 1.0
dataset_flat = np.log(dataset_flat)
print(dataset_flat.min(), dataset_flat.max())

In [None]:
# max-normalize each flattened diffraction pattern (optional)
dataset_flat = dataset_flat / np.max(dataset_flat, axis=1)[:, np.newaxis]
dataset_flat = dataset_flat.clip(min=0.0)
print(dataset_flat.min(), dataset_flat.max())
print(dataset_flat.shape)

In [None]:
total_num = len(dataset_flat)
ri = np.random.choice(total_num, total_num, replace=False)

dataset_input = dataset_flat[ri]

In [None]:
class invar1DVAE_encoder(nn.Module):
    def __init__(self,  in_dim, h_dim, z_dim, rot_check=True, trans_check=True, trans_std=0.1):
        super(invar1DVAE_encoder, self).__init__()
        
        self.img_z_dim = z_dim        
        
        self.rot_check=rot_check
        self.trans_check=trans_check
        self.trans_std=trans_std
        
        self.z_dim = self.img_z_dim
        if self.rot_check:
            self.z_dim += 1
        if self.trans_check:
            self.z_dim += 2
        if not self.rot_check and not self.trans_check:
            print("Warning! at least one invariant property must be chosen")
            return
        
        enc_net = []
        enc_net.append(nn.Linear(in_dim, h_dim[0]))
        enc_net.append(nn.LeakyReLU(0.1))
        for i in range(1, len(h_dim)):
            enc_net.append(nn.Linear(h_dim[i-1], h_dim[i]))
            enc_net.append(nn.LeakyReLU(0.1))
        enc_net.append(nn.Linear(h_dim[-1], 2*self.z_dim))
        
        self.encoder = nn.Sequential(*enc_net)
        
    def encode(self, x):
        
        latent = self.encoder(x)
        mu = latent[:, :self.z_dim]
        logvar = latent[:, self.z_dim:]
        
        return mu, logvar
        
    def rotation(self, coord, z):
        rot_matrix = torch.stack((torch.cos(z), torch.sin(z), -torch.sin(z), torch.cos(z)), dim=1)
        rot_matrix = rot_matrix.view(-1, 2, 2)
        
        return torch.bmm(coord, rot_matrix)
        
        
    def translation(self, coord, z):
        trans_z = z * self.trans_std
        trans_z = trans_z.unsqueeze(1)
        
        return coord + trans_z
        
                
    def reparametrization(self, mu, logvar):
        
        return mu+torch.exp(0.5*logvar)*torch.randn_like(logvar)
        
        
    def forward(self, x, coord):
        if coord.size(0) != x.size(0):
            coord = coord[:x.size(0)]
        
        mu, logvar = self.encode(x)
        z = self.reparametrization(mu, logvar)
        
        rot_mu=None
        rot_logvar=None
        rot_z=None
        
        trans_mu=None
        trans_logvar=None
        trans_z=None
        
        if self.rot_check:
            rot_mu = mu[:, 0]
            mu = mu[:, 1:]
            
            rot_logvar = logvar[:, 0]
            logvar = logvar[:, 1:]
            
            rot_z = z[:, 0]
            z = z[:, 1:]
            
            coord = self.rotation(coord, rot_z)
            
        if self.trans_check:
            
            trans_z = z[:, 2:]
            z = z[:, 2:]
            
            coord = self.translation(coord, trans_z)
        
        return coord, mu, logvar, z, rot_mu, rot_logvar, rot_z, trans_z

In [None]:
class invar1DVAE_decoder(nn.Module):
    def __init__(self, n_coord, n_dim, z_dim, hid_dim, num_hid, bi_lin=False):
        super(invar1DVAE_decoder, self).__init__()
        
        self.n_coord = n_coord
        self.n_dim = n_dim
        self.z_dim = z_dim
        self.bi_lin = bi_lin
        
        self.linear_coord = nn.Linear(n_dim, hid_dim, bias=False)
        self.linear_img = nn.Linear(z_dim, hid_dim, bias=False)
        if bi_lin:
            self.bi_linear = nn.Bilinear(n_dim, z_dim, hid_dim, bias=False)
    
        
        dec_net = []
        for i in range(num_hid):
            dec_net.append(nn.Linear(hid_dim, hid_dim, bias=True))
            dec_net.append(nn.BatchNorm1d(hid_dim))
            dec_net.append(nn.Tanh())
            
        dec_net.append(nn.Linear(hid_dim, 1, bias=True))
        dec_net.append(nn.Sigmoid())
        
        self.decoder = nn.Sequential(*dec_net)
        
        
    def forward(self, coord, z):
        img_tmp = self.linear_img(z)
        z = z.unsqueeze(1)
        z = z.expand(z.size(0), self.n_coord, self.z_dim).contiguous()
        if self.bi_lin:
            bi_tmp = self.bi_linear(coord, z)
        coord = coord.view(coord.size(0)*coord.size(1), -1).contiguous()
        coord_tmp = self.linear_coord(coord)
        
        #print(img_tmp.shape, bi_tmp.shape, coord_tmp.shape)
        img_tmp = img_tmp.unsqueeze(1)
        coord_tmp = coord_tmp.view(z.size(0), self.n_coord, -1)
        #print(img_tmp.shape, bi_tmp.shape, coord_tmp.shape)
        
        
        if self.bi_lin:
            init_dec = coord_tmp + img_tmp + bi_tmp
        else:
            init_dec = coord_tmp + img_tmp
        #print(init_dec.shape)
        
        init_dec = init_dec.view(z.size(0)*self.n_coord, -1)
        output = self.decoder(init_dec.contiguous())
        
        return output.view(z.size(0), self.n_coord)

In [None]:
if torch.cuda.is_available():
    print("%d gpu available"%(torch.cuda.device_count()))
    cuda_device = torch.device("cuda:0")
    print(torch.cuda.get_device_name(cuda_device))
    torch.cuda.set_device(cuda_device)
    print(torch.cuda.memory_summary(device=cuda_device))
else:
    cuda_device = None

In [None]:
batch_size = 1024
mini_batches = [dataset_input[k:k+batch_size] for k in range(0, len(dataset_input), batch_size)]
print(len(mini_batches))
print(len(mini_batches[-1]))

In [None]:
if radial_flat_:
    img_coord = np.stack([k_indx, k_indy], 1)
    img_coord = torch.from_numpy(img_coord)
    img_coord = img_coord.to(torch.float32)
    img_coord = img_coord.requires_grad_(requires_grad=False)
    n_dim = img_coord.size(1)
    print(img_coord.shape)
    
else:
    grid = np.linspace(1, -1, w_size)
    X, Y = np.meshgrid(grid, grid)
    img_coord = np.stack([X.ravel(), Y.ravel()], 1)
    img_coord = torch.from_numpy(img_coord)
    img_coord = img_coord.to(torch.float32)
    img_coord = img_coord.requires_grad_(requires_grad=False)
    n_dim = img_coord.size(1)
    print(img_coord.shape)

In [None]:
coord = img_coord.expand(batch_size, img_coord.size(0), img_coord.size(1))
coord = coord.requires_grad_(requires_grad=False)
coord = coord.to(cuda_device)
n_coord = coord.size(1)
print(n_coord)
print(coord.shape)

In [None]:
h_dim = [256, 256]
num_comp = 2

rotation_check = True
angle_std = np.pi/4
translation_check = False
translation_std = 0.1

enc_model = invar1DVAE_encoder(s_length, h_dim, num_comp, 
                           rotation_check, translation_check, translation_std)
enc_model.cuda(cuda_device)
print(enc_model)

In [None]:
hid_dim = 256
num_hid = 1
parallel_ = True
dec_model = invar1DVAE_decoder(n_coord, n_dim, num_comp, hid_dim, num_hid, bi_lin=True)
if parallel_:
    dec_model = nn.DataParallel(dec_model)
dec_model.cuda(cuda_device)
print(dec_model)

In [None]:
n_epoch = 10
l_rate = 0.01
params = list(enc_model.parameters()) + list(dec_model.parameters())
optimizer = optim.Adam(params, lr=l_rate)

In [None]:
for i in range(1):
    x = torch.from_numpy(mini_batches[i]).clamp_(min=1E-3, max=1-1E-3)
    x = x.to(torch.float32)
    x = x.to(cuda_device)
    x.requires_grad_(requires_grad=False)
    
    #with torch.cuda.amp.autocast():
    tf_coord, mu, logvar, z, rot_mu, rot_logvar, rot_z, trans_mu, trans_logvar, trans_z = enc_model(x, coord)
    x_ = dec_model(tf_coord.contiguous(), z)
    
    print(x.shape, x_.shape)
    print(x.min(), x.max())
    print(x_.min(), x_.max())
    print(x.dtype, x_.dtype)
    
    reconstruction_error = F.binary_cross_entropy(x_, x, reduction="mean")
    KL_divergence = -0.5*torch.mean(1+logvar-mu**2-logvar.exp())
    KL_divergence_rot = torch.mean(-rot_logvar + np.log(angle_std) + (torch.exp(rot_logvar)**2 + 
                                                              rot_mu**2)/2/angle_std**2 - 0.5)
        
    loss = reconstruction_error + KL_divergence + KL_divergence_rot
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    #if torch.isinf(reconstruction_error):
    #    print(x.shape)
    #    print(x_.shape)
    #    print(x.min(), x.max())
    #    print(x_.min(), x_.max())   
    #    print(i, reconstruction_error)

In [None]:
start = time.time()
latent_z = []
z_mu = []
z_logvar = []
rot_theta = []
trans_delta = []
n_fig = 5
for epoch in range(n_epoch):
    loss_epoch = 0
    recon_loss = 0
    KLD_loss = 0
    KLD_rot_loss = 0
    for i, m_batch in enumerate(mini_batches):
        
        x = torch.from_numpy(m_batch).clamp_(min=0.001, max=0.999)
        x = x.to(torch.float32)
        x = x.to(cuda_device)
        x.requires_grad_(requires_grad=False)
        
        tf_coord, mu, logvar, z, rot_mu, rot_logvar, rot_z, trans_z = enc_model(x, coord)
        x_ = dec_model(tf_coord.contiguous(), z)
        print(x.min(), x.max())
        print(x_.min(), x_.max())
        
        reconstruction_error = F.binary_cross_entropy(x_, x, reduction="sum")
        KL_divergence = -0.5*torch.sum(1+logvar-mu**2-logvar.exp())
        KL_divergence_rot = torch.sum(-rot_logvar + np.log(angle_std) + (torch.exp(rot_logvar)**2 + 
                                                              rot_mu**2)/2/angle_std**2 - 0.5)
        
        loss = reconstruction_error + KL_divergence + KL_divergence_rot
        print(reconstruction_error)
        print(KL_divergence)
        print(KL_divergence_rot)
        print(loss)
        
        loss_epoch += loss.item()
        recon_loss += reconstruction_error.item()
        KLD_loss += KL_divergence.item()
        KLD_rot_loss += KL_divergence_rot.item()
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch == n_epoch-1:
            latent_z.extend(z.data.cpu().numpy().tolist())
            z_mu.extend(mu.data.cpu().numpy().tolist())
            z_logvar.extend(logvar.data.cpu().numpy().tolist())
            rot_theta.extend(rot_z.data.cpu().numpy().tolist())
            if translation_check:
                trans_delta.extend(trans_z.data.cpu().numpy().tolist())
            
            
    if epoch == 0:
        print(torch.cuda.memory_summary(device=cuda_device))
        
    if (epoch+1) % int(n_epoch/10) == 0:
        print(tabulate([
                        ["epoch", epoch+1], 
                        ["total loss", loss_epoch/total_num],
                        ["reconstruction error", recon_loss/total_num],
                        ["KL divergence", KLD_loss/total_num],
                        ["KL divergence (rotation)", KLD_rot_loss/total_num],
                        ]))
        print("%.2f minutes have passed"%((time.time()-start)/60))
        
        fig, ax = plt.subplots(2, n_fig, figsize=(5*n_fig, 5*2))
        for i in range(n_fig):
            if radial_flat_:
                tmp = np.zeros((radial_range[1]*2, radial_range[1]*2))
                tmp[k_indy, k_indx] = x[i].data.cpu().numpy()
                ax[0][i].imshow(tmp, cmap="inferno")
                ax[0][i].axis("off")
                tmp[k_indy, k_indx] = x_[i].data.cpu().numpy()
                ax[1][i].imshow(tmp, cmap="inferno")
                ax[1][i].axis("off")
                
            else:
                ax[0][i].imshow(x[i].data.cpu().numpy().reshape(w_size, w_size), cmap="inferno")
                ax[0][i].axis("off")
                ax[1][i].imshow(x_[i].data.cpu().numpy().reshape(w_size, w_size), cmap="inferno")
                ax[1][i].axis("off")
        fig.tight_layout()
        plt.show()

latent_z = np.asarray(latent_z)
z_mu = np.asarray(z_mu)
z_logvar = np.asarray(z_logvar)
rot_theta = np.asarray(rot_theta)
trans_delta = np.asarray(trans_delta)
print("The training has been finished.")

In [None]:
coeffs = np.zeros_like(latent_z)
coeffs[ri] = latent_z.copy()
coeffs_reshape = reshape_coeff(coeffs, data_shape)

fig, ax = plt.subplots(1, 1)
ax.grid()
for i in range(num_comp):
    ax.hist(coeffs[:, i], bins=50, alpha=(1.0-i*(1/num_comp)))
plt.show()

if num_img != 1:
    for i in range(num_comp):
        fig, ax = plt.subplots(1, num_img, figsize=(7*num_img, 7))
        for j in range(num_img):
            tmp = ax[j].imshow(coeffs_reshape[j][:, :, i], cmap="viridis")
            ax[j].axis("off")
            #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()
else:            
    for i in range(num_comp):
        fig, ax = plt.subplots(1, 1, figsize=(7, 7*num_img))
        tmp = ax.imshow(coeffs_reshape[0][:, :, i], cmap="viridis")
        ax.axis("off")
        #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()

In [None]:
if rotation_check:
    rot_theta = np.expand_dims(rot_theta, axis=1)
    coeffs = np.zeros_like(rot_theta)
    coeffs[ri] = rot_theta.copy()
    coeffs_reshape = reshape_coeff(coeffs, data_shape)

    fig, ax = plt.subplots(1, 1)
    ax.grid()
    ax.hist(coeffs[:], bins=50, alpha=(1.0-i*(1/num_comp)))
    plt.show()

    if num_img != 1:
        fig, ax = plt.subplots(1, num_img, figsize=(7*num_img, 7))
        for j in range(num_img):
            tmp = ax[j].imshow(coeffs_reshape[j][:, :], cmap="viridis")
            ax[j].axis("off")
            #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()
    else:            
        fig, ax = plt.subplots(1, 1, figsize=(7, 7*num_img))
        tmp = ax.imshow(coeffs_reshape[0][:, :], cmap="viridis")
        ax.axis("off")
        #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()

In [None]:
if translation_check:
    trans_delta = np.expand_dims(trans_delta, axis=1)
    coeffs = np.zeros_like(trans_delta)
    coeffs[ri] = rot_theta.copy()
    coeffs_reshape = reshape_coeff(coeffs, data_shape)

    fig, ax = plt.subplots(1, 1)
    ax.grid()
    ax.hist(coeffs[:], bins=50, alpha=(1.0-i*(1/num_comp)))
    plt.show()

    if num_img != 1:
        fig, ax = plt.subplots(1, num_img, figsize=(7*num_img, 7))
        for j in range(num_img):
            tmp = ax[j].imshow(coeffs_reshape[j][:, :], cmap="viridis")
            ax[j].axis("off")
            #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()
    else:            
        fig, ax = plt.subplots(1, 1, figsize=(7, 7*num_img))
        tmp = ax.imshow(coeffs_reshape[0][:, :], cmap="viridis")
        ax.axis("off")
        #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()

In [None]:
n_sample = 10
sigma = 5.0
z_test = np.linspace(-sigma, sigma, n_sample*10, endpoint=True)
rv = stats.norm(0, 1)
norm_pdf = rv.pdf(z_test)
norm_pdf = norm_pdf / np.sum(norm_pdf)
z_test = np.sort(np.random.choice(z_test, n_sample, replace=False, p=norm_pdf))
z_test = np.meshgrid(z_test, z_test)
z_test = np.stack((z_test[0].flatten(), z_test[1].flatten()), axis=1)
print(z_test.shape)
z_test = torch.from_numpy(z_test).to(torch.float32).to(cuda_device)

coord_test = img_coord.expand(n_sample**2, img_coord.size(0), img_coord.size(1))
coord_test = coord_test.to(cuda_device)
print(coord_test.shape)

In [None]:
dec_model.eval()
generated = dec_model(coord_test.contiguous(), z_test)
print(generated.shape)

In [None]:
fig, ax = plt.subplots(n_sample, n_sample, figsize=(30, 30))
tmp = np.zeros((radial_range[1]*2, radial_range[1]*2))
for i, a in enumerate(ax.flat):
    tmp[k_indy, k_indx] = generated[i].squeeze().data.cpu().numpy()
    a.imshow(tmp, cmap="jet")
    a.axis("off")
plt.subplots_adjust(hspace=0.01, wspace=0.01)
plt.show()