In [1]:
## !/usr/bin/env python
# coding: utf-8

# # SEM Crystal Analyzer
# 
# This notebook provides an interactive tool for analyzing crystals in SEM (Scanning Electron Microscope) images. It includes advanced preprocessing, segmentation, and detection methods.
# 
# ## Features:
# - Multiple thresholding methods (Otsu, Adaptive, Canny, Manual)
# - CLAHE (Contrast Limited Adaptive Histogram Equalization)
# - Advanced segmentation (Watershed, Voronoi, Connected Components)
# - Shape-specific crystal detection
# - Automatic removal of border-touching crystals
# - Detailed measurement and export capabilities

# ## 1. Install Required Packages

# In[1]:


# Install required packages if not already installed
import subprocess
import sys

def install(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])

# List of required packages
packages = [
    'opencv-python',
    'numpy',
    'matplotlib',
    'pandas',
    'scipy',
    'scikit-image',
    'ipywidgets',
    'pillow'
]

print("Installing required packages...")
for package in packages:
    try:
        __import__(package.replace('-', '_'))
        print(f"✓ {package} already installed")
    except ImportError:
        print(f"Installing {package}...")
        install(package)
        print(f"✓ {package} installed")

print("\nAll packages installed successfully!")


# ## 2. Import Libraries and Define the Crystal Counter Class

# In[2]:


# Crystal Counter for SEM Images - Complete Version
# Adapted from various sources including proven SEM analysis techniques

import cv2
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy import ndimage
from skimage import filters, measure, morphology
from skimage.feature import peak_local_max
import ipywidgets as widgets
from IPython.display import display, clear_output
import io
import base64
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

