In [2]:
from PySide6.QtWidgets import QWidget, QVBoxLayout, QLabel, QSlider, QPushButton, QComboBox, QStackedWidget, QDoubleSpinBox, QGridLayout
from PySide6.QtCore import Qt, Signal
import numpy as np
import imageio # For general image loading (can use Pillow too)
import skimage.filters
import skimage.morphology
from skimage.color import rgb2gray
from scipy.ndimage import convolve

from modules.i_image_module import IImageModule
from image_data_store import ImageDataStore

In [3]:
##
class BaseParamsWidget(QWidget):
    """Base class for parameter widgets to ensure a consistent interface."""
    def get_params(self) -> dict:
        raise NotImplementedError

class NoParamsWidget(BaseParamsWidget):
    """A placeholder widget for operations with no parameters."""
    def __init__(self, parent=None):
        super().__init__(parent)
        layout = QVBoxLayout(self)
        label = QLabel("This operation has no parameters.")
        label.setStyleSheet("font-style: italic; color: gray;")
        layout.addWidget(label)
        layout.addStretch()

    def get_params(self) -> dict:
        return {}


In [4]:
##
class GaussianParamsWidget(BaseParamsWidget):
    """A widget for Gaussian blur parameters."""
    def __init__(self, parent=None):
        super().__init__(parent)
        layout = QVBoxLayout(self)
        layout.setContentsMargins(0, 0, 0, 0)
        layout.addWidget(QLabel("Sigma (Standard Deviation):"))
        self.sigma_spinbox = QDoubleSpinBox()
        self.sigma_spinbox.setMinimum(0.1)
        self.sigma_spinbox.setMaximum(25.0)
        self.sigma_spinbox.setValue(1.0)
        self.sigma_spinbox.setSingleStep(0.1)
        layout.addWidget(self.sigma_spinbox)
        layout.addStretch()

    def get_params(self) -> dict:
        return {'sigma': self.sigma_spinbox.value()}



In [5]:
##
class PowerLawParamsWidget(BaseParamsWidget):
    """A widget for Power Law (Gamma) Transformation."""
    def __init__(self, parent=None):
        super().__init__(parent)
        layout = QVBoxLayout(self)
        layout.setContentsMargins(0, 0, 0, 0)
        layout.addWidget(QLabel("Gamma:"))
        self.gamma_spinbox = QDoubleSpinBox()
        self.gamma_spinbox.setMinimum(0.01)
        self.gamma_spinbox.setMaximum(5.0)
        self.gamma_spinbox.setValue(1.0)
        self.gamma_spinbox.setSingleStep(0.1)
        layout.addWidget(self.gamma_spinbox)
        layout.addStretch()

    def get_params(self) -> dict:
        return {'gamma': self.gamma_spinbox.value()}



In [6]:
##
class ConvolutionParamsWidget(BaseParamsWidget):
    """A widget for defining a 3x3 convolution kernel."""
    def __init__(self, parent=None):
        super().__init__(parent)
        layout = QVBoxLayout(self)
        layout.setContentsMargins(0, 0, 0, 0)
        layout.addWidget(QLabel("3x3 Kernel:"))
        
        grid_layout = QGridLayout()
        self.kernel_inputs = []
        for r in range(3):
            row_inputs = []
            for c in range(3):
                spinbox = QDoubleSpinBox()
                spinbox.setMinimum(-100.0)
                spinbox.setMaximum(100.0)
                spinbox.setValue(0.0)
                # Set center to 1.0 for an identity-like default
                if r == 1 and c == 1:
                    spinbox.setValue(1.0)
                grid_layout.addWidget(spinbox, r, c)
                row_inputs.append(spinbox)
            self.kernel_inputs.append(row_inputs)
        layout.addLayout(grid_layout)

    def get_params(self) -> dict:
        kernel = np.array([[spinbox.value() for spinbox in row] for row in self.kernel_inputs])
        return {'kernel': kernel}



