In [None]:
import time
import numpy as np
import tifffile
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import matplotlib.cm as cm
import ipywidgets as pyw
import hyperspy.api as hys
import tkinter.filedialog as tkf
from tabulate import tabulate
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch import linalg as LA

In [None]:
# create a customized colorbar
color_rep = ["black", "red", "green", "blue", "orange", "purple", "yellow", "lime", 
             "cyan", "magenta", "lightgray", "peru", "springgreen", "deepskyblue", 
             "hotpink", "darkgray"]
print(len(color_rep))
custom_cmap = mcolors.ListedColormap(color_rep)
bounds = np.arange(-1, len(color_rep))
norm = mcolors.BoundaryNorm(boundaries=bounds, ncolors=len(color_rep))
sm = cm.ScalarMappable(cmap=custom_cmap, norm=norm)
sm.set_array([])

cm_rep = ["gray", "Reds", "Greens", "Blues", "Oranges", "Purples"]
print(len(cm_rep))

In [None]:
def reshape_coeff(coeffs, new_shape):
    """
    reshape a coefficient matrix to restore the original scanning shapes.
    """
    coeff_reshape = []
    for i in range(len(new_shape)):
        temp = coeffs[:int(new_shape[i, 0]*new_shape[i, 1]), :]
        coeffs = np.delete(coeffs, range(int(new_shape[i, 0]*new_shape[i, 1])), axis=0)
        temp = np.reshape(temp, (new_shape[i, 0], new_shape[i, 1], -1))
        #print(temp.shape)
        coeff_reshape.append(temp)
        
    return coeff_reshape

def radial_indices(shape, radial_range, center=None):
    y, x = np.indices(shape)
    if not center:
        center = np.array([(y.max()-y.min())/2.0, (x.max()-x.min())/2.0])
    
    r = np.hypot(y - center[0], x - center[1])
    ri = np.ones(r.shape)
    
    if len(np.unique(radial_range)) > 1:
        ri[np.where(r <= radial_range[0])] = 0
        ri[np.where(r > radial_range[1])] = 0
        
    else:
        r = np.round(r)
        ri[np.where(r != round(radial_range[0]))] = 0
    
    return ri

In [None]:
file_adr = []

In [None]:
file_adr.extend(tkf.askopenfilenames())
print(len(file_adr))
print(*file_adr, sep="\n")

In [None]:
num_img = len(file_adr)
print(num_img)

In [None]:
data_original = []
data_shape = []
for i in range(num_img):
    tmp = tifffile.imread(file_adr[i])
    print(tmp.shape)
    data_shape.append(list(tmp.shape[:2]))
    data_original.append(tmp)
    
data_shape = np.asarray(data_shape)

In [None]:
center_pos = []
cbox_edge = 100
for i in range(num_img):
    mean_dp = np.mean(data_original[i], axis=(0, 1))
    cbox_outy = int(mean_dp.shape[0]/2 - cbox_edge/2)
    cbox_outx = int(mean_dp.shape[1]/2 - cbox_edge/2)
    center_box = mean_dp[cbox_outy:-cbox_outy, cbox_outx:-cbox_outx]
    Y, X = np.indices(center_box.shape)
    com_y = np.sum(center_box * Y) / np.sum(center_box)
    com_x = np.sum(center_box * X) / np.sum(center_box)
    c_pos = [np.around(com_y+cbox_outy), np.around(com_x+cbox_outx)]
    center_pos.append(c_pos)
print(center_pos)

In [None]:
center_radius = 5
for i in range(num_img):
    ri = radial_indices(data_original[i].shape[2:], [center_radius, 100], center=center_pos[i])
    data_original[i] = np.multiply(data_original[i], ri)

In [None]:
side_length = 30
box_size = np.array([side_length, side_length])
dataset = []
for i in range(num_img):
    h_si = np.floor(center_pos[i][0]-box_size[0]).astype(int)
    h_fi = np.ceil(center_pos[i][0]+box_size[0]).astype(int)
    w_si = np.floor(center_pos[i][1]-box_size[1]).astype(int)
    w_fi = np.ceil(center_pos[i][1]+box_size[1]).astype(int)
    
    tmp = data_original[i][:, :, h_si:h_fi, w_si:w_fi]    
    dataset.extend(tmp.tolist())

dataset = np.asarray(dataset)
print(dataset.shape)

In [None]:
w_size = side_length*2
print(w_size)
    