class CrystalCounter:
    def __init__(self):
        self.image = None
        self.processed_image = None
        self.results = None
        self.setup_ui()
    
    def setup_ui(self):
        """Create interactive widgets for Voila"""
        # File upload widget
        self.upload_widget = widgets.FileUpload(
            accept='image/*',
            multiple=False,
            description='Upload SEM Image'
        )
        
        # Image preprocessing parameters
        self.apply_clahe = widgets.Checkbox(
            value=True,
            description='Apply CLAHE',
            tooltip='Contrast Limited Adaptive Histogram Equalization'
        )
        
        self.clahe_grid_size = widgets.IntSlider(
            value=25, min=3, max=50, step=2,
            description='CLAHE Grid Size:'
        )

        self.clip_limit = widgets.IntSlider(
            value=40, min=0, max=50, step=1,
            description='CLAHE Clip Limit:'
        )
        
        self.gaussian_blur = widgets.IntSlider(
            value=11, min=1, max=21, step=2,
            description='Gaussian Blur:'
        )

        self.gaussian_blur_sigma = widgets.FloatText(
            value=0,
            description='Gaussian Blur Sigma:'
        )
        
        self.threshold_type = widgets.Dropdown(
            options=['otsu', 'adaptive', 'canny', 'manual'],
            value='otsu',
            description='Threshold Method:'
        )
        
        self.manual_threshold = widgets.IntSlider(
            value=127, min=0, max=255,
            description='Manual Threshold:',
            disabled=True
        )
        
        # Adaptive threshold parameters
        self.adaptive_block_size = widgets.IntSlider(
            value=17, min=3, max=51, step=2,
            description='Adaptive Block Size:',
            disabled=True
        )
        
        self.adaptive_c = widgets.IntSlider(
            value=4, min=1, max=20,
            description='Adaptive C:',
            disabled=True
        )
        
        # Morphological parameters
        self.morph_kernel_size = widgets.IntSlider(
            value=3, min=3, max=15, step=2,
            description='Morph Kernel Size:'
        )
        
        self.morph_iterations = widgets.IntSlider(
            value=1, min=1, max=5,
            description='Morph Iterations:'
        )
        
        self.extract_edges = widgets.Checkbox(
            value=True,
            description='Extract Edges',
            tooltip='Extract edges using morphological gradient'
        )
        
        # Segmentation method
        self.segmentation_method = widgets.Dropdown(
            options=['watershed', 'voronoi', 'connected_components'],
            value='watershed',
            description='Segmentation:'
        )
        
        # Crystal detection parameters
        self.min_area = widgets.IntSlider(
            value=100, min=10, max=5000,
            description='Min Crystal Area:'
        )
        
        self.max_area = widgets.IntSlider(
            value=10000, min=100, max=100000,
            description='Max Crystal Area:'
        )
        
        # Shape detection parameters
        self.shape_selector = widgets.Dropdown(
            options=['all', 'hexagonal', 'cubic', 'circular'],
            value='all',
            description='Crystal Shape:'
        )
        
        # Process button
        self.process_button = widgets.Button(
            description='Process Image',
            button_style='success'
        )
        
        # Export button
        self.export_button = widgets.Button(
            description='Export Results',
            button_style='info'
        )
        
        # Output areas
        self.output = widgets.Output()
        self.results_output = widgets.Output()
        
        # Setup callbacks
        self.upload_widget.observe(self.on_upload, names='value')
        self.gaussian_blur.observe(self.process_image,names='value')
        self.gaussian_blur_sigma.observe(self.process_image,names='value')
        self.clip_limit.observe(self.process_image,names='value')
        self.clahe_grid_size.observe(self.process_image,names='value')
        self.threshold_type.observe(self.process_image,names='value')
        self.manual_threshold.observe(self.process_image,names='value')
        self.manual_threshold.observe(self.process_image,names='value')
        self.adaptive_block_size.observe(self.process_image,names='value')
        self.adaptive_c.observe(self.process_image,names='value')
        self.morph_kernel_size.observe(self.process_image,names='value')
        self.morph_iterations.observe(self.process_image,names='value')
        self.extract_edges.observe(self.process_image,names='value')
        self.segmentation_method.observe(self.process_image,names='value')
        self.min_area.observe(self.process_image,names='value')
        self.max_area.observe(self.process_image,names='value')
        self.shape_selector.observe(self.process_image,names='value')
        self.process_button.on_click(self.process_image)
        self.export_button.on_click(self.export_results)
        self.threshold_type.observe(self.on_threshold_type_change, names='value')
        self.apply_clahe.observe(self.on_clahe_change, names='value')
    
    def on_threshold_type_change(self, change):
        """Enable/disable threshold-specific parameters"""
        if self.threshold_type.value == 'manual':
            self.manual_threshold.disabled = False
            self.adaptive_block_size.disabled = True
            self.adaptive_c.disabled = True
        elif self.threshold_type.value == 'adaptive':
            self.manual_threshold.disabled = True
            self.adaptive_block_size.disabled = False
            self.adaptive_c.disabled = False
        else:  # otsu or canny
            self.manual_threshold.disabled = True
            self.adaptive_block_size.disabled = True
            self.adaptive_c.disabled = True
    
    def on_clahe_change(self, change):
        """Enable/disable CLAHE grid size"""
        self.clahe_grid_size.disabled = not self.apply_clahe.value
        self.clip_limit.disabled = not self.apply_clahe.value
    
    def display_ui(self):
        """Display the complete UI for Voila"""
        # Group widgets in boxes
        upload_box = widgets.VBox([
            widgets.HTML("<h2>Crystal Counter for SEM Images</h2>"),
            self.upload_widget
        ])
        
        preprocessing_box = widgets.VBox([
            widgets.HTML("<h3>Preprocessing Parameters</h3>"),
            self.apply_clahe,
            self.clahe_grid_size,
            self.clip_limit,
            self.gaussian_blur,
            self.gaussian_blur_sigma,
            widgets.HTML("<h4>Thresholding</h4>"),
            self.threshold_type,
            self.manual_threshold,
            self.adaptive_block_size,
            self.adaptive_c,
            widgets.HTML("<h4>Morphology</h4>"),
            self.morph_kernel_size,
            self.morph_iterations,
            self.extract_edges
        ])
        
        detection_box = widgets.VBox([
            widgets.HTML("<h3>Detection Parameters</h3>"),
            self.segmentation_method,
            self.shape_selector,
            self.min_area,
            self.max_area
        ])
        
        controls_box = widgets.HBox([
            self.process_button,
            self.export_button
        ])
        
        # Arrange in columns
        params_column = widgets.VBox([
            preprocessing_box,
            detection_box
        ], layout=widgets.Layout(width='400px'))
        
        output_column = widgets.VBox([
            self.output,
            self.results_output
        ], layout=widgets.Layout(width='100%'))
        
        main_ui = widgets.VBox([
            upload_box,
            widgets.HBox([
                params_column,
                output_column
            ]),
            controls_box
        ])
        
        display(main_ui)
    
    def on_upload(self, change):
        """Handle image upload"""
        if len(self.upload_widget.value) > 0:
            # Get the first uploaded file
            uploaded_file = self.upload_widget.value[0]
            
            # For TIF files, we need to use PIL first, then convert to OpenCV format
            try:
                # Read the image data
                image_data = uploaded_file['content']
                
                # Try to open with PIL first (better TIF support)
                pil_image = Image.open(io.BytesIO(image_data))
                
                # Convert to RGB if necessary (some TIF files might be in different modes)
                if pil_image.mode != 'RGB':
                    pil_image = pil_image.convert('RGB')
                
                # Convert PIL image to numpy array
                image_array = np.array(pil_image)
                # Remove lower stripe
                print(image_array.shape)
                cutoff = int(0.9*image_array.shape[0])
                image_array = image_array[:cutoff,:,:]
                # Convert RGB to BGR for OpenCV
                self.image = cv2.cvtColor(image_array, cv2.COLOR_RGB2BGR)
                
                with self.output:
                    clear_output()
                    print("✅ Image uploaded successfully!")
                    print(f"Image dimensions: {self.image.shape}")
                    print(f"Image format: {uploaded_file['name']}")
                    
            except Exception as e:
                with self.output:
                    clear_output()
                    print(f"❌ Error loading image: {str(e)}")
                    print("Please ensure the file is a valid image format.")
    
    def preprocess_image(self):
        """Enhanced preprocessing for SEM images based on proven methods"""
        if self.image is None:
            return None, None
        
        # Convert to grayscale
        gray = cv2.cvtColor(self.image, cv2.COLOR_BGR2GRAY)
        
        # Apply CLAHE if enabled
        if self.apply_clahe.value:
            clahe = cv2.createCLAHE(clipLimit=self.clip_limit.value, tileGridSize=(self.clahe_grid_size.value, self.clahe_grid_size.value))
            enhanced = clahe.apply(gray)
        else:
            enhanced = gray.copy()
        
        # Apply Gaussian blur
        blur_size = self.gaussian_blur.value
        blur_size_sigma = self.gaussian_blur_sigma.value
        if blur_size % 2 == 0:  # Ensure odd size
            blur_size += 1
        blurred = cv2.GaussianBlur(enhanced, (blur_size, blur_size), blur_size_sigma)
        
        # Apply thresholding based on method
        if self.threshold_type.value == 'otsu':
            _, binary = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
        elif self.threshold_type.value == 'adaptive':
            block_size = self.adaptive_block_size.value
            if block_size % 2 == 0:  # Ensure odd size
                block_size += 1
            binary = cv2.adaptiveThreshold(
                blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, 
                cv2.THRESH_BINARY_INV, block_size, self.adaptive_c.value
            )
        elif self.threshold_type.value == 'canny':
            # Use Canny edge detection
            threshold = self._get_brightness_threshold(blurred, 1/6)
            binary = cv2.Canny(blurred, threshold1=0, threshold2=threshold, apertureSize=3, L2gradient=False)
        else:  # manual
            _, binary = cv2.threshold(blurred, self.manual_threshold.value, 255, cv2.THRESH_BINARY_INV)
        
        # Enhanced morphological operations
        kernel_size = self.morph_kernel_size.value
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (kernel_size, kernel_size))
        
        if self.threshold_type.value == 'canny':
            # For Canny, we already have edges, so minimal processing
            cleaned = binary
        else:
            # Remove noise
            cleaned = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel, iterations=self.morph_iterations.value)
            # Close gaps
            cleaned = cv2.morphologyEx(cleaned, cv2.MORPH_CLOSE, kernel, iterations=self.morph_iterations.value)
            
            # For non-Canny methods, we might want to extract edges
            if self.extract_edges.value and self.threshold_type.value != 'canny':
                cleaned = cv2.morphologyEx(cleaned, cv2.MORPH_GRADIENT, kernel)
        
        return enhanced, cleaned
    
    def _get_brightness_threshold(self, image, pixel_percentage):
        """Get brightness threshold for given percentage of pixels"""
        if pixel_percentage > 1:
            return 255
        elif pixel_percentage <= 0:
            return 0
        
        # Calculate histogram
        hist = cv2.calcHist([image], [0], None, [256], [0, 256])
        
        # Calculate cumulative distribution
        pixel_count = image.shape[0] * image.shape[1]
        threshold_count = pixel_percentage * pixel_count
        cumsum = 0
        
        for i in range(256):
            cumsum += hist[i]
            if cumsum >= threshold_count:
                return i
        
        return 255
    
    def remove_border_crystals(self, markers):
        """Remove crystals touching the image border"""
        # Get unique labels touching the borders
        border_labels = set()
        
        # Check top and bottom borders
        border_labels.update(np.unique(markers[0, :]))
        border_labels.update(np.unique(markers[-1, :]))
        
        # Check left and right borders
        border_labels.update(np.unique(markers[:, 0]))
        border_labels.update(np.unique(markers[:, -1]))
        
        # Remove background and border labels
        border_labels.discard(-1)  # Remove border marker
        border_labels.discard(0)   # Remove background
        border_labels.discard(1)   # Remove background
        
        # Create a copy and remove border crystals
        markers_cleaned = markers.copy()
        for label in border_labels:
            markers_cleaned[markers == label] = 1
        
        return markers_cleaned
    
    def watershed_segmentation(self, binary_image):
        """Apply watershed segmentation to separate touching crystals"""
        # Compute distance transform
        dist_transform = cv2.distanceTransform(binary_image, cv2.DIST_L2, 5)
        
        # Find sure foreground area (crystal centers)
        # Using a dynamic threshold based on the max distance
        threshold_value = dist_transform.max() / np.e  # Using e as in the old code
        _, sure_fg = cv2.threshold(dist_transform, threshold_value, 255, cv2.THRESH_BINARY)
        sure_fg = np.uint8(sure_fg)
        
        # Find sure background area
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
        sure_bg = cv2.dilate(binary_image, kernel, iterations=3)
        
        # Find unknown region
        unknown = cv2.subtract(sure_bg, sure_fg)
        
        # Marker labelling
        _, markers = cv2.connectedComponents(sure_fg)
        
        # Add 1 to all labels so that sure background is not 0, but 1
        markers = markers + 1
        
        # Mark the region of unknown with zero
        markers[unknown == 255] = 0
        
        # Apply watershed
        img_for_watershed = cv2.cvtColor(binary_image, cv2.COLOR_GRAY2BGR)
        markers = cv2.watershed(img_for_watershed, markers)
        
        # Remove crystals touching borders
        markers = self.remove_border_crystals(markers)
        
        # Create binary mask from watershed result
        segmented = np.zeros_like(binary_image)
        segmented[markers > 1] = 255
        
        return segmented, markers
    
    def voronoi_segmentation(self, binary_image):
        """Apply Voronoi-based segmentation"""
        # Find ultimate eroded points (crystal centers)
        dist_transform = cv2.distanceTransform(binary_image, cv2.DIST_L2, 5)
        
        # Find local maxima as crystal centers
        coordinates = peak_local_max(
            dist_transform, 
            min_distance=int(dist_transform.max() / np.e),
            exclude_border=False
        )
        
        # Store Voronoi centers for visualization
        self.voronoi_centers = coordinates
        
        # Create Voronoi diagram
        if len(coordinates) > 0:
            # Create markers for watershed from Voronoi centers
            markers = np.zeros_like(binary_image, dtype=np.int32)
            for i, coord in enumerate(coordinates):
                markers[coord[0], coord[1]] = i + 2  # Start from 2 (1 is background)
            
            # Use watershed with Voronoi markers
            markers_dilated = cv2.dilate(markers.astype(np.uint8), None, iterations=2)
            _, markers_labeled = cv2.connectedComponents(markers_dilated.astype(np.uint8))
            markers_labeled = markers_labeled + 1
            markers_labeled[binary_image == 0] = 1  # Background
            
            # Apply watershed
            img_for_watershed = cv2.cvtColor(binary_image, cv2.COLOR_GRAY2BGR)
            final_markers = cv2.watershed(img_for_watershed, markers_labeled)
            
            # Store original markers before border removal for visualization
            self.voronoi_markers_original = final_markers.copy()
            
            # Remove crystals touching borders
            final_markers = self.remove_border_crystals(final_markers)
            
            # Create segmented image
            segmented = np.zeros_like(binary_image)
            segmented[final_markers > 1] = 255
            
            return segmented, final_markers
        else:
            self.voronoi_centers = []
            self.voronoi_markers_original = np.ones_like(binary_image)
            return binary_image, np.ones_like(binary_image)
    
    def create_voronoi_visualization(self, image_shape, markers):
        """Create a visualization of Voronoi diagram with borders"""
        # Create a color image for visualization
        voronoi_viz = np.zeros((image_shape[0], image_shape[1], 3), dtype=np.uint8)
        
        # Generate random colors for each region
        unique_labels = np.unique(markers)
        colors = {}
        for label in unique_labels:
            if label > 0:  # Skip background and borders
                colors[label] = (np.random.randint(50, 200), 
                               np.random.randint(50, 200), 
                               np.random.randint(50, 200))
        
        # Color each region
        for label, color in colors.items():
            voronoi_viz[markers == label] = color
        
        # Highlight borders in white
        voronoi_viz[markers == -1] = (255, 255, 255)
        
        # Mark Voronoi centers if available
        if hasattr(self, 'voronoi_centers') and len(self.voronoi_centers) > 0:
            for center in self.voronoi_centers:
                cv2.circle(voronoi_viz, (center[1], center[0]), 3, (255, 0, 0), -1)
                cv2.circle(voronoi_viz, (center[1], center[0]), 4, (255, 255, 255), 1)
        
        return voronoi_viz
    
    def extract_voronoi_borders(self, markers):
        """Extract just the Voronoi borders as a binary image"""
        # Create binary image of borders
        borders = np.zeros(markers.shape, dtype=np.uint8)
        borders[markers == -1] = 255
        
        # Optionally dilate borders to make them more visible
        kernel = cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3))
        borders_dilated = cv2.dilate(borders, kernel, iterations=1)
        
        return borders, borders_dilated
    
    def get_voronoi_statistics(self, markers):
        """Calculate statistics about Voronoi regions"""
        unique_labels = np.unique(markers)
        valid_labels = [l for l in unique_labels if l > 1]  # Exclude background and borders
        
        stats = {
            'num_regions': len(valid_labels),
            'region_sizes': [],
            'region_neighbor_counts': []
        }
        
        # Calculate size of each region
        for label in valid_labels:
            region_size = np.sum(markers == label)
            stats['region_sizes'].append(region_size)
        
        # Calculate number of neighbors for each region (simplified)
        # This is a basic implementation - for accurate neighbor counting,
        # you'd need to analyze the actual Voronoi diagram structure
        
        return stats
    
    def detect_crystals(self, binary_image, markers=None):
        """Detect crystals based on shape and size criteria"""
        # If we have watershed markers, use them to find contours
        if markers is not None:
            # Create an image where each region has a unique value
            crystal_mask = np.zeros_like(binary_image)
            crystal_mask[markers > 1] = 255
            
            # Find contours from the mask
            contours, _ = cv2.findContours(crystal_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        else:
            # Find contours directly from binary image
            contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        detected_crystals = []
        
        for contour in contours:
            # Calculate properties
            contour = cv2.convexHull(contour)
            area = cv2.contourArea(contour)
            
            # Filter by area
            if area < self.min_area.value or area > self.max_area.value:
                continue
            
            # Calculate shape properties
            perimeter = cv2.arcLength(contour, True)
            if perimeter == 0:
                continue
                
            circularity = 4 * np.pi * area / (perimeter * perimeter)
            
            # Calculate convexity
            hull = cv2.convexHull(contour)
            hull_area = cv2.contourArea(hull)
            if hull_area == 0:
                continue
            convexity = area / hull_area
            
            # Shape-specific filtering
            if self.shape_selector.value != 'all':
                if not self.is_target_shape(contour, circularity):
                    continue
            
            # Calculate additional properties
            moments = cv2.moments(contour)
            if moments['m00'] != 0:
                cx = int(moments['m10'] / moments['m00'])
                cy = int(moments['m01'] / moments['m00'])
            else:
                cx, cy = 0, 0
            
            # Fit ellipse for orientation
            if len(contour) >= 5:
                ellipse = cv2.fitEllipse(contour)
                angle = ellipse[2]
                major_axis = max(ellipse[1])
                minor_axis = min(ellipse[1])
                aspect_ratio = major_axis / minor_axis if minor_axis > 0 else 0
            else:
                angle = 0
                aspect_ratio = 1
                major_axis = minor_axis = np.sqrt(area / np.pi) * 2
            
            detected_crystals.append({
                'contour': contour,
                'area': area,
                'perimeter': perimeter,
                'circularity': circularity,
                'convexity': convexity,
                'center_x': cx,
                'center_y': cy,
                'angle': angle,
                'aspect_ratio': aspect_ratio,
                'major_axis': major_axis,
                'minor_axis': minor_axis
            })
        
        return detected_crystals
    
    def is_target_shape(self, contour, circularity):
        """Determine if contour matches target shape"""
        # Approximate contour to polygon
        epsilon = 0.02 * cv2.arcLength(contour, True)
        approx = cv2.approxPolyDP(contour, epsilon, True)
        
        if self.shape_selector.value == 'hexagonal':
            # Look for 6-sided polygons
            return len(approx) >= 5 and len(approx) <= 8
        elif self.shape_selector.value == 'cubic':
            # Look for 4-sided polygons (squares/rectangles)
            return len(approx) >= 3 and len(approx) <= 5
        elif self.shape_selector.value == 'circular':
            # High circularity
            return circularity > 0.7
        
        return True
    
    def visualize_results(self, gray_image, crystals):
        """Create visualization of detected crystals"""
        # Create output image
        result_image = cv2.cvtColor(gray_image, cv2.COLOR_GRAY2BGR)
        
        # Draw detected crystals
        for i, crystal in enumerate(crystals):
            # Draw contour
            cv2.drawContours(result_image, [crystal['contour']], -1, (0, 255, 0), 2)
            
            # Draw center point
            cv2.circle(result_image, (crystal['center_x'], crystal['center_y']), 3, (255, 0, 0), -1)
            
            # Add label
            cv2.putText(result_image, str(i+1), 
                       (crystal['center_x']-10, crystal['center_y']-10),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 1)
        
        return result_image
    
    def process_image(self, button):
        """Main processing function"""
        if self.image is None:
            with self.output:
                clear_output()
                print("❌ Please upload an image first!")
            return
        
        with self.output:
            clear_output()
            print("🔄 Processing image...")
        
        try:
            # Preprocess image
            enhanced, binary = self.preprocess_image()
            
            # Apply segmentation if selected
            markers = None
            segmented = binary
            
            if self.segmentation_method.value == 'watershed':
                segmented, markers = self.watershed_segmentation(binary)
            elif self.segmentation_method.value == 'voronoi':
                segmented, markers = self.voronoi_segmentation(binary)
            # For connected_components, we just use the binary image as is
            
            # Detect crystals
            crystals = self.detect_crystals(segmented, markers)
            
            # Create visualization
            result_image = self.visualize_results(enhanced, crystals)
            
            # Store results
            self.results = crystals
            self.processed_image = result_image
            
            # Display results
            with self.output:
                clear_output()
                
                # Create subplot figure - add extra row if Voronoi is selected
                if self.segmentation_method.value == 'voronoi':
                    fig, axes = plt.subplots(3, 2, figsize=(15, 18))
                else:
                    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
                
                # Original image
                axes[0, 0].imshow(cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB))
                axes[0, 0].set_title('Original Image')
                axes[0, 0].axis('off')
                
                # Binary/Segmented image
                axes[0, 1].imshow(segmented, cmap='gray')
                axes[0, 1].set_title(f'Segmented Image ({self.segmentation_method.value})')
                axes[0, 1].axis('off')
                
                # Results
                axes[1, 0].imshow(cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB))
                axes[1, 0].set_title(f'Detected Crystals: {len(crystals)}')
                axes[1, 0].axis('off')
                
                # Size distribution histogram
                if crystals:
                    areas = [c['area'] for c in crystals]
                    circularity_values = [c['circularity'] for c in crystals]
                    
                    axes[1, 1].hist(areas, bins=20, alpha=0.7, color='green', edgecolor='black')
                    axes[1, 1].set_title('Crystal Size Distribution')
                    axes[1, 1].set_xlabel('Area (pixels²)')
                    axes[1, 1].set_ylabel('Count')
                    axes[1, 1].grid(True, alpha=0.3)
                    
                    print(f"Mean circularity: {np.mean(circularity_values):.3f}")
                    print(f"Mean area: {np.mean(areas):.1f} ± {np.std(areas):.1f} px²")
                else:
                    axes[1, 1].text(0.5, 0.5, 'No crystals detected', 
                                   ha='center', va='center', transform=axes[1, 1].transAxes)
                    axes[1, 1].set_title('Crystal Size Distribution')
                
                # If Voronoi segmentation, show additional visualizations
                if self.segmentation_method.value == 'voronoi' and hasattr(self, 'voronoi_markers_original'):
                    # Voronoi diagram with borders
                    voronoi_viz = self.create_voronoi_visualization(binary.shape, self.voronoi_markers_original)
                    axes[2, 0].imshow(voronoi_viz)
                    axes[2, 0].set_title('Voronoi Diagram (with borders and centers)')
                    axes[2, 0].axis('off')
                    
                    # Overlay Voronoi borders on original image
                    overlay = cv2.cvtColor(enhanced, cv2.COLOR_GRAY2BGR)
                    # Make borders more visible by dilating them
                    border_mask = (self.voronoi_markers_original == -1)
                    border_mask_dilated = cv2.dilate(border_mask.astype(np.uint8), 
                                                    cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)))
                    overlay[border_mask_dilated > 0] = [0, 255, 255]  # Yellow borders
                    
                    # Add Voronoi centers
                    if len(self.voronoi_centers) > 0:
                        for center in self.voronoi_centers:
                            cv2.circle(overlay, (center[1], center[0]), 3, (255, 0, 0), -1)
                    
                    axes[2, 1].imshow(overlay)
                    axes[2, 1].set_title('Voronoi Borders on Original')
                    axes[2, 1].axis('off')
                    
                    print(f"Number of Voronoi centers: {len(self.voronoi_centers)}")
                
                plt.tight_layout()
                plt.show()
                
                print(f"✅ Processing complete! Detected {len(crystals)} crystals.")
                
            # Display detailed results
            self.display_results_table()
            
        except Exception as e:
            with self.output:
                clear_output()
                print(f"❌ Error processing image: {str(e)}")
                import traceback
                traceback.print_exc()
    
    def display_results_table(self):
        """Display detailed results in a table"""
        if not self.results:
            return
        
        with self.results_output:
            clear_output()
            
            # Create results DataFrame
            results_data = []
            for i, crystal in enumerate(self.results):
                results_data.append({
                    'Crystal_ID': i + 1,
                    'Area_px²': round(crystal['area'], 2),
                    'Perimeter_px': round(crystal['perimeter'], 2),
                    'Circularity': round(crystal['circularity'], 3),
                    'Convexity': round(crystal['convexity'], 3),
                    'Center_X': crystal['center_x'],
                    'Center_Y': crystal['center_y'],
                    'Angle_degrees': round(crystal['angle'], 1),
                    'Aspect_Ratio': round(crystal['aspect_ratio'], 2),
                    'Major_Axis_px': round(crystal['major_axis'], 2),
                    'Minor_Axis_px': round(crystal['minor_axis'], 2)
                })
            
            df = pd.DataFrame(results_data)
            
            print("📊 Detailed Crystal Analysis Results:")
            print("=" * 60)
            display(df)
            
            # Summary statistics
            print("\n📈 Summary Statistics:")
            print(f"Total Crystals Detected: {len(self.results)}")
            print(f"Average Area: {df['Area_px²'].mean():.2f} ± {df['Area_px²'].std():.2f} px²")
            print(f"Average Circularity: {df['Circularity'].mean():.3f} ± {df['Circularity'].std():.3f}")
            print(f"Average Aspect Ratio: {df['Aspect_Ratio'].mean():.2f} ± {df['Aspect_Ratio'].std():.2f}")
    
    def export_results(self, button):
        """Export results to CSV"""
        if not self.results:
            with self.results_output:
                print("❌ No results to export! Please process an image first.")
            return
        
        try:
            # Create results DataFrame
            results_data = []
            for i, crystal in enumerate(self.results):
                results_data.append({
                    'Crystal_ID': i + 1,
                    'Area_px²': crystal['area'],
                    'Perimeter_px': crystal['perimeter'],
                    'Circularity': crystal['circularity'],
                    'Convexity': crystal['convexity'],
                    'Center_X': crystal['center_x'],
                    'Center_Y': crystal['center_y'],
                    'Angle_degrees': crystal['angle'],
                    'Aspect_Ratio': crystal['aspect_ratio'],
                    'Major_Axis_px': crystal['major_axis'],
                    'Minor_Axis_px': crystal['minor_axis']
                })
            
            df = pd.DataFrame(results_data)
            
            # Save to CSV (in Voila, this will download)
            csv_string = df.to_csv(index=False)
            
            # Create download link for Voila
            b64 = base64.b64encode(csv_string.encode()).decode()
            href = f'<a href="data:file/csv;base64,{b64}" download="crystal_analysis_results.csv">Download Results CSV</a>'
            
            with self.results_output:
                print("💾 Results exported successfully!")
                display(widgets.HTML(href))
                
        except Exception as e:
            with self.results_output:
                print(f"❌ Error exporting results: {str(e)}")