In [7]:
##
class CLAHEParamsWidget(BaseParamsWidget):
    """Widget for CLAHE parameters."""
    def __init__(self, parent=None):
        super().__init__(parent)
        layout = QVBoxLayout(self)
        layout.setContentsMargins(0, 0, 0, 0)
        
        # Clip limit parameter
        layout.addWidget(QLabel("Clip Limit:"))
        self.clip_spinbox = QDoubleSpinBox()
        self.clip_spinbox.setMinimum(0.1)
        self.clip_spinbox.setMaximum(10.0)
        self.clip_spinbox.setValue(2.0)  # Default value
        self.clip_spinbox.setSingleStep(0.1)
        layout.addWidget(self.clip_spinbox)
        
        # Tile size parameter
        layout.addWidget(QLabel("Tile Size (Grid):"))
        self.tile_spinbox = QDoubleSpinBox()
        self.tile_spinbox.setMinimum(2)
        self.tile_spinbox.setMaximum(32)
        self.tile_spinbox.setValue(8)  # Default 8x8 tiles
        self.tile_spinbox.setSingleStep(1)
        layout.addWidget(self.tile_spinbox)
        
        layout.addStretch()
    
    def get_params(self) -> dict:
        return {
            'clip_limit': self.clip_spinbox.value(),
            'tile_size': int(self.tile_spinbox.value())
        }

##

In [8]:
##
class DopplerBoostParamsWidget(BaseParamsWidget):
    """Widget for Relativistic Doppler Boosting parameters."""
    def __init__(self, parent=None):
        super().__init__(parent)
        layout = QVBoxLayout(self)
        layout.setContentsMargins(0, 0, 0, 0)
        
        # Velocity parameter (fraction of speed of light)
        layout.addWidget(QLabel("Velocity (fraction of c):"))
        self.velocity_spinbox = QDoubleSpinBox()
        self.velocity_spinbox.setMinimum(0.1)
        self.velocity_spinbox.setMaximum(0.9)  # Up to 90% speed of light
        self.velocity_spinbox.setValue(0.5)  # Default 0.5c
        self.velocity_spinbox.setSingleStep(0.05)
        layout.addWidget(self.velocity_spinbox)
        
        # Viewing angle parameter
        layout.addWidget(QLabel("Viewing Angle (degrees):"))
        self.angle_spinbox = QDoubleSpinBox()
        self.angle_spinbox.setMinimum(0)
        self.angle_spinbox.setMaximum(180)
        self.angle_spinbox.setValue(90)  # Default edge-on view
        self.angle_spinbox.setSingleStep(5)
        layout.addWidget(self.angle_spinbox)
        
        layout.addStretch()
    
    def get_params(self) -> dict:
        return {
            'velocity': self.velocity_spinbox.value(),
            'angle': self.angle_spinbox.value()
        }

In [9]:
##
# Define a custom control widget
class SampleControlsWidget(QWidget):
    # Signal to request processing from the module manager
    process_requested = Signal(dict)

    def __init__(self, module_manager, parent=None):
        super().__init__(parent)
        self.module_manager = module_manager
        self.param_widgets = {}
        self.setup_ui()

    def setup_ui(self):
        layout = QVBoxLayout(self)
        layout.addWidget(QLabel("<h3>Control Panel</h3>"))

        layout.addWidget(QLabel("Operation:"))
        self.operation_selector = QComboBox()
        layout.addWidget(self.operation_selector)

        # Stacked widget to hold the parameter UIs
        self.params_stack = QStackedWidget()
        layout.addWidget(self.params_stack)

        # Define operations and their corresponding parameter widgets
        operations = {
            "Gaussian Blur": GaussianParamsWidget,
            "Sobel Edge Detect": NoParamsWidget,
            "Power Law (Gamma)": PowerLawParamsWidget,
            "Convolution": ConvolutionParamsWidget,
            "CLAHE Enhancement":CLAHEParamsWidget,
            "Relativistic Doppler Boosting":DopplerBoostParamsWidget,
            "Canny Edge Detection":NoParamsWidget
        }

        for name, widget_class in operations.items():
            widget = widget_class()
            self.params_stack.addWidget(widget)
            self.param_widgets[name] = widget
            self.operation_selector.addItem(name)

        self.apply_button = QPushButton("Apply Processing")
        layout.addWidget(self.apply_button)

        self.apply_button.clicked.connect(self._on_apply_clicked)
        self.operation_selector.currentTextChanged.connect(self._on_operation_changed)

    def _on_apply_clicked(self):
        operation_name = self.operation_selector.currentText()
        active_widget = self.param_widgets[operation_name]
        params = active_widget.get_params()
        params['operation'] = operation_name # Add operation name to params
        self.process_requested.emit(params)

    def _on_operation_changed(self, operation_name: str):
        if operation_name in self.param_widgets:
            self.params_stack.setCurrentWidget(self.param_widgets[operation_name])



