In [1]:
import os
import numpy as np
import tifffile as tiff

In [3]:


def replace_with_3_sharpest_layers(image_dir):
    # Traverse all directories and subdirectories
    for root, _, files in os.walk(image_dir):
        for file in files:
            if file.lower().endswith(('.tiff', '.tif')):  # Check for TIFF and TIF files
                img_path = os.path.join(root, file)
                
                # Read the image
                image = tiff.imread(img_path)
                
                # Handle 2D grayscale images
                if len(image.shape) == 2:
                    print(f"Image {img_path} is a 2D grayscale image (shape: {image.shape}).")
                    
                    # Convert to 3-channel by duplicating the grayscale image
                    three_channel_image = np.stack([image] * 3, axis=0)  # Creates (3, H, W)
                    
                    # Save the 3-channel image to replace the original
                    tiff.imwrite(img_path, three_channel_image.astype(image.dtype))
                    print(f"Replaced {img_path} with a 3-channel image.")
                    continue
                
                # Handle images with fewer than 3 layers
                if image.ndim > 2 and image.shape[0] < 3:
                    print(f"Warning: Image {img_path} has less than 3 layers. (Found {image.shape[0]} layers)")
                    break  # Skip processing for this image

                # If the image has more than 3 layers, proceed
                if image.ndim > 2 and image.shape[0] > 3:
                    print(f"Processing {img_path} with {image.shape[0]} layers.")
                    
                    # Normalize the 16-bit image to [0, 1] for consistent sharpness calculation
                    image_normalized = image.astype(np.float32) / 65535.0
                    
                    # Calculate sharpness for each layer
                    sharpness_scores = []
                    for i in range(image.shape[0]):
                        layer = image_normalized[i]
                        gy, gx = np.gradient(layer)  # Compute gradients
                        gnorm = np.sqrt(gx**2 + gy**2)  # Gradient magnitude
                        sharpness = np.average(gnorm)  # Average sharpness
                        sharpness_scores.append((sharpness, i))  # Store with index
                    
                    # Sort layers by sharpness score and select the top 3 sharpest layers
                    sharpest_layers = sorted(sharpness_scores, reverse=True, key=lambda x: x[0])[:3]
                    top_3_indices = [index for _, index in sharpest_layers]
                    
                    # Create a new image with only the 3 sharpest layers
                    new_image = image[top_3_indices, :, :]
                    
                    # Save the new image to replace the original
                    tiff.imwrite(img_path, new_image.astype(np.uint16))
                    print(f"Replaced {img_path} with the 3 sharpest layers.")

# Example usage
image_dir = r"C:\Users\k54739\Today_data\segmentation\test_img"
replace_with_3_sharpest_layers(image_dir)