print("✅ Crystal Counter class defined successfully!")


# ## 3. Initialize and Run the Application

# In[3]:


# Initialize and display the application
def main():
    """Main function to run the Crystal Counter"""
    app = CrystalCounter()
    app.display_ui()
    return app

# Create the application
print("Initializing SEM Crystal Analyzer...")
crystal_app = main()
print("\n✅ Application ready! Please upload an SEM image to begin.")


# ## 4. Usage Instructions
# 
# ### Getting Started:
# 1. **Upload an Image**: Click the "Upload SEM Image" button and select your SEM image file (supports TIF, PNG, JPG, etc.)
# 
# ### Preprocessing Options:
# - **CLAHE**: Enable/disable Contrast Limited Adaptive Histogram Equalization for better contrast
# - **Gaussian Blur**: Adjust the blur kernel size to reduce noise
# - **Threshold Method**:
#   - **Canny**: Best for high-contrast images with clear edges
#   - **Otsu**: Automatic thresholding, good for bimodal histograms
#   - **Adaptive**: Good for images with varying lighting
#   - **Manual**: Set your own threshold value
# 
# ### Segmentation Methods:
# - **Watershed**: Best for separating touching crystals
# - **Voronoi**: Uses Voronoi diagrams for segmentation
# - **Connected Components**: Simple segmentation for well-separated crystals
# 
# ### Detection Parameters:
# - **Crystal Shape**: Filter by specific shapes (hexagonal, cubic, circular) or detect all
# - **Min/Max Area**: Set the size range for valid crystals
# 
# ### Tips for Best Results:
# 1. Start with **Canny** edge detection for clear, high-contrast images
# 2. Enable **CLAHE** for images with poor contrast
# 3. Use **Watershed** segmentation when crystals are touching
# 4. Adjust **morphological operations** to clean up noise
# 5. Set appropriate **area limits** to filter out noise and artifacts