dataset_flat = dataset.reshape(-1, w_size, w_size).clip(min=0.0)
mask = np.max(dataset_flat, axis=(1, 2))
print(mask.shape)
mask[np.where(mask==0)] = 1.0
#dataset_flat = np.log(dataset_flat)
dataset_flat = dataset_flat / mask[:, np.newaxis, np.newaxis]
print(dataset_flat.shape)

In [None]:
total_num = len(dataset_flat)
ri = np.random.choice(total_num, total_num, replace=False)

dataset_input = dataset_flat[ri]

In [None]:
class conv_ae(nn.Module):
    def __init__(self, input_size, encoded_dimension, channels, kernels, poolings):
        super(conv_ae, self).__init__()

        self.input_size = input_size
        self.encoded_dimension = encoded_dimension
        
        enc_net = []
        enc_net.append(nn.Conv2d(1, channels[0], kernels[0], bias=True))
        enc_net.append(nn.BatchNorm2d(channels[0]))
        enc_net.append(nn.LeakyReLU(0.1))
        enc_net.append(nn.AvgPool2d(poolings[0]))
        for i in range(1, len(channels)):
            enc_net.append(nn.Conv2d(channels[i-1], channels[i], kernels[i], bias=True))
            enc_net.append(nn.BatchNorm2d(channels[i]))
            enc_net.append(nn.LeakyReLU(0.1))
            enc_net.append(nn.AvgPool2d(poolings[i]))
            
        enc_net.append(nn.Flatten())
        self.encoder = nn.Sequential(*enc_net)
        
        self.decoder = nn.Sequential(
            nn.Linear(self.encoded_dimension, (2*self.input_size)**2, bias=False),
            nn.Sigmoid(),
        )
        
    def forward(self, x):
        x = x.view(-1, 1, self.input_size*2, self.input_size*2)
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return encoded, decoded

In [None]:
if torch.cuda.is_available():
    print("%d gpu available"%(torch.cuda.device_count()))
    cuda_device = torch.device("cuda:0")
    print(torch.cuda.get_device_name(cuda_device))
    torch.cuda.set_device(cuda_device)
else:
    cuda_device = None

In [None]:
num_comp = 5
channels = [100, 100, 100, num_comp]
kernels = [4, 3, 3, 0]
pooling = [2, 2, 2, 2]

dat_dim = []
tmp_dim = side_length*2
for i in range(len(kernels)):
    tmp_dim += (-kernels[i]+1)
    tmp_dim /= pooling[i]
    dat_dim.append(int(tmp_dim))
    
kernels[-1] = dat_dim[-2] - pooling[-1] + 1
dat_dim[-1] = 1

print(dat_dim)
print(kernels)
print(channels)
print(pooling)

In [None]:
model = conv_ae(side_length, dat_dim[-1]*channels[-1], channels, kernels, pooling)

parallel_ = True

if parallel_
    model = nn.DataParallel(model)
    
model.cuda(cuda_device)
print(model)

In [None]:
batch_size = 1280
mini_batches = [dataset_input[k:k+batch_size] for k in range(0, len(dataset_input), batch_size)]

In [None]:
n_epoch = 500
l_rate = 0.001
lmbd_tv = 0.0
lmbd_reg = 0.0
alpha = 1.0

In [None]:
#optimizer = optim.SGD(model.parameters(), lr=l_rate)
optimizer = optim.Adam(model.parameters(), lr=l_rate)
torch.nn.init.orthogonal_(model.module.decoder[0].weight)
print(optimizer)