In [10]:
##
class SampleImageModule(IImageModule):
    def __init__(self):
        super().__init__()
        self._controls_widget = None

    def get_name(self) -> str:
        return "Sample Module"

    def get_supported_formats(self) -> list[str]:
        return ["png", "jpg", "jpeg", "bmp", "gif", "tiff"] # Common formats

    def create_control_widget(self, parent=None, module_manager=None) -> QWidget:
        if self._controls_widget is None:
            self._controls_widget = SampleControlsWidget(module_manager, parent)
            # The widget's signal is connected to the module's handler
            self._controls_widget.process_requested.connect(self._handle_processing_request)
        return self._controls_widget

    def _handle_processing_request(self, params: dict):
        # Here, the module needs a way to trigger processing in the main app
        # The control widget now has a valid reference to the module manager
        if self._controls_widget and self._controls_widget.module_manager:
            self._controls_widget.module_manager.apply_processing_to_current_image(params)

    

In [11]:
##
def load_image(self, file_path: str):
        try:
            image_data = imageio.imread(file_path)
            # Ensure 2D images are correctly shaped (e.g., handle grayscale vs RGB)
            if image_data.ndim == 3 and image_data.shape[2] in [3, 4]: # RGB or RGBA
                # napari handles this well, but for processing, sometimes a single channel is needed
                pass
            elif image_data.ndim == 2: # Grayscale
                image_data = image_data[np.newaxis, :] # Add a channel dimension for consistency if desired
            else:
                print(f"Warning: Unexpected image dimensions {image_data.shape}")

            metadata = {'name': file_path.split('/')[-1]}
            # Add more metadata: original_shape, file_size, etc.
            return True, image_data, metadata, None # Session ID generated by store
        except Exception as e:
            print(f"Error loading 2D image {file_path}: {e}")
            return False, None, {}, None

   