# ## 5. Example Processing Pipeline

# In[4]:


# Example: Optimal settings for different image types

print("📋 Recommended Settings for Common Scenarios:\n")

print("1. Clear, High-Contrast SEM Images:")
print("   - Threshold: Canny")
print("   - CLAHE: Disabled")
print("   - Segmentation: Connected Components")
print("   - Gaussian Blur: 3-5\n")

print("2. Low-Contrast or Noisy Images:")
print("   - Threshold: Otsu or Adaptive")
print("   - CLAHE: Enabled (Grid Size: 25)")
print("   - Segmentation: Watershed")
print("   - Gaussian Blur: 5-9\n")

print("3. Images with Touching Crystals:")
print("   - Threshold: Canny or Otsu")
print("   - CLAHE: As needed")
print("   - Segmentation: Watershed or Voronoi")
print("   - Morphology: 2-3 iterations\n")

print("4. Images with Specific Crystal Shapes:")
print("   - Set Crystal Shape filter accordingly")
print("   - Adjust Min/Max Area based on expected sizes")
print("   - Use appropriate threshold method for your contrast")


# ## 6. Advanced Features

# In[5]:


# Additional utility functions for batch processing (optional)

def process_multiple_images(image_paths, settings):
    """
    Process multiple images with the same settings
    
    Parameters:
    image_paths: list of image file paths
    settings: dictionary of processing parameters
    
    Returns:
    results: list of processing results for each image
    """
    results = []
    
    for path in image_paths:
        # Load image
        image = cv2.imread(path)
        
        # Apply settings and process
        # (Implementation would go here)
        
        results.append({
            'filename': path,
            'crystal_count': 0,  # Placeholder
            'mean_area': 0,      # Placeholder
            'mean_circularity': 0 # Placeholder
        })
    
    return results