In [None]:
start = time.time()
ae_coeffs = []
ae_bias = []
for epoch in range(n_epoch):
    for i, m_batch in enumerate(mini_batches):
        
        x = torch.from_numpy(m_batch)
        x = x.to(torch.float32)
        x = x.to(cuda_device)
        x.requires_grad_(requires_grad=False)
        
        flat_x = x.view(-1, w_size**2)
        
        encoded, decoded = model(x)
        
        l1_reg = 0
        l2_reg = 0
        tv_reg = 0
        
        if parallel_:
            model_access = model.module
        else:
            model_access = model
        

        for i in range(num_comp):
            tv_reg += lmbd_tv*LA.norm(torch.gradient(model_access.decoder[0].weight[:, i])[0], 1)
            l1_reg += lmbd_reg*alpha*(LA.norm(model_access.decoder[0].weight[:, i], 1))
            l2_reg += lmbd_reg*(1-alpha)*(LA.norm(model_access.decoder[0].weight[:, i], 2))
        
        main_loss = F.binary_cross_entropy(decoded, flat_x, reduction="mean")
        #main_loss = LA.norm((decoded - flat_x), 2) / len(m_batch)
        total_loss = main_loss + l1_reg + l2_reg + tv_reg
        
        optimizer.zero_grad()
        total_loss.backward()
        optimizer.step()
        
        model_access.decoder[0].weight.data.clamp_(min=0.0)
        
        if epoch == n_epoch-1:
            coeff_batch = encoded.data.cpu().numpy().tolist()
            ae_coeffs.extend(coeff_batch)            
    
    
    if epoch == 0:
        print(torch.cuda.memory_summary(device=cuda_device))
    
    if (epoch+1) % int(n_epoch/10) == 0:
        print(tabulate([
                        ["epoch", epoch+1], 
                        ["total loss", total_loss.item()],
                        ["-main loss", main_loss.item(), main_loss.item()*100/total_loss.item()],
                        ["-L1 reg.", l1_reg.item(), l1_reg.item()*100/total_loss.item()],
                        ["-L2 reg.", l2_reg.item(), l2_reg.item()*100/total_loss.item()],
                        ["-TV reg.", tv_reg.item(), tv_reg.item()*100/total_loss.item()],
                        ]))
        print("%.2f minutes have passed"%((time.time()-start)/60))
        
        fig, ax = plt.subplots(1, num_comp, figsize=(5*num_comp, 5))
        for i in range(num_comp):
            ax[i].imshow(model_access.decoder[0].weight.data.cpu()[:, i].reshape(w_size, w_size), cmap="viridis")
            ax[i].axis("off")
        fig.tight_layout()
        plt.show()

print("The training has been finished.")

In [None]:
ae_coeffs = np.asarray(ae_coeffs)
ae_comp_vectors = model.module.decoder[0].weight.data.cpu().numpy().T
print(ae_coeffs.shape)
print(ae_comp_vectors.shape)

# convert the coefficient matrix into coefficient maps
coeffs = np.zeros_like(ae_coeffs)
coeffs[ri] = ae_coeffs.copy()
coeffs_reshape = reshape_coeff(coeffs, data_shape)

In [None]:
# visualize loading vectors
for i in range(num_comp):
    fig, ax = plt.subplots(1, 1, figsize=(5, 5))
    ax.imshow(ae_comp_vectors[i].reshape(box_size*2), cmap="viridis")
    ax.axis("off")
    fig.tight_layout()
    plt.show()

In [None]:
%matplotlib inline

In [None]:
%matplotlib qt

In [None]:
# visualize the coefficient maps
if num_img != 1:
    for i in range(num_comp):
        fig, ax = plt.subplots(num_img, 1, figsize=(10*num_img, 10))
        for j in range(num_img):
            tmp = ax[j].imshow(coeffs_reshape[j][:, :, i], cmap="viridis")
            ax[j].set_title("loading vector %d map"%(i+1), fontsize=10)
            ax[j].axis("off")
            #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()
else:            
    for i in range(num_comp):
        fig, ax = plt.subplots(1, 1, figsize=(5*num_img, 5))
        tmp = ax.imshow(coeffs_reshape[0][:, :, i], cmap="viridis")
        ax.set_title("loading vector %d map"%(i+1), fontsize=10)
        ax.axis("off")
        #fig.colorbar(tmp, cax=fig.add_axes([0.92, 0.15, 0.04, 0.7]))
        plt.show()

In [None]:
# 2D subspace
%matplotlib qt
fig, ax = plt.subplots(1, 1, figsize=(7, 7))

def projection(c1, c2):
    ax.cla()
    ax.scatter(coeffs[:, c1], coeffs[:, c2], s=30, c="black", alpha=0.5)
    ax.grid()
    ax.set_xlabel("loading vector %d"%(c1+1), fontsize=15)
    ax.set_ylabel("loading vector %d"%(c2+1), fontsize=15)
    ax.tick_params(axis="both", labelsize=15)
    fig.canvas.draw()
    fig.tight_layout()

x_widget = pyw.IntSlider(min=0, max=num_comp-1, step=1, value=1)
y_widget = pyw.IntSlider(min=0, max=num_comp-1, step=1, value=2)

pyw.interact(projection, c1=x_widget, c2=y_widget)
plt.show()