# Probabilistic PCA

## Import dependencies for the model

In [None]:
import pickle

import numpy as np
np.random.seed(13)

import scipy.sparse.linalg

import matplotlib.pyplot as plt

from tqdm import tqdm

import h5py as h5


## Probabilistic PCA Class

In [None]:
class ProbabilisticPCA(object):
    def __init__(self, n_components, n_singular_values_to_approximate_variance):
        assert n_components > 0
        assert n_singular_values_to_approximate_variance >= n_components
        
        self.n_components_ = n_components
        self.n_singular_values_to_approximate_variance_ = n_singular_values_to_approximate_variance
        
        self.mean_ = None
        self.components_ = None
        self.variance_ = None
    
    # Adapted from: https://github.com/davidstutz/probabilistic-pca
    def fit(self, data):
        n_samples = data.shape[0]
        n_components = self.n_components_
        n_singular_values_to_approximate_variance = self.n_singular_values_to_approximate_variance_
        
        mean = np.mean(data, axis=0)
        mean_centered_data = data - mean
        
        U, s, Vt = scipy.sparse.linalg.svds(mean_centered_data, k=n_components)
        
        n_singular_values_to_approximate_variance = min(n_singular_values_to_approximate_variance, n_samples)
        
        _, s_all, _ = scipy.sparse.linalg.svds(mean_centered_data, k=n_singular_values_to_approximate_variance)

        e = s ** 2 / (n_samples - 1)
        e_all = s_all ** 2 / (n_samples - 1)

        var = 1.0 / (n_samples - n_components) * (np.sum(e_all) - np.sum(e))
        
        L_m = np.diag(e - np.ones((n_components)) * var) ** 0.5
        V = Vt.T.dot(L_m)
        
        self.mean_ = mean
        self.components_ = V.T
        self.variance_ = np.array([var])
        
        return self
    
    # Adapted from: https://github.com/davidstutz/probabilistic-pca
    def transform(self, data):
        n_components = self.n_components_
        mean = self.mean_
        V = self.components_
        var = self.variance_
        
        I = np.eye(n_components)
        M = V.dot(V.T) + I * var
        M_inv = np.linalg.inv(M)

        mean_centered_data = data - mean
        
        projection = np.dot(mean_centered_data, V.T).dot(M_inv.T)

        return projection


## Utility functions

In [None]:
# https://stackoverflow.com/questions/39382412/crop-center-portion-of-a-numpy-image
def crop_center(img, cropx, cropy):
    y, x = img.shape
    startx = x // 2 - (cropx // 2)
    starty = y // 2 - (cropy // 2)    
    return img[starty : starty + cropy, startx : startx + cropx]

def center_crop_images(img_data, dataset_size, center_crop_target_width, center_crop_target_height, display_progress_bar=False):
    img_data_center_cropped = np.zeros((dataset_size, center_crop_target_width, center_crop_target_height))
    
    if display_progress_bar:
        for img_index in tqdm(range(dataset_size)):
            img_data_center_cropped[img_index] = crop_center(img_data[img_index], center_crop_target_width, center_crop_target_height)
    else:
        for img_index in tqdm(range(dataset_size)):
            img_data_center_cropped[img_index] = crop_center(img_data[img_index], center_crop_target_width, center_crop_target_height)

    return img_data_center_cropped

def plot_components_3d(r, xlim=None, ylim=None, figsize=(12, 12), nbins=50):
    fig, axes = plt.subplots(3, 3, figsize=figsize)
    
    for row in range(axes.shape[0]):
        for col in range(axes.shape[1]):
            ax = axes[row, col]
            if row == col:
                ax.hist(r[:, row], bins=nbins)
                if xlim is not None:
                    ax.set_xlim(xlim)
            else:
                ax.hexbin(r[:, col], r[:, row], mincnt=1)
                if xlim is not None:
                    ax.set_xlim(xlim)
                if ylim is not None:
                    ax.set_ylim(ylim)   
            
            if row == 0:
                ax.set_xlabel("PC {}".format(col + 1), fontsize=20)
                ax.xaxis.set_label_position('top') 
            
            if col == 0:
                ax.set_ylabel("PC {}".format(row + 1), fontsize=20)
    
    fig.tight_layout()
    plt.show()


## Fit the model to the data

In [None]:
dataset_name = "3iyf-10K-mixed-hit-99"
downsampled_images_output_subdir = "downsample-128x128"

dataset_size = 10000
center_crop_target_height = 100
center_crop_target_width = 100

h5_file = "/reg/data/ana03/scratch/deebanr/{}/dataset/{}/cspi_synthetic_dataset_diffraction_patterns_3iyf-10K-mixed-hit_uniform_quat_dataset-size={}_diffraction-pattern-shape=1024x1040.hdf5".format(dataset_name, downsampled_images_output_subdir, dataset_size)
# h5_file = "/reg/data/ana03/scratch/deebanr/{}/dataset/cspi_synthetic_dataset_diffraction_patterns_3iyf-10K-mixed-hit_uniform_quat_dataset-size={}_diffraction-pattern-shape=1024x1040.hdf5".format(dataset_name, dataset_size)

h5_file_handle = h5.File(h5_file, 'r')

data_to_fit_and_project = h5_file_handle["downsampled_diffraction_patterns"][:]
# data_to_fit_and_project = h5_file_handle["diffraction_patterns"][:]
data_to_fit_and_project = data_to_fit_and_project.reshape((dataset_size, -1))

print("Fitting Probabilistic PCA to vectorized images of shape: {}".format(data_to_fit_and_project.shape))
probabilistic_pca = ProbabilisticPCA(n_components=3, n_singular_values_to_approximate_variance=3)
probabilistic_pca.fit(data_to_fit_and_project)

h5_file_handle.close()


## Save the model

In [None]:
probabilistic_pca_file = "probabilistic-pca-{}-{}-dataset_size={}.pkl".format(dataset_name, downsampled_images_output_subdir, dataset_size)
with open(probabilistic_pca_file, 'wb') as probabilistic_pca_file_handle:
    pickle.dump(probabilistic_pca, probabilistic_pca_file_handle)

print("Saved Probabilistic PCA model to: {}".format(probabilistic_pca_file_handle))


## Load the model

In [None]:
with open(probabilistic_pca_file, 'rb') as probabilistic_pca_file_handle:
    loaded_probabilistic_pca = pickle.load(probabilistic_pca_file_handle)

print("Loaded Probabilistic PCA model from: {}".format(probabilistic_pca_file))
print("Principal components:\n{}".format(loaded_probabilistic_pca.components_))


## Project the data using the model

In [None]:
latent_projections = probabilistic_pca.transform(data_to_fit_and_project)


## Plot the projection

In [None]:
plot_components_3d(latent_projections, figsize=(9, 9), nbins=100)
