In [7]:
!pip install opencv-python scikit-image pillow



In [10]:
# Crystal Counter for SEM Images - Voila Compatible
# Adapted from open source solutions including CorrieGunter/particle_counter and pygempick

import cv2
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from scipy import ndimage
from skimage import filters, measure, morphology
import ipywidgets as widgets
from IPython.display import display, clear_output
import io
import base64
from PIL import Image
import warnings
warnings.filterwarnings('ignore')

class CrystalCounter:
    def __init__(self):
        self.image = None
        self.processed_image = None
        self.results = None
        self.setup_ui()
    
    def setup_ui(self):
        """Create interactive widgets for Voila"""
        # File upload widget
        self.upload_widget = widgets.FileUpload(
            accept='image/*',
            multiple=False,
            description='Upload SEM Image'
        )
        
        # Shape detection parameters
        self.shape_selector = widgets.Dropdown(
            options=['hexagonal', 'cubic', 'circular', 'all'],
            value='all',
            description='Crystal Shape:'
        )
        
        # Image preprocessing parameters
        self.gaussian_blur = widgets.IntSlider(
            value=3, min=1, max=15, step=2,
            description='Gaussian Blur:'
        )
        
        self.threshold_type = widgets.Dropdown(
            options=['otsu', 'adaptive', 'manual'],
            value='otsu',
            description='Threshold Method:'
        )

        self.threshold_type.observe(self.on_threshold_type_change, names='value')

        # Update manual threshold slider to be initially disabled
        self.manual_threshold = widgets.IntSlider(
            value=127, min=0, max=255,
            description='Manual Threshold:',
            disabled=True  # Initially disabled
        )
        
        self.manual_threshold = widgets.IntSlider(
            value=127, min=0, max=255,
            description='Manual Threshold:'
        )
        
        # Crystal detection parameters
        self.min_area = widgets.IntSlider(
            value=100, min=10, max=5000,
            description='Min Crystal Area:'
        )
        
        self.max_area = widgets.IntSlider(
            value=10000, min=100, max=100000,
            description='Max Crystal Area:'
        )
        
        # Circularity for shape filtering
        self.min_circularity = widgets.FloatSlider(
            value=0.3, min=0.1, max=1.0, step=0.1,
            description='Min Circularity:'
        )
        
        # Convexity for shape filtering
        self.min_convexity = widgets.FloatSlider(
            value=0.5, min=0.1, max=1.0, step=0.1,
            description='Min Convexity:'
        )
        
        # Process button
        self.process_button = widgets.Button(
            description='Process Image',
            button_style='success'
        )
        
        # Export button
        self.export_button = widgets.Button(
            description='Export Results',
            button_style='info'
        )

        self.use_watershed = widgets.Checkbox(
            value=True,
            description='Use Watershed Segmentation',
            tooltip='Helps separate touching crystals'
        )
        
        # Output areas
        self.output = widgets.Output()
        self.results_output = widgets.Output()
        
        # Setup callbacks
        self.upload_widget.observe(self.on_upload, names='value')
        self.process_button.on_click(self.process_image)
        self.export_button.on_click(self.export_results)

    def on_threshold_type_change(self, change):
        """Enable/disable manual threshold based on threshold type"""
        self.manual_threshold.disabled = (self.threshold_type.value != 'manual')
    
    def display_ui(self):
        """Display the complete UI for Voila"""
        # Group widgets in boxes
        upload_box = widgets.VBox([
            widgets.HTML("<h2>Crystal Counter for SEM Images</h2>"),
            self.upload_widget
        ])
        
        params_box = widgets.VBox([
            widgets.HTML("<h3>Detection Parameters</h3>"),
            self.shape_selector,
            widgets.HTML("<h4>Preprocessing</h4>"),
            self.gaussian_blur,
            self.threshold_type,
            self.manual_threshold,
            widgets.HTML("<h4>Crystal Properties</h4>"),
            self.min_area,
            self.max_area,
            self.min_circularity,
            self.min_convexity
        ])
        
        controls_box = widgets.HBox([
            self.process_button,
            self.export_button
        ])
        
        main_ui = widgets.VBox([
            upload_box,
            params_box,
            controls_box,
            self.output,
            self.results_output
        ])
        
        display(main_ui)
    
    def on_upload(self, change):
        """Handle image upload"""
        if len(self.upload_widget.value) > 0:
            # Get the first uploaded file
            uploaded_file = self.upload_widget.value[0]
            
            # For TIF files, we need to use PIL first, then convert to OpenCV format
            try:
                # Read the image data
                image_data = uploaded_file['content']
                
                # Try to open with PIL first (better TIF support)
                pil_image = Image.open(io.BytesIO(image_data))
                
                # Convert to RGB if necessary (some TIF files might be in different modes)
                if pil_image.mode != 'RGB':
                    pil_image = pil_image.convert('RGB')
                
                # Convert PIL image to numpy array
                image_array = np.array(pil_image)
                
                # Convert RGB to BGR for OpenCV
                self.image = cv2.cvtColor(image_array, cv2.COLOR_RGB2BGR)
                
                with self.output:
                    clear_output()
                    print("✅ Image uploaded successfully!")
                    print(f"Image dimensions: {self.image.shape}")
                    print(f"Image format: {uploaded_file['name']}")
                    
            except Exception as e:
                with self.output:
                    clear_output()
                    print(f"❌ Error loading image: {str(e)}")
                    print("Please ensure the file is a valid image format.")
    
    def preprocess_image(self):
        """Enhanced preprocessing for SEM images"""
        if self.image is None:
            return None, None
        
        # Convert to grayscale
        gray = cv2.cvtColor(self.image, cv2.COLOR_BGR2GRAY)
        
        # Enhance contrast using CLAHE (Contrast Limited Adaptive Histogram Equalization)
        clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        enhanced = clahe.apply(gray)
        
        # Apply bilateral filter to reduce noise while keeping edges sharp
        denoised = cv2.bilateralFilter(enhanced, 9, 75, 75)
        
        # Apply Gaussian blur
        blur_size = self.gaussian_blur.value
        blurred = cv2.GaussianBlur(denoised, (blur_size, blur_size), 0)
        
        # Apply thresholding
        if self.threshold_type.value == 'otsu':
            _, binary = cv2.threshold(blurred, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
        elif self.threshold_type.value == 'adaptive':
            binary = cv2.adaptiveThreshold(
                blurred, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, 
                cv2.THRESH_BINARY, 21, 5  # Increased block size and C value
            )
        else:  # manual
            _, binary = cv2.threshold(blurred, self.manual_threshold.value, 255, cv2.THRESH_BINARY)
        
        # Invert if necessary (crystals should be white on black background)
        # Check if we need to invert by counting white pixels
        white_pixels = np.sum(binary == 255)
        total_pixels = binary.size
        if white_pixels > total_pixels * 0.5:  # If more than 50% is white, invert
            binary = cv2.bitwise_not(binary)
        
        # Enhanced morphological operations
        kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3,3))
        binary = cv2.morphologyEx(binary, cv2.MORPH_OPEN, kernel, iterations=2)
        binary = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel, iterations=2)
        
        # Fill holes in crystals
        contours, _ = cv2.findContours(binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        filled = np.zeros_like(binary)
        cv2.drawContours(filled, contours, -1, 255, -1)
        
        return enhanced, filled

    def watershed_segmentation(self, binary_image):
        """Apply watershed segmentation to separate touching crystals"""
        # Compute distance transform
        dist_transform = cv2.distanceTransform(binary_image, cv2.DIST_L2, 5)
        
        # Find local maxima (crystal centers)
        _, sure_fg = cv2.threshold(dist_transform, 0.5 * dist_transform.max(), 255, 0)
        sure_fg = np.uint8(sure_fg)
        
        # Find sure background area
        kernel = np.ones((3,3), np.uint8)
        sure_bg = cv2.dilate(binary_image, kernel, iterations=3)
        
        # Find unknown region (this line was missing before sure_fg was used)
        unknown = cv2.subtract(sure_bg, sure_fg)
        
        # Marker labelling
        _, markers = cv2.connectedComponents(sure_fg)
        
        # Add 1 to all labels so that sure background is not 0, but 1
        markers = markers + 1
        
        # Mark the region of unknown with zero
        markers[unknown == 255] = 0
        
        # Convert binary image to 3-channel for watershed
        img_for_watershed = cv2.cvtColor(binary_image, cv2.COLOR_GRAY2BGR)
        
        # Apply watershed
        markers = cv2.watershed(img_for_watershed, markers)
        
        # Create binary mask from watershed result
        segmented = np.zeros_like(binary_image)
        segmented[markers > 1] = 255
        
        return segmented, markers
    
    def detect_crystals(self, binary_image):
        """Detect crystals based on shape criteria"""
        # Find contours
        contours, _ = cv2.findContours(binary_image, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        detected_crystals = []
        
        for contour in contours:
            # Calculate properties
            area = cv2.contourArea(contour)
            
            # Filter by area
            if area < self.min_area.value or area > self.max_area.value:
                continue
            
            # Calculate shape properties
            perimeter = cv2.arcLength(contour, True)
            if perimeter == 0:
                continue
                
            circularity = 4 * np.pi * area / (perimeter * perimeter)
            
            # Calculate convexity
            hull = cv2.convexHull(contour)
            hull_area = cv2.contourArea(hull)
            if hull_area == 0:
                continue
            convexity = area / hull_area
            
            # Filter by shape properties
            if circularity < self.min_circularity.value or convexity < self.min_convexity.value:
                continue
            
            # Shape-specific filtering
            if self.shape_selector.value != 'all':
                if not self.is_target_shape(contour):
                    continue
            
            # Calculate additional properties
            moments = cv2.moments(contour)
            if moments['m00'] != 0:
                cx = int(moments['m10'] / moments['m00'])
                cy = int(moments['m01'] / moments['m00'])
            else:
                cx, cy = 0, 0
            
            # Fit ellipse for orientation
            if len(contour) >= 5:
                ellipse = cv2.fitEllipse(contour)
                angle = ellipse[2]
                major_axis = max(ellipse[1])
                minor_axis = min(ellipse[1])
                aspect_ratio = major_axis / minor_axis if minor_axis > 0 else 0
            else:
                angle = 0
                aspect_ratio = 1
                major_axis = minor_axis = np.sqrt(area / np.pi) * 2
            
            detected_crystals.append({
                'contour': contour,
                'area': area,
                'perimeter': perimeter,
                'circularity': circularity,
                'convexity': convexity,
                'center_x': cx,
                'center_y': cy,
                'angle': angle,
                'aspect_ratio': aspect_ratio,
                'major_axis': major_axis,
                'minor_axis': minor_axis
            })
        
        return detected_crystals
    
    def is_target_shape(self, contour):
        """Determine if contour matches target shape"""
        # Approximate contour to polygon
        epsilon = 0.02 * cv2.arcLength(contour, True)
        approx = cv2.approxPolyDP(contour, epsilon, True)
        
        if self.shape_selector.value == 'hexagonal':
            # Look for 6-sided polygons
            return len(approx) >= 5 and len(approx) <= 8
        elif self.shape_selector.value == 'cubic':
            # Look for 4-sided polygons (squares/rectangles)
            return len(approx) >= 3 and len(approx) <= 5
        elif self.shape_selector.value == 'circular':
            # High circularity
            area = cv2.contourArea(contour)
            perimeter = cv2.arcLength(contour, True)
            if perimeter > 0:
                circularity = 4 * np.pi * area / (perimeter * perimeter)
                return circularity > 0.7
        
        return True
    
    def visualize_results(self, gray_image, crystals):
        """Create visualization of detected crystals"""
        # Create output image
        result_image = cv2.cvtColor(gray_image, cv2.COLOR_GRAY2BGR)
        
        # Draw detected crystals
        for i, crystal in enumerate(crystals):
            # Draw contour
            cv2.drawContours(result_image, [crystal['contour']], -1, (0, 255, 0), 2)
            
            # Draw center point
            cv2.circle(result_image, (crystal['center_x'], crystal['center_y']), 3, (255, 0, 0), -1)
            
            # Add label
            cv2.putText(result_image, str(i+1), 
                       (crystal['center_x']-10, crystal['center_y']-10),
                       cv2.FONT_HERSHEY_SIMPLEX, 0.5, (255, 255, 0), 1)
        
        return result_image
    
    def process_image(self, button):
        """Main processing function"""
        if self.image is None:
            with self.output:
                clear_output()
                print("❌ Please upload an image first!")
            return
        
        with self.output:
            clear_output()
            print("🔄 Processing image...")
        
        try:
            # Preprocess image
            gray, binary = self.preprocess_image()
            
            # Apply watershed segmentation if enabled
            if hasattr(self, 'use_watershed') and self.use_watershed.value:
                binary, markers = self.watershed_segmentation(binary)
            
            # Detect crystals
            crystals = self.detect_crystals(binary)
            
            # Create visualization
            result_image = self.visualize_results(gray, crystals)
            
            # Store results
            self.results = crystals
            self.processed_image = result_image
            
            # Display results
            with self.output:
                clear_output()
                
                # Create subplot figure
                fig, axes = plt.subplots(2, 2, figsize=(15, 12))
                
                # Original image
                axes[0, 0].imshow(cv2.cvtColor(self.image, cv2.COLOR_BGR2RGB))
                axes[0, 0].set_title('Original Image')
                axes[0, 0].axis('off')
                
                # Preprocessed binary image
                axes[0, 1].imshow(binary, cmap='gray')
                axes[0, 1].set_title('Binary Image')
                axes[0, 1].axis('off')
                
                # Results
                axes[1, 0].imshow(cv2.cvtColor(result_image, cv2.COLOR_BGR2RGB))
                axes[1, 0].set_title(f'Detected Crystals: {len(crystals)}')
                axes[1, 0].axis('off')
                
                # Size distribution histogram
                if crystals:
                    areas = [c['area'] for c in crystals]
                    axes[1, 1].hist(areas, bins=20, alpha=0.7, color='green')
                    axes[1, 1].set_title('Crystal Size Distribution')
                    axes[1, 1].set_xlabel('Area (pixels²)')
                    axes[1, 1].set_ylabel('Count')
                else:
                    axes[1, 1].text(0.5, 0.5, 'No crystals detected', 
                                   ha='center', va='center', transform=axes[1, 1].transAxes)
                    axes[1, 1].set_title('Crystal Size Distribution')
                
                plt.tight_layout()
                plt.show()
                
                print(f"✅ Processing complete! Detected {len(crystals)} crystals.")
                
            # Display detailed results
            self.display_results_table()
            
        except Exception as e:
            with self.output:
                clear_output()
                print(f"❌ Error processing image: {str(e)}")
    
    def display_results_table(self):
        """Display detailed results in a table"""
        if not self.results:
            return
        
        with self.results_output:
            clear_output()
            
            # Create results DataFrame
            results_data = []
            for i, crystal in enumerate(self.results):
                results_data.append({
                    'Crystal_ID': i + 1,
                    'Area_px²': round(crystal['area'], 2),
                    'Perimeter_px': round(crystal['perimeter'], 2),
                    'Circularity': round(crystal['circularity'], 3),
                    'Convexity': round(crystal['convexity'], 3),
                    'Center_X': crystal['center_x'],
                    'Center_Y': crystal['center_y'],
                    'Angle_degrees': round(crystal['angle'], 1),
                    'Aspect_Ratio': round(crystal['aspect_ratio'], 2),
                    'Major_Axis_px': round(crystal['major_axis'], 2),
                    'Minor_Axis_px': round(crystal['minor_axis'], 2)
                })
            
            df = pd.DataFrame(results_data)
            
            print("📊 Detailed Crystal Analysis Results:")
            print("=" * 60)
            display(df)
            
            # Summary statistics
            print("\n📈 Summary Statistics:")
            print(f"Total Crystals Detected: {len(self.results)}")
            print(f"Average Area: {df['Area_px²'].mean():.2f} ± {df['Area_px²'].std():.2f} px²")
            print(f"Average Circularity: {df['Circularity'].mean():.3f} ± {df['Circularity'].std():.3f}")
            print(f"Average Aspect Ratio: {df['Aspect_Ratio'].mean():.2f} ± {df['Aspect_Ratio'].std():.2f}")
    
    def export_results(self, button):
        """Export results to CSV"""
        if not self.results:
            with self.results_output:
                print("❌ No results to export! Please process an image first.")
            return
        
        try:
            # Create results DataFrame
            results_data = []
            for i, crystal in enumerate(self.results):
                results_data.append({
                    'Crystal_ID': i + 1,
                    'Area_px²': crystal['area'],
                    'Perimeter_px': crystal['perimeter'],
                    'Circularity': crystal['circularity'],
                    'Convexity': crystal['convexity'],
                    'Center_X': crystal['center_x'],
                    'Center_Y': crystal['center_y'],
                    'Angle_degrees': crystal['angle'],
                    'Aspect_Ratio': crystal['aspect_ratio'],
                    'Major_Axis_px': crystal['major_axis'],
                    'Minor_Axis_px': crystal['minor_axis']
                })
            
            df = pd.DataFrame(results_data)
            
            # Save to CSV (in Voila, this will download)
            csv_string = df.to_csv(index=False)
            
            # Create download link for Voila
            b64 = base64.b64encode(csv_string.encode()).decode()
            href = f'<a href="data:file/csv;base64,{b64}" download="crystal_analysis_results.csv">Download Results CSV</a>'
            
            with self.results_output:
                print("💾 Results exported successfully!")
                display(widgets.HTML(href))
                
        except Exception as e:
            with self.results_output:
                print(f"❌ Error exporting results: {str(e)}")

# Initialize and display the application
def main():
    """Main function to run the Crystal Counter"""
    app = CrystalCounter()
    app.display_ui()
    return app

# For Voila compatibility
if __name__ == "__main__":
    crystal_app = main()

VBox(children=(VBox(children=(HTML(value='<h2>Crystal Counter for SEM Images</h2>'), FileUpload(value=(), acce…