# Scale bar detection function (if needed)
def detect_scale_bar(image):
    """
    Detect scale bar in SEM image and calculate pixel-to-unit conversion
    
    Parameters:
    image: input SEM image
    
    Returns:
    scale_factor: pixels per unit (e.g., pixels per micrometer)
    unit: detected unit (e.g., 'µm', 'nm')
    """
    # This would contain OCR and scale bar detection logic
    # For now, return placeholder values
    scale_factor = 10.0  # pixels per micrometer
    unit = 'µm'
    
    return scale_factor, unit

print("✅ Additional utility functions loaded")
print("\nNote: These are template functions that can be extended for:")
print("- Batch processing multiple images")
print("- Automatic scale bar detection")
print("- Statistical analysis across multiple samples")


# ## 7. Troubleshooting

# In[6]:


print("🔧 Troubleshooting Guide:\n")

print("Problem: No crystals detected")
print("Solutions:")
print("- Check threshold method - try different options")
print("- Adjust min/max area parameters")
print("- Enable CLAHE if image has low contrast")
print("- Reduce Gaussian blur if features are small\n")

print("Problem: Too many false detections")
print("Solutions:")
print("- Increase minimum area threshold")
print("- Increase morphological iterations")
print("- Use shape-specific filtering")
print("- Adjust threshold parameters\n")

print("Problem: Touching crystals not separated")
print("Solutions:")
print("- Use Watershed or Voronoi segmentation")
print("- Increase morphological erosion")
print("- Adjust segmentation parameters\n")

print("Problem: Image processing is slow")
print("Solutions:")
print("- Reduce image size before processing")
print("- Decrease morphological iterations")
print("- Use simpler segmentation method")


# ## 8. Save Your Work

# In[7]:


# Function to save the current configuration
def save_configuration(app, filename='sem_config.json'):
    """Save current parameter settings to a JSON file"""
    import json
    
    config = {
        'preprocessing': {
            'apply_clahe': app.apply_clahe.value,
            'clahe_grid_size': app.clahe_grid_size.value,
            'gaussian_blur': app.gaussian_blur.value,
            'threshold_type': app.threshold_type.value,
            'manual_threshold': app.manual_threshold.value,
            'adaptive_block_size': app.adaptive_block_size.value,
            'adaptive_c': app.adaptive_c.value,
            'morph_kernel_size': app.morph_kernel_size.value,
            'morph_iterations': app.morph_iterations.value,
            'extract_edges': app.extract_edges.value
        },
        'detection': {
            'segmentation_method': app.segmentation_method.value,
            'shape_selector': app.shape_selector.value,
            'min_area': app.min_area.value,
            'max_area': app.max_area.value
        }
    }
    
    with open(filename, 'w') as f:
        json.dump(config, f, indent=4)
    
    print(f"✅ Configuration saved to {filename}")

# Function to load a configuration
def load_configuration(app, filename='sem_config.json'):
    """Load parameter settings from a JSON file"""
    import json
    
    try:
        with open(filename, 'r') as f:
            config = json.load(f)
        
        # Apply preprocessing settings
        app.apply_clahe.value = config['preprocessing']['apply_clahe']
        app.clahe_grid_size.value = config['preprocessing']['clahe_grid_size']
        app.gaussian_blur.value = config['preprocessing']['gaussian_blur']
        app.threshold_type.value = config['preprocessing']['threshold_type']
        app.manual_threshold.value = config['preprocessing']['manual_threshold']
        app.adaptive_block_size.value = config['preprocessing']['adaptive_block_size']
        app.adaptive_c.value = config['preprocessing']['adaptive_c']
        app.morph_kernel_size.value = config['preprocessing']['morph_kernel_size']
        app.morph_iterations.value = config['preprocessing']['morph_iterations']
        app.extract_edges.value = config['preprocessing']['extract_edges']
        
        # Apply detection settings
        app.segmentation_method.value = config['detection']['segmentation_method']
        app.shape_selector.value = config['detection']['shape_selector']
        app.min_area.value = config['detection']['min_area']
        app.max_area.value = config['detection']['max_area']
        
        print(f"✅ Configuration loaded from {filename}")
        
    except FileNotFoundError:
        print(f"❌ Configuration file {filename} not found")
    except Exception as e:
        print(f"❌ Error loading configuration: {str(e)}")

print("Configuration save/load functions ready!")
print("\nTo save current settings: save_configuration(crystal_app)")
print("To load saved settings: load_configuration(crystal_app)")


# ## 11. Voronoi Segmentation Visualization
# 
# The Voronoi segmentation method provides unique insights into crystal distribution and boundaries.

# In[9]:


# Demonstration of Voronoi visualization capabilities

def demonstrate_voronoi_features(app):
    """
    Demonstrate the Voronoi segmentation visualization features
    
    This function shows:
    1. Voronoi centers (ultimate eroded points)
    2. Voronoi borders
    3. Colored regions
    4. Statistics about the segmentation
    """
    if app.results is None:
        print("❌ Please process an image with Voronoi segmentation first!")
        return
    
    if app.segmentation_method.value != 'voronoi':
        print("❌ Please select 'voronoi' as the segmentation method and reprocess!")
        return
    
    if hasattr(app, 'voronoi_markers_original'):
        # Extract borders
        borders, borders_dilated = app.extract_voronoi_borders(app.voronoi_markers_original)
        
        # Get statistics
        stats = app.get_voronoi_statistics(app.voronoi_markers_original)
        
        # Create visualization
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        
        # 1. Just the borders
        axes[0, 0].imshow(borders, cmap='gray')
        axes[0, 0].set_title('Voronoi Borders (Original)')
        axes[0, 0].axis('off')
        
        # 2. Dilated borders for better visibility
        axes[0, 1].imshow(borders_dilated, cmap='gray')
        axes[0, 1].set_title('Voronoi Borders (Dilated)')
        axes[0, 1].axis('off')
        
        # 3. Distance transform with centers
        _, binary = app.preprocess_image()
        dist_transform = cv2.distanceTransform(binary, cv2.DIST_L2, 5)
        axes[1, 0].imshow(dist_transform, cmap='hot')
        
        # Overlay centers
        if len(app.voronoi_centers) > 0:
            y_coords = [c[0] for c in app.voronoi_centers]
            x_coords = [c[1] for c in app.voronoi_centers]
            axes[1, 0].scatter(x_coords, y_coords, c='blue', s=50, marker='x')
        
        axes[1, 0].set_title('Distance Transform with Voronoi Centers')
        axes[1, 0].axis('off')
        
        # 4. Region size distribution
        if stats['region_sizes']:
            axes[1, 1].hist(stats['region_sizes'], bins=20, alpha=0.7, color='purple', edgecolor='black')
            axes[1, 1].set_title('Voronoi Region Size Distribution')
            axes[1, 1].set_xlabel('Region Size (pixels)')
            axes[1, 1].set_ylabel('Count')
            axes[1, 1].grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.show()
        
        # Print statistics
        print("\n📊 Voronoi Segmentation Statistics:")
        print(f"Number of Voronoi regions: {stats['num_regions']}")
        print(f"Number of Voronoi centers: {len(app.voronoi_centers)}")
        if stats['region_sizes']:
            print(f"Average region size: {np.mean(stats['region_sizes']):.1f} pixels")
            print(f"Min region size: {np.min(stats['region_sizes'])} pixels")
            print(f"Max region size: {np.max(stats['region_sizes'])} pixels")

print("✅ Voronoi visualization functions ready!")
print("\nTo see Voronoi features after processing with Voronoi segmentation:")
print("demonstrate_voronoi_features(crystal_app)")


# ## 12. Understanding the Voronoi Algorithm
# 
# The Voronoi segmentation method works as follows:
# 
# 1. **Find Crystal Centers**: Uses distance transform and peak detection to find ultimate eroded points (UEPs)
# 2. **Create Voronoi Diagram**: Each center becomes a seed for a Voronoi region
# 3. **Apply Watershed**: Uses watershed algorithm with Voronoi seeds to create borders
# 4. **Remove Border Crystals**: Eliminates incomplete crystals touching image edges
# 
# ### Advantages of Voronoi Segmentation:
# - Excellent for uniformly distributed crystals
# - Creates natural boundaries between regions
# - Can handle touching crystals well
# - Provides information about crystal packing and distribution
# 
# ### When to Use Voronoi:
# - When crystals are roughly similar in size
# - When you need to analyze spatial distribution
# - When crystals are touching but have clear centers
# - When you want to study crystal packing density

# In[10]:


print("\n" + "="*60)
print("🎉 SEM Crystal Analyzer Successfully Loaded!")
print("="*60)
print("\nYou can now:")
print("1. Upload an SEM image using the interface above")
print("2. Adjust parameters to optimize detection")
print("3. Process the image and view results")
print("4. Export results to CSV")
print("\nFor Voronoi visualization, select 'voronoi' as segmentation method")
print("and run: demonstrate_voronoi_features(crystal_app)")
print("\nFor help, refer to the usage instructions and troubleshooting guide above.")
print("\nHappy analyzing! 🔬")

Installing required packages...
Installing opencv-python...
✓ opencv-python installed
✓ numpy already installed
✓ matplotlib already installed
✓ pandas already installed
✓ scipy already installed
Installing scikit-image...
✓ scikit-image installed
✓ ipywidgets already installed
Installing pillow...
✓ pillow installed

All packages installed successfully!
✅ Crystal Counter class defined successfully!
Initializing SEM Crystal Analyzer...


VBox(children=(VBox(children=(HTML(value='<h2>Crystal Counter for SEM Images</h2>'), FileUpload(value=(), acce…


✅ Application ready! Please upload an SEM image to begin.
📋 Recommended Settings for Common Scenarios:

1. Clear, High-Contrast SEM Images:
   - Threshold: Canny
   - CLAHE: Disabled
   - Segmentation: Connected Components
   - Gaussian Blur: 3-5

2. Low-Contrast or Noisy Images:
   - Threshold: Otsu or Adaptive
   - CLAHE: Enabled (Grid Size: 25)
   - Segmentation: Watershed
   - Gaussian Blur: 5-9

3. Images with Touching Crystals:
   - Threshold: Canny or Otsu
   - CLAHE: As needed
   - Segmentation: Watershed or Voronoi
   - Morphology: 2-3 iterations

4. Images with Specific Crystal Shapes:
   - Set Crystal Shape filter accordingly
   - Adjust Min/Max Area based on expected sizes
   - Use appropriate threshold method for your contrast
✅ Additional utility functions loaded

Note: These are template functions that can be extended for:
- Batch processing multiple images
- Automatic scale bar detection
- Statistical analysis across multiple samples
🔧 Troubleshooting Guide:

Problem: 

In [27]:
img = np.zeros((100,100),np.uint8)
cv2.circle(img,(50,50),5,255)
print(np.transpose(np.where(img==255)).shape)

(28, 2)


In [33]:
cv2.GaussianBlur?

[31mDocstring:[39m
GaussianBlur(src, ksize, sigmaX[, dst[, sigmaY[, borderType[, hint]]]]) -> dst
.   @brief Blurs an image using a Gaussian filter.
.   
.   The function convolves the source image with the specified Gaussian kernel. In-place filtering is
.   supported.
.   
.   @param src input image; the image can have any number of channels, which are processed
.   independently, but the depth should be CV_8U, CV_16U, CV_16S, CV_32F or CV_64F.
.   @param dst output image of the same size and type as src.
.   @param ksize Gaussian kernel size. ksize.width and ksize.height can differ but they both must be
.   positive and odd. Or, they can be zero's and then they are computed from sigma.
.   @param sigmaX Gaussian kernel standard deviation in X direction.
.   @param sigmaY Gaussian kernel standard deviation in Y direction; if sigmaY is zero, it is set to be
.   equal to sigmaX, if both sigmas are zeros, they are computed from ksize.width and ksize.height,
.   respectively (see #getG