In [None]:
import torch
import numpy as np
from sklearn.utils.extmath import randomized_svd
from sklearn.decomposition import TruncatedSVD

import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
import seaborn as sns

from tqdm import tqdm

import h5py

from utils.checkpoints import load_siren_from_checkpoint
from utils.siren_utils import get_transposed_mgrid

%matplotlib notebook

In [None]:
def slider_plot_latent_dim(latent_matrix, img_real, normalize_per_dimension=True, cm='magma'):
    if isinstance(latent_matrix, torch.Tensor):
        max_z = torch.max(latent_matrix)
        min_z = torch.min(latent_matrix)
    else:
        max_z = np.max(latent_matrix)
        min_z = np.min(latent_matrix)
    
    fig, axes = plt.subplots(1, 2, figsize=(11, 6))
    
    # display real image on the left
    im_real = axes[0].imshow(img_real, cmap='gray')
    axes[0].get_xaxis().set_ticks([])
    axes[0].get_yaxis().set_ticks([])
    fig.colorbar(im_real, ax=axes[0])
    
    # display initial latent dimension
    if normalize_per_dimension:
        im = axes[1].imshow(latent_matrix[:, :, 0], cmap=cm)
    else:
        im = axes[1].imshow(latent_matrix[:, :, 0], vmax=max_z, vmin=min_z, cmap=cm)
    fig.colorbar(im, ax=axes[1])
    axes[1].get_xaxis().set_ticks([])
    axes[1].get_yaxis().set_ticks([])
    
    # setup slider
    n_dim = latent_matrix.shape[-1]
    slider_ax = fig.add_axes([0.25, 0.1, 0.65, 0.03])
    dim_slider = Slider(ax=slider_ax, label='Dimension', orientation='horizontal', 
                        valinit=1, valmin=1, valmax=n_dim, valstep=1, closedmax=True)
    
    fig.subplots_adjust(bottom=0.25)
    axes[0].set_title('Target image')
    axes[1].set_title('Latent representation')
    plt.grid(False)
    
    def dim_update(val):
        idx = dim_slider.val - 1
        
        if normalize_per_dimension:
            # remove colorbar
            axes[1].images[-1].colorbar.remove()
            # update image
            temp_im = axes[1].imshow(latent_matrix[:, :, idx], cmap=cm)
            # update colorbar
            fig.colorbar(temp_im, ax=axes[1])
        else:
            axes[1].imshow(latent_matrix[:, :, idx], vmax=max_z, vmin=min_z, cmap=cm)
    
    dim_slider.on_changed(dim_update)
    plt.suptitle('Single latent dimensions')
    
    plt.show()

In [None]:
model_path = '../trained_models/siren/2023-12-08_16:02:59/checkpoint_26000'
img_path = '../datasets/LoDoPaB/ground_truth_train/ground_truth_train_000.hdf5'

# load SIREN
siren = load_siren_from_checkpoint(model_path)

# load actual image
img_idx = torch.load(model_path, map_location='cpu')['img_idx']
real_image = h5py.File(img_path)['data'][img_idx]

# get input coords for SIREN
latent_dim = 64
coord_side_length = 2 * 355
coords = get_transposed_mgrid(coord_side_length)
print(f'coords shape: {coords.shape}')

In [None]:
# Get latent vectors by passing coords to SIREN
with torch.no_grad():
    z, _ = siren(coords)
z_for_plot = z.reshape(coord_side_length, coord_side_length, 64)

print(f'z_for_plot.shape: {z_for_plot.shape}')

In [None]:
# show latent dimensions from raw z
slider_plot_latent_dim(latent_matrix=z_for_plot, img_real=real_image, normalize_per_dimension=True, cm='gray')

In [None]:
k = 20
u, s, vh = randomized_svd(z.numpy(), n_components=k)

smat = np.diag(s)

reduced_z = np.dot(u, smat)
reduced_z_matrix = reduced_z.reshape(coord_side_length, coord_side_length, k)

fig, ax = plt.subplots(1, 1)
sns.lineplot(s, ax=ax)
plt.xlabel('Index')
plt.ylabel('Singular value')
plt.title(f'First {k} singular values')
plt.show()

In [None]:
slider_plot_latent_dim(latent_matrix=reduced_z_matrix, img_real=real_image, normalize_per_dimension=True, cm='magma')

In [None]:
svd = TruncatedSVD(n_components=k)
svd.fit(z)
reduction_trunc_svd = svd.transform(z).reshape(coord_side_length, coord_side_length, k)

In [None]:
slider_plot_latent_dim(latent_matrix=reduction_trunc_svd, img_real=real_image, normalize_per_dimension=True, cm='magma')

In [None]:
def plot_all_dimensions_one_figure():
    fig, ax = plt.subplots(8, 8, figsize=(20, 20))
    for i in range(8):
        for j in range(8):
            a = ax[i, j]
            a.imshow(z_for_plot[:, :, 8 * i + j], cmap='magma')
            a.get_xaxis().set_ticks([])
            a.get_yaxis().set_ticks([])
    plt.tight_layout()
    plt.show()

In [None]:
# Singular Value Decomposition
if z.shape[0] <= 100:
    U, S, Vh = np.linalg.svd(z.numpy())
    print(f'U shape: {U.shape}')
    print(f'S shape: {S.shape}')
    print(f'Vh shape: {Vh.shape}')
    
    plt.close()
    plt.plot(S)
    plt.title('Singular Values')
    plt.xlabel('Index')
    plt.ylabel('Singular Value')
    plt.show()
    
    k = 10
    
    smat = np.zeros((U.shape[-1], Vh.shape[0]))
    reduced_diag = np.diag(S[:k])
    smat[:k, :k] = reduced_diag
    
    #recon = np.dot(U[:, :k], np.dot(smat[:k, :k], Vh[:k, :]))
    reduced_z = np.dot(U[:, :k], smat[:k, :k])
    reduced_z_matrix = reduced_z.reshape(coord_side_length, coord_side_length, k)

In [None]:
slider_plot_latent_dim(latent_matrix=reduced_z_matrix, img_real=real_image, normalize_per_dimension=True, cm='magma')