In [None]:
%pip install kneed

In [None]:
import os
import re
import shutil
import random

import numpy as np
import matplotlib.pyplot as plt
import stempy.io as stio
import stempy.image as stim
import h5py
import ncempy
import ipywidgets as widgets

from pathlib import Path
from PIL import Image

from matplotlib.colors import LogNorm, PowerNorm
from matplotlib.patches import Rectangle

from skimage.measure import label, regionprops
from skimage.color import rgb2lab, rgb2gray
from skimage.feature import peak_local_max
from skimage import filters, morphology, segmentation as seg
from skimage.metrics import structural_similarity as ssim

from sklearn.decomposition import PCA
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score

from IPython.display import display, clear_output
from scipy import ndimage


import warnings
from sklearn.exceptions import ConvergenceWarning
from kneed import KneeLocator

In [None]:
#initialize
class GridSegmentation:
    def __init__(self, bf, sp, gif_counter=0, is_gif=False, block_size=11, clusters=0, mask_threshold=0.5,
                 max_k_clusters=15, threshold=0.2, sigma=5, peak_square_size=11, real_bg_sigma=7, peak_threshold = 2, facecolor_visible=0,
                 edge_points=0, intensity_weight=1, plot_dp_log=True, print_cluster_boxes=True):

        #main variables
        self.bf = bf
        self.sp = sp
        self.gif_counter = gif_counter
        self.is_gif = is_gif
        self.block_size = block_size
        self.clusters = clusters
        self.max_k_clusters = max_k_clusters
        self.threshold = threshold
        self.sigma = sigma
        self.edge_points = edge_points
        self.plot_dp_log = plot_dp_log
        self.print_cluster_boxes = print_cluster_boxes
        self.intensity_weight = intensity_weight
        self.peak_square_size = peak_square_size
        self.real_bg_sigma = real_bg_sigma
        self.peak_threshold = peak_threshold
        self.mask_threshold = mask_threshold
        self.facecolor_visible = facecolor_visible

        #total number of blocks in terms of block size in real space    
        self.num_blocks_x = bf.shape[0] // block_size
        self.num_blocks_y = bf.shape[1] // block_size

        #stores coordinates x1 x2 y1 y2 of each block
        self.block_coordinates = []
    
        #corresponding dp for valid blocks in real space
        self.valid_dp_logs = []
        
        #used to group valid dp by cluster
        self.clustered_data = {}

        self.peak_locations = []

        #stores calculated similarity and corresponding real space indices used for clustering
        self.features = np.zeros((len(self.valid_dp_logs), len(self.peak_locations) + 1))

        self.rs_mask = np.zeros(self.bf.shape, dtype=np.bool_)
    
    def enhanced_dp_sum(self):
        #masking real space
        blurred_bf = ndimage.gaussian_filter(self.bf.astype(np.float32), self.real_bg_sigma)
        thresh_otsu_gaussian = filters.threshold_otsu(blurred_bf)
        self.rs_mask = blurred_bf >= thresh_otsu_gaussian
        
        #taking diffraction space from masked real space
        dp_mask = stim.mask_real_space(self.sp, self.rs_mask)
        dp_mask_log = np.log1p(dp_mask)
        
        ### masking center
        center_x, center_y = dp_mask_log.shape[0] / 2, dp_mask_log.shape[1] / 2
        radius = 30
        radius_edge = dp_mask_log.shape[0]-300
        
        # create grid of coordinates
        y, x = np.ogrid[:dp_mask_log.shape[0], :dp_mask_log.shape[1]]
        
        # Create a circular mask
        distance_from_center = np.sqrt((x - center_x)**2 + (y - center_y)**2)
        circular_mask = distance_from_center <= radius
        circular_mask_edge = distance_from_center >= radius_edge
        
        # Combine masks: Set the center and edge to 0
        combined_mask = circular_mask | circular_mask_edge
        
        # Apply the combined mask to the array
        dp_mask_log[combined_mask] = 0
        
        plt.imshow(dp_mask_log)
        plt.show()
    
        return dp_mask_log
    
    def show_all_peaks(self,dp_log, blobs):
        plt.imshow(dp_log, cmap='gray')
        plt.title('Detected Blobs')
        
        # Overlay blobs on the image and annotate with intensity values
        for blob in blobs:
            y, x = blob[:2]
            plt.scatter(x, y, color='red', s=30, marker='o', alpha=0.2)
            intensity = dp_log[int(y), int(x)]
            # plt.text(x, y, f'{intensity:.2f}', color='white', fontsize=8, ha='center', va='center')
        
        plt.show()
    
    #function to find and store peaks from summed dp
    def generate_all_dp_sum_non_zero(self):
    
        dp_log = self.enhanced_dp_sum()
        
        blurred_image = ndimage.gaussian_filter(dp_log.astype(np.float32), self.sigma)
        blobs = ncempy.algo.peak_find.peakFind2D(blurred_image, self.threshold)
        all_dp_sum = np.zeros(dp_log.shape)
        
        # Store intensity of each blob into main_labels if intensity is non-zero
        for blob in blobs:
            y, x = blob[:2]
            intensity = dp_log[int(y), int(x)]
            if intensity != 0:  # Only store if intensity is non-zero
                all_dp_sum[int(y), int(x)] = intensity
    
        self.peak_locations = np.transpose(np.nonzero(all_dp_sum))
    
        self.show_all_peaks(dp_log, blobs)
        
        return self.peak_locations
    
    #segments real space by block_size, 
    #finds valid blocks by only keeping blocks that has atleast 2 peaks in dp
    #create an array of peaks and dp for each corresponding real space indices
    
    def bf_grid_segment(self):
        for i in range(self.num_blocks_x):
            for j in range(self.num_blocks_y):
                
                #grab top left coordinates of block
                x1, y1 = j * self.block_size, i * self.block_size
            
                #grab bottom right coordinates of block
                x2, y2 = (j + 1) * self.block_size - 1, (i + 1) * self.block_size - 1
            
                #store coordinates
                self.block_coordinates.append((x1, y1, x2, y2))

                segment_mask = self.rs_mask[x1:x2,y1:y2]

                if np.mean(segment_mask) >= self.mask_threshold:
    
                    #find dp for that specific block coordinate in real space bf
                    dp = self.sp[x1:x2 + 1, y1:y2 + 1, :, :].sum(axis=(0, 1))
                    dp_log = np.log1p(dp)
                    self.valid_dp_logs.append((dp_log, i, j))
    
                # #peak finding for current diffraction space just to check if it's a background or not
                # blurred_image = ndimage.gaussian_filter(dp_log.astype(np.float32), self.sigma)
                # blobs = ncempy.algo.peak_find.peakFind2D(blurred_image, self.threshold)
    
                # #check if there are atleast 2 peak (ignores background)
                # if len(blobs) >= 2:
                #     #store valid block with it's corresponding dp and indices in real space
                #     self.valid_dp_logs.append((dp_log, i, j))
    
    #check if a valid square block in real space is the edge of a cluster 
    #(valid block = a block in real space where it's dp has at least 2 peaks)
    def is_at_edge(self, i, j):
            self.num_blocks_x, self.num_blocks_y = self.block_validity.shape
    
            #for all valid blocks, check at the edge
            if i > 0 and j > 0 and i < self.num_blocks_x - 1 and j < self.num_blocks_y - 1:
                if not self.block_validity[i - 1, j] or not self.block_validity[i + 1, j] or not self.block_validity[i, j - 1] or not self.block_validity[i, j + 1]:
                    return True
            return False
    
    #calculate similarity by comparing current dp to summed dp and store features for clustering
    
    def features_similarity(self):
    
        #go through each square grid in real space that is valid (valid_dp_logs == has peaks)
        #dp_log = diffraction space at that square grid
        # i,j indices of that real space square grid
        for idx, (dp_log, i, j) in enumerate(self.valid_dp_logs):
    
            #within that square grid, check all peak_locations from entire diffraction sum
            for loc_idx, (y, x) in enumerate(self.peak_locations):
    
                #check if a peak exists within 11x11 square around the actual peak
                half_size = self.peak_square_size // 2
                
                # 11x11 square boundary
                start_y = max(y - half_size, 0)
                end_y = min(y + half_size + 1, dp_log.shape[0])
                start_x = max(x - half_size, 0)
                end_x = min(x + half_size + 1, dp_log.shape[1])
    
                #threshold for what pixel intensity is a peak [default = 2]
                if np.any(dp_log[start_y:end_y,start_x:end_x] >= self.peak_threshold):
    
                    #reset intensity value to 0 before storing
                    intensity = 0
    
                    # Find the highest value in the 11x11 square peak then store as intensity
                    intensity = np.max(dp_log[start_y:end_y,start_x:end_x])
    
                    # if self.is_at_edge(i, j):
                    #     min_distance = float('inf')
                    #     for other_block, other_i, other_j in self.valid_dp_logs:
                    #         if (other_i, other_j) != (i, j):
                    #             distance = np.sqrt((other_i - i) ** 2 + (other_j - j) ** 2)
                    #             if distance < min_distance:
                    #                 min_distance = distance
                    #     #if so, add similarity points
                    #     similarity *= self.edge_points
                        
                    self.features[idx, loc_idx] = intensity*self.intensity_weight
                    # self.features[idx, -1] = similarity
                        
        self.features = np.array(self.features)
    
    #find the optimal number of clusters using the elbow method
    def elbow_method(self):
            distortions = []
            K = range(2, self.max_k_clusters + 1)
    
            for k in K:
                kmeans = KMeans(n_clusters=k, random_state=0, n_init=30)
                kmeans.fit(self.features)
                #collect all sum of squared distance for each cluster
                distortions.append(kmeans.inertia_)
    
            #find 2nd derivative to find differences between them
            # deltas = np.diff(distortions, 2)
            # #find index where difference is at minimum - offset
            # opt_index = np.argmin(deltas) - int(self.tangent_offset)
            # opt_k = K[opt_index]

            kn = KneeLocator(K, distortions, curve='convex', direction='decreasing')
            opt_k = kn.knee
    
            #plot tangent line
            plt.plot(K, distortions, 'bo-')
            plt.xlabel('Number of Clusters')
            plt.ylabel('Distortion (Inertia)')
            plt.title('Elbow Method For finding n_clusters')
            plt.axvline(x=opt_k, color='r', linestyle='--', label=f'Optimal k = {opt_k}')
    
            opt_index = opt_k - 2
            slope = (distortions[opt_index + 1] - distortions[opt_index - 1]) / (K[opt_index + 1] - K[opt_index - 1])
            intercept = distortions[opt_index] - slope * K[opt_index]
    
            tangent_offset_x = np.linspace(K[0], K[-1], 200)
            tangent_offset_y = slope * tangent_offset_x + intercept
    
            plt.plot(tangent_offset_x, tangent_offset_y, 'g--', label='tangent_offset Line at Optimal k')
            plt.legend()
            plt.grid(True)
            plt.show()
    
            return opt_k
    
    #creates and sorts clusters by color
    def cluster_colored_blocks(self, cluster_labels):
        
        #init a set of colors for each unique cluster label
        cluster_colors = plt.cm.hsv(np.linspace(0, 1, len(set(cluster_labels))))
    
        #init list to store legend handles for the plot.
        legend_handles = []
    
        #init dictionary to keep track of colors already added to the legend.
        added_colors = {}
    
        #init list to store coordinates of blocks in each cluster
        cluster_boxes = [[] for _ in range(len(set(cluster_labels)))]

        #init dictionary to store summed diffraction patterns for each cluster
        cluster_sums = {label: np.zeros(self.sp.shape[2:]) for label in set(cluster_labels)}
    
        #group each self.valid_dp_logs by each cluster
        for i, label in enumerate(cluster_labels):
            if label not in self.clustered_data:
                self.clustered_data[label] = []
            self.clustered_data[label].append(self.valid_dp_logs[i])
    
        plt.figure(figsize=(8, 8))
        plt.imshow(self.bf, cmap='gray')
        plt.title('Valid Squares on BF Image with Clusters (K-Means)')
    
         # iterate over each cluster and its corresponding items, using an index for colors
        for cluster_index, (cluster_label, cluster_items) in enumerate(self.clustered_data.items()):

            # select a color from the colormap for the current cluster, go through colors if there are more clusters than colors
            color = tuple(cluster_colors[cluster_index % len(cluster_colors)])

            #toggle facecolor
            facecolor = color if self.facecolor_visible else 'none'
       
    
            if color not in added_colors:
                # add a rectangle with the current color to the legend handles
                legend_handles.append(Rectangle((0, 0), 1, 1, edgecolor=color, facecolor=facecolor, alpha=1, label=f'Cluster {cluster_label}'))
                
                # mark this color as added with its corresponding cluster label
                added_colors[color] = cluster_label
                
            for dp_log, i, j in cluster_items:
                x1, y1, x2, y2 = self.block_coordinates[i * self.num_blocks_y + j]
    
                # create a rectangle to for the block with the current cluster's color
                square = Rectangle((y1, x1), self.block_size, self.block_size, edgecolor=color, facecolor=facecolor, alpha=1)
                plt.gca().add_patch(square)
                
                # store the coordinates of the current block to the cluster's list of boxes
                cluster_boxes[cluster_label].append((x1, y1, x2, y2))
                
                # sum diffraction pattern of each cluster 
                block_dp = self.sp[x1:x2 + 1, y1:y2 + 1, :, :].sum(axis=(0, 1))
                cluster_sums[cluster_label] += block_dp
                
        
        # filter out any empty clusters
        cluster_boxes = [boxes for boxes in cluster_boxes if boxes]
        
        # Sort the legend handles by cluster label (numerically)
        legend_handles = sorted(legend_handles, key=lambda x: int(x.get_label().split()[1]))
        plt.legend(handles=legend_handles, loc='upper right')
        plt.show()

        # plot the summed diffraction patterns for each cluster
        for label, summed_dp in cluster_sums.items():
            summed_dp_log = np.log1p(summed_dp)
            plt.imshow(summed_dp_log, cmap='viridis')
            plt.title(f'Summed Diffraction Space for Cluster {label}')
            plt.colorbar()
            plt.show()

    
    def main(self):

        #generate summed dp of whole real space
        self.peak_locations = self.generate_all_dp_sum_non_zero()
        
        #segments real space by block_size, 
        #finds valid blocks by only keeping blocks that has atleast 2 peaks in dp
        #create an array of peaks and dp for each corresponding real space indices
        self.bf_grid_segment()
        
        #checks if current dataset has no valid dp blocks
        if len(self.valid_dp_logs) == 0:
            raise ValueError("No valid blocks found")
    
        # # Initialize features array with zeros [for some reason I get an error if I do not re-init]
        self.features = np.zeros((len(self.valid_dp_logs), len(self.peak_locations) + 1))
    
        #calculate similarity and store features for clustering
        self.features_similarity()
    
        #if user defined cluster number, else if 0, use optimal cluster from elbow method
        if self.clusters:
            best_k = int(self.clusters)
        else:
            best_k = self.elbow_method()
    
        #kmeans clustering
        kmeans = KMeans(n_clusters=best_k, random_state=0, n_init=15)
        cluster_labels = kmeans.fit_predict(self.features)

        unique_clusters = np.unique(cluster_labels)

        print(len(unique_clusters))
    
        #creates, sorts and shows clusters by color
        self.cluster_colored_blocks(cluster_labels)

        #save plot as png for gif
        if self.is_gif:
            Path('./gif/').mkdir(parents=True, exist_ok=True)
            plt.savefig(os.path.join('./gif/', f'{self.gif_counter:03d}.png'))
        plt.show()
    
        #to print indices of each real space block
        if self.print_cluster_boxes:
            print("Cluster Boxes:")
            for index, boxes in enumerate(cluster_boxes):
                print(f"Cluster {index}: {boxes}")
    
        #to show dp of each block
        if self.plot_dp_log:
            for cluster_num, boxes in enumerate(cluster_boxes):
                for box in boxes:
                    x1, y1, x2, y2 = box
                    dp = self.sp[x1:x2 + 1, y1:y2 + 1, :, :].sum(axis=(0, 1))
                    dp_log = np.log1p(dp)
    
                    plt.figure(figsize=(3, 3))
                    plt.imshow(dp_log)
                    plt.title(f"Group {cluster_num}")
                    plt.colorbar()
                    plt.show()