In [12]:
##
def process_image(self, image_data: np.ndarray, metadata: dict, params: dict) -> np.ndarray:
        processed_data = image_data.copy()

        operation = params.get('operation')

        if operation == "Gaussian Blur":
            sigma = params.get('sigma', 1.0)
            # skimage.filters.gaussian expects float data
            processed_data = skimage.filters.gaussian(processed_data.astype(float), sigma=sigma, preserve_range=True)
        elif operation == "Median Filter":
            filter_size = params.get('filter_size', 3)
            if filter_size <= 1: return processed_data # No change
            # skimage.filters.median
            if processed_data.ndim == 3 and processed_data.shape[2] in [3, 4]: # RGB/RGBA
                # Apply to each channel
                channels = []
                for i in range(processed_data.shape[2]):
                    channels.append(skimage.filters.median(processed_data[:,:,i], footprint=skimage.morphology.disk(int(filter_size/2))))
                processed_data = np.stack(channels, axis=-1)
            else:
                processed_data = skimage.filters.median(processed_data, footprint=skimage.morphology.disk(int(filter_size/2)))
        elif operation == "Sobel Edge Detect":
            # Sobel works on 2D (grayscale) images. Convert if necessary.
            if processed_data.ndim == 3 and processed_data.shape[2] in [3, 4]:
                grayscale_img = rgb2gray(processed_data[:,:,:3])
            else:
                grayscale_img = processed_data
            
            processed_data = skimage.filters.sobel(grayscale_img)
        elif operation == "Power Law (Gamma)":
            gamma = params.get('gamma', 1.0)
            # Normalize to [0, 1]
            input_float = processed_data.astype(float)
            max_val = np.max(input_float)
            if max_val > 0:
                normalized = input_float / max_val
                # Apply gamma correction
                gamma_corrected = np.power(normalized, gamma)
                # Scale back to original range
                processed_data = gamma_corrected * max_val
        
        elif operation == "Convolution":
            kernel = params.get('kernel')
            if kernel is not None:
                # Convolve works best on float images
                input_float = processed_data.astype(float)
                if input_float.ndim == 3 and input_float.shape[2] in [3, 4]: # RGB/RGBA
                    channels = []
                    for i in range(input_float.shape[2]):
                        channels.append(convolve(input_float[:,:,i], kernel, mode='reflect'))
                    processed_data = np.stack(channels, axis=-1)
                else:
                    processed_data = convolve(input_float, kernel, mode='reflect')
        elif operation == "CLAHE Enhancement":
            clip_limit = params.get('clip_limit', 2.0)
            tile_size = params.get('tile_size', 8)
            
            # Convert to grayscale if color image
            if processed_data.ndim == 3 and processed_data.shape[2] in [3, 4]:
                grayscale_img = rgb2gray(processed_data[:,:,:3])
                was_color = True
            else:
                grayscale_img = processed_data.copy()
                was_color = False
            
            # Normalize to [0, 255] range
            img_min = grayscale_img.min()
            img_max = grayscale_img.max()
            if img_max > img_min:
                normalized = ((grayscale_img - img_min) / (img_max - img_min) * 255).astype(np.uint8)
            else:
                normalized = grayscale_img.astype(np.uint8)
            
            # Calculate tile dimensions
            height, width = normalized.shape
            tile_h = height // tile_size
            tile_w = width // tile_size
            
            # Create output array
            clahe_img = np.zeros_like(normalized, dtype=np.float32)
            
            # Process each tile
            for i in range(tile_size):
                for j in range(tile_size):
                    # Calculate tile boundaries
                    y1 = i * tile_h
                    y2 = (i + 1) * tile_h if i < tile_size - 1 else height
                    x1 = j * tile_w
                    x2 = (j + 1) * tile_w if j < tile_size - 1 else width
                    
                    tile = normalized[y1:y2, x1:x2]
                    
                    # Compute histogram
                    hist, bins = np.histogram(tile.flatten(), bins=256, range=(0, 256))
                    
                    # Clip histogram
                    clip_threshold = clip_limit * (tile.size / 256.0)
                    clipped_hist = np.minimum(hist, clip_threshold)
                    
                    # Redistribute clipped pixels
                    excess = np.sum(hist - clipped_hist)
                    redistribute = excess / 256.0
                    clipped_hist = clipped_hist + redistribute
                    
                    # Compute CDF
                    cdf = np.cumsum(clipped_hist)
                    cdf = cdf / cdf[-1]  # Normalize to [0, 1]
                    
                    # Map intensities using CDF
                    equalized_tile = np.interp(tile.flatten(), bins[:-1], cdf * 255)
                    equalized_tile = equalized_tile.reshape(tile.shape)
                    
                    # Store result
                    clahe_img[y1:y2, x1:x2] = equalized_tile
            
            # Convert back to original range
            if img_max > img_min:
                clahe_img = (clahe_img / 255.0) * (img_max - img_min) + img_min
            
            # Store result
            processed_data = clahe_img.astype(image_data.dtype)


           

        elif operation == "Relativistic Doppler Boosting":
            velocity = params.get('velocity', 0.5)  # fraction of c
            viewing_angle = params.get('angle', 90)  # degrees
            
            # Convert to grayscale if needed
            if processed_data.ndim == 3 and processed_data.shape[2] in [3, 4]:
                grayscale_img = rgb2gray(processed_data[:,:,:3])
            else:
                grayscale_img = processed_data.copy()
            
            # Find image center (black hole location)
            height, width = grayscale_img.shape
            center_y = height // 2
            center_x = width // 2
            
            # Create coordinate grids
            y, x = np.ogrid[0:height, 0:width]
            dy = y - center_y
            dx = x - center_x
            
            # Calculate distance from center
            r = np.sqrt(dx**2 + dy**2)
            
            # Calculate angle around ring (φ)
            # phi = 0 at right, π/2 at top, π at left, -π/2 at bottom
            phi = np.arctan2(dy, dx)
            
            # Calculate Lorentz factor
            gamma = 1.0 / np.sqrt(1.0 - velocity**2)
            
            # Convert viewing angle to radians
            theta = np.radians(viewing_angle)
            
            # Calculate velocity component toward observer
            # For edge-on view (θ=90°): v_radial = v × cos(φ)
            # cos(φ) = +1 at right (perpendicular), 0 at bottom (approaching), -1 at left
            beta_radial = velocity * np.sin(theta) * np.cos(phi)
            
            # Calculate Doppler factor: δ = 1 / [γ(1 - β·cos(φ))]
            doppler_factor = 1.0 / (gamma * (1.0 - beta_radial))
            
            # Clip to reasonable range (avoid extreme values)
            doppler_factor = np.clip(doppler_factor, 0.1, 10.0)
            
            # Apply relativistic beaming: I' = I × δ³
            boosted = grayscale_img.astype(float) * (doppler_factor ** 3)
            
            # Create ring mask (apply only to ring, not center shadow)
            inner_radius = height * 0.15  # Shadow boundary
            outer_radius = height * 0.48  # Outer ring edge
            ring_mask = (r > inner_radius) & (r < outer_radius)
            
            # Apply boosting only to ring region
            result = grayscale_img.astype(float).copy()
            result[ring_mask] = boosted[ring_mask]
            
            # Normalize to avoid overflow
            result = np.clip(result, 0, 255)
            
            # Convert back to original type
            processed_data = result.astype(image_data.dtype)

        elif operation == "Canny Edge Detection":
            # Convert to grayscale
            if processed_data.ndim == 3 and processed_data.shape[2] in [3, 4]:
                grayscale_img = rgb2gray(processed_data[:,:,:3])
            else:
                grayscale_img = processed_data.copy()
            
            # Step 1: Gaussian smoothing
            smoothed = skimage.filters.gaussian(grayscale_img, sigma=1.0)
            
            # Step 2: Calculate gradients using Sobel operators
            sobel_x = np.array([[-1, 0, 1],
                                [-2, 0, 2],
                                [-1, 0, 1]], dtype=float)
            sobel_y = np.array([[-1, -2, -1],
                                [ 0,  0,  0],
                                [ 1,  2,  1]], dtype=float)
            
            Gx = convolve(smoothed, sobel_x)
            Gy = convolve(smoothed, sobel_y)
            
            # Step 3: Calculate magnitude and direction
            G = np.sqrt(Gx**2 + Gy**2)
            theta = np.arctan2(Gy, Gx)
            
            # Step 4: Non-maximum suppression
            height, width = G.shape
            nms = np.zeros_like(G)
            
            # Convert angle to degrees and round to nearest 45°
            angle_deg = np.degrees(theta) % 180
            
            for i in range(1, height - 1):
                for j in range(1, width - 1):
                    # Determine direction
                    angle = angle_deg[i, j]
                    
                    # 0° or 180° - horizontal edge (compare top and bottom)
                    if (0 <= angle < 22.5) or (157.5 <= angle <= 180):
                        neighbors = [G[i, j-1], G[i, j+1]]
                    # 45° - diagonal edge (compare NE and SW)
                    elif 22.5 <= angle < 67.5:
                        neighbors = [G[i-1, j+1], G[i+1, j-1]]
                    # 90° - vertical edge (compare left and right)
                    elif 67.5 <= angle < 112.5:
                        neighbors = [G[i-1, j], G[i+1, j]]
                    # 135° - diagonal edge (compare NW and SE)
                    else:
                        neighbors = [G[i-1, j-1], G[i+1, j+1]]
                    
                    # Keep pixel only if it's maximum along gradient direction
                    if G[i, j] >= max(neighbors):
                        nms[i, j] = G[i, j]
            
            # Step 5: Double thresholding
            T_high = 0.3 * np.max(nms)
            T_low = 0.1 * np.max(nms)
            
            strong_edges = nms > T_high
            weak_edges = (nms >= T_low) & (nms <= T_high)
            
            # Step 6: Edge tracking by hysteresis
            from scipy.ndimage import binary_dilation
            
            edges = strong_edges.copy()
            
            # Iteratively grow strong edges into weak edges
            for _ in range(10):
                dilated = binary_dilation(edges)
                new_edges = dilated & weak_edges
                if not np.any(new_edges & ~edges):
                    break  # No new edges found
                edges = edges | new_edges
            
            # Convert to 8-bit image
            processed_data = (edges * 255).astype(np.uint8)


        # Ensure output data type is consistent (e.g., convert back to uint8 if processing changed it)
        processed_data = processed_data.astype(image_data.dtype)

        return processed_data