In [None]:
#Example Usage

def update_plots(block_size, threshold, sigma, intensity_weight, clusters, peak_threshold, mask_threshold, facecolor_visible, real_bg_sigma):

    h5_path = '/data.h5'
    sp = stio.load_electron_counts(h5_path)
    
    #axis=(2,3) shows real space
    bf = sp.sum(axis=(2,3))
    segmenter = GridSegmentation(bf, sp, gif_counter=0, is_gif=False, block_size=block_size, clusters=clusters, max_k_clusters=15, real_bg_sigma=real_bg_sigma,
                                 threshold=threshold, sigma=sigma, peak_threshold = peak_threshold, mask_threshold=mask_threshold, facecolor_visible=facecolor_visible,
                                 edge_points=0,intensity_weight=intensity_weight, plot_dp_log=False, print_cluster_boxes=False)
    segmenter.main()

block_size_slider = widgets.IntSlider(value=11, min=5, max=20, step=1, description='Grid Size:', style={'description_width': 'initial'})
threshold_slider = widgets.FloatSlider(value=0.2, min=0.1, max=1.0, step=0.01, description='Peak Finding Threshold:', style={'description_width': 'initial'})
sigma_slider = widgets.FloatSlider(value=5, min=1, max=10, step=0.1, description='Peak Finding Sigma:', style={'description_width': 'initial'})
real_bg_sigma_slider = widgets.FloatSlider(value=7, min=1, max=10, step=0.1, description='RS Background Sigma:', style={'description_width': 'initial'})
intensity_slider = widgets.FloatSlider(value=1, min=1, max=1000, step=1, description='Intensity Weight:', style={'description_width': 'initial'})
cluster_slider = widgets.FloatSlider(value=0, min=0, max=15, step=1, description='Num Clusters:', style={'description_width': 'initial'})
peak_threshold_slider = widgets.FloatSlider(value=2, min=0, max=10, step=0.01, description='Peak Intensity Threshold:', style={'description_width': 'initial'})
mask_threshold_silder = widgets.FloatSlider(value=0.5, min=0, max=1, step=0.01, description='Peak Percent Threshold:', style={'description_width': 'initial'})
facecolor_visible_slider = widgets.FloatSlider(value=0, min=0, max=1, step=1, description='Facecolor:', style={'description_width': 'initial'})
widgets.interact(update_plots, block_size=block_size_slider, real_bg_sigma=real_bg_sigma_slider,
                 threshold=threshold_slider, sigma=sigma_slider, intensity_weight=intensity_slider, facecolor_visible=facecolor_visible_slider,
                 clusters=cluster_slider, peak_threshold=peak_threshold_slider, mask_threshold=mask_threshold_silder)