In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# WEAK-SUPERVISION MASK GENERATION

## Purpose
Generates weak-supervision segmentation masks by combining original µCT slices with Multi-Otsu segmentation outputs. The script computes confidence heuristics from the original images to identify and mark uncertain regions as ignored, suitable for nnU-Net training with excluded regions.

## Inputs
- Directory containing original µCT slice images (TIFF or PNG)
- Directory containing corresponding Multi-Otsu segmentation slices
- Files are matched by filename (sorted alphabetically)

## Outputs
- Directory containing weak-supervision mask slices
- Single-channel images (same spatial dimensions as input)
- Pixel values encode:
  * 0, 1, 2, ... : valid pseudo-labels from Multi-Otsu
  * IGNORE_LABEL : uncertain/ignored regions (default: 255)

## Behavior
For each slice pair:
1. Load original image and its Multi-Otsu segmentation
2. Use Multi-Otsu labels as initial class assignment
3. Compute confidence heuristics from original image:
   - Local intensity variance (heterogeneity)
   - Gradient magnitude (edge strength)
   - Boundary proximity (distance to class boundaries)
4. Mark voxels as ignored if confidence criteria are violated
5. Preserve confident regions unchanged

All operations are deterministic (no randomness, no learning).

## Assumptions
- Files in both directories are paired by filename and sorted order
- Multi-Otsu segmentation has class values 0, 1, 2, ...
- Original images are grayscale uint8 or uint16
- No missing files or mismatched dimensions

## Dependencies
numpy, scipy, scikit-image, tifffile (or imageio)

In [None]:
import os
from pathlib import Path
import numpy as np
from scipy import ndimage
from skimage import filters, morphology
from skimage.util import img_as_ubyte, img_as_float32
import tifffile
from tifffile import TiffFile

## CONFIGURATION

In [11]:
# Input directories
INPUT_ORIGINAL_DIR = "/content/drive/MyDrive/soil_microCT_images/ROI/rehovot_ROI_8bit"
INPUT_MULTIOTSU_DIR = "/content/drive/MyDrive/soil_microCT_images/ROI/multiotsu_pores_outputs/rehovot_ROI_8bit/pores_class0_plus_1"

# Output directory
OUTPUT_DIR = "/content/drive/MyDrive/soil_microCT_images/ROI/weak_supervision_masks/rehovot_ROI_8bit"

# Label for ignored/uncertain regions
IGNORE_LABEL = 255

# Class label for pores/air (typically 0 = darkest class in Multi-Otsu)
# All other labels are considered solid material
PORE_LABEL = 0

# Confidence thresholds as percentiles (0-100)
# Higher percentiles = stricter filtering = more regions ignored
# These are applied to per-slice statistics (adaptive)
VARIANCE_PERCENTILE = 90.0       # Percentile for local variance (90 = top 10% marked uncertain)
GRADIENT_PERCENTILE = 90.0       # Percentile for gradient magnitude (90 = top 10% marked uncertain)
BOUNDARY_DISTANCE_THRESHOLD = 3  # Pixels from pore↔solid boundaries to ignore

# Window size for local statistics (pixels)
LOCAL_WINDOW_SIZE = 7

# File extensions to process
EXTS = {".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp"}

## HELPER FUNCTIONS

In [12]:
def ensure_dir(p):
    """Create directory if it doesn't exist."""
    Path(p).mkdir(parents=True, exist_ok=True)


def list_images(folder):
    """List all image files in a folder, sorted alphabetically."""
    folder = Path(folder)
    files = []
    if not folder.exists():
        print(f"Warning: folder does not exist: {folder}")
        return files
    for f in folder.iterdir():
        if f.is_file() and f.suffix.lower() in EXTS:
            files.append(str(f))
    return sorted(files)


def extract_slice_identifier(filename):
    """
    Extracts the common slice identifier from a filename.
    Assumes identifier is everything before '_pores_' or the base filename.
    """
    name_without_ext = Path(filename).stem
    # Example: 'roi_0660_slice0780_norm8_pores_class0plus1' -> 'roi_0660_slice0780_norm8'
    # Example: 'roi_0660_slice0780_norm8' -> 'roi_0660_slice0780_norm8'
    if "_pores_" in name_without_ext:
        return name_without_ext.split("_pores_")[0]
    return name_without_ext


def load_image(path):
    """Load image as grayscale numpy array."""
    try:
        # Try tifffile first (handles 16-bit TIFF well)
        img = tifffile.imread(path)
        if img.ndim == 3:
            # Convert RGB to grayscale if needed
            img = np.mean(img, axis=-1).astype(img.dtype)
        return img
    except Exception:
        # Fallback: try PIL or other loaders
        from PIL import Image
        img = Image.open(path)
        if img.mode != 'L':
            img = img.convert('L')
        return np.array(img)


In [13]:
def compute_local_variance(image, window_size):
    """
    Compute local intensity variance using uniform filter.

    High variance indicates heterogeneous/uncertain regions.
    """
    image_f = img_as_float32(image)

    # Local mean
    local_mean = ndimage.uniform_filter(image_f, size=window_size)

    # Local mean of squared values
    local_mean_sq = ndimage.uniform_filter(image_f ** 2, size=window_size)

    # Variance = E[X^2] - E[X]^2
    local_var = local_mean_sq - local_mean ** 2
    local_var = np.maximum(local_var, 0)  # Avoid numerical negatives

    return local_var


def compute_gradient_magnitude(image):
    """
    Compute gradient magnitude using Sobel operator.

    High gradient indicates strong edges, which may be uncertain boundaries.
    """
    image_f = img_as_float32(image)

    # Sobel filters
    grad_x = ndimage.sobel(image_f, axis=1)
    grad_y = ndimage.sobel(image_f, axis=0)

    # Magnitude
    grad_mag = np.sqrt(grad_x**2 + grad_y**2)

    return grad_mag


def compute_boundary_mask(segmentation, pore_label, distance_threshold):
    """
    Identify pixels near pore↔solid boundaries.

    Only pore/air to solid transitions are physically ambiguous in µCT.
    Solid↔solid boundaries remain confident.
    """
    # Create binary pore mask
    pore_mask = (segmentation == pore_label)

    # Find pore boundaries: dilate pores and find overlap with solids
    dilated_pores = morphology.binary_dilation(pore_mask)
    pore_boundary = dilated_pores & ~pore_mask  # Solid pixels adjacent to pores

    # Also mark pore pixels adjacent to solids
    eroded_pores = morphology.binary_erosion(pore_mask)
    pore_edge = pore_mask & ~eroded_pores  # Pore pixels at boundary

    # Combine: uncertainty zone includes both sides of pore↔solid interface
    boundary = pore_boundary | pore_edge

    # Dilate boundary to create uncertainty zone
    if distance_threshold > 0:
        selem = morphology.disk(distance_threshold)
        boundary = morphology.binary_dilation(boundary, selem)

    return boundary

In [14]:
def generate_weak_supervision_mask(original, multiotsu_seg, config):
    """
    Generate weak-supervision mask by marking uncertain regions.

    Parameters:
    -----------
    original : ndarray
        Original grayscale image
    multiotsu_seg : ndarray
        Multi-Otsu segmentation (class labels: 0, 1, 2, ...)
    config : dict
        Configuration dictionary with thresholds

    Returns:
    --------
    mask : ndarray (uint8)
        Weak-supervision mask with confident labels and ignored regions
    """
    h, w = original.shape

    # Initialize mask with Multi-Otsu labels
    mask = multiotsu_seg.astype(np.uint8).copy()

    # Compute confidence heuristics
    print("  Computing local variance...", end=" ")
    local_var = compute_local_variance(original, config['local_window_size'])
    print("done")

    print("  Computing gradient magnitude...", end=" ")
    grad_mag = compute_gradient_magnitude(original)
    print("done")

    print("  Computing pore↔solid boundary mask...", end=" ")
    boundary = compute_boundary_mask(multiotsu_seg, config['pore_label'], config['boundary_distance_threshold'])
    print("done")

    # Compute adaptive thresholds from per-slice statistics (percentile-based)
    # This accounts for slice-to-slice variation and float32 normalization
    var_thresh = np.percentile(local_var, config['variance_percentile']) if config['variance_percentile'] > 0 else np.inf
    grad_thresh = np.percentile(grad_mag, config['gradient_percentile']) if config['gradient_percentile'] > 0 else np.inf

    print(f"  Adaptive variance threshold: {var_thresh:.6f}")
    print(f"  Adaptive gradient threshold: {grad_thresh:.6f}")

    # Identify uncertain regions (logical OR of all criteria)
    uncertain = np.zeros((h, w), dtype=bool)

    # High variance regions (heterogeneous texture)
    if config['variance_percentile'] > 0:
        high_var = local_var > var_thresh
        uncertain |= high_var
        print(f"  High variance pixels: {high_var.sum()} ({100*high_var.mean():.2f}%)")

    # High gradient regions (strong edges)
    if config['gradient_percentile'] > 0:
        high_grad = grad_mag > grad_thresh
        uncertain |= high_grad
        print(f"  High gradient pixels: {high_grad.sum()} ({100*high_grad.mean():.2f}%)")

    # Near pore↔solid boundaries (physically ambiguous interfaces)
    if config['boundary_distance_threshold'] > 0:
        uncertain |= boundary
        print(f"  Pore↔solid boundary pixels: {boundary.sum()} ({100*boundary.mean():.2f}%)")

    # Mark uncertain regions with ignore label
    mask[uncertain] = config['ignore_label']

    total_ignored = uncertain.sum()
    print(f"  Total ignored pixels: {total_ignored} ({100*total_ignored/(h*w):.2f}%)")

    return mask

## MAIN PROCESSING

In [15]:
print("=" * 60)
print("WEAK-SUPERVISION MASK GENERATION")
print("=" * 60)
print()

# Configuration dictionary
config = {
    'ignore_label': IGNORE_LABEL,
    'pore_label': PORE_LABEL,
    'variance_percentile': VARIANCE_PERCENTILE,
    'gradient_percentile': GRADIENT_PERCENTILE,
    'boundary_distance_threshold': BOUNDARY_DISTANCE_THRESHOLD,
    'local_window_size': LOCAL_WINDOW_SIZE,
}

print("Configuration:")
print(f"  Input original dir: {INPUT_ORIGINAL_DIR}")
print(f"  Input Multi-Otsu dir: {INPUT_MULTIOTSU_DIR}")
print(f"  Output dir: {OUTPUT_DIR}")
print(f"  Ignore label: {IGNORE_LABEL}")
print(f"  Pore label: {PORE_LABEL}")
print(f"  Variance percentile: {VARIANCE_PERCENTILE}")
print(f"  Gradient percentile: {GRADIENT_PERCENTILE}")
print(f"  Boundary distance: {BOUNDARY_DISTANCE_THRESHOLD} px (pore↔solid only)")
print(f"  Local window size: {LOCAL_WINDOW_SIZE} px")
print()

WEAK-SUPERVISION MASK GENERATION

Configuration:
  Input original dir: /content/drive/MyDrive/soil_microCT_images/ROI/rehovot_ROI_8bit
  Input Multi-Otsu dir: /content/drive/MyDrive/soil_microCT_images/ROI/multiotsu_pores_outputs/rehovot_ROI_8bit/pores_class0_plus_1
  Output dir: /content/drive/MyDrive/soil_microCT_images/ROI/weak_supervision_masks/rehovot_ROI_8bit
  Ignore label: 255
  Pore label: 0
  Variance percentile: 90.0
  Gradient percentile: 90.0
  Boundary distance: 3 px (pore↔solid only)
  Local window size: 7 px



In [16]:
# Create output directory
ensure_dir(OUTPUT_DIR)

# List raw image file paths from both directories
original_files_list = list_images(INPUT_ORIGINAL_DIR)
multiotsu_files_list = list_images(INPUT_MULTIOTSU_DIR)

print(f"Found {len(original_files_list)} original images (raw count)")
print(f"Found {len(multiotsu_files_list)} Multi-Otsu segmentations (raw count)")
print()

# Create maps for identifiers to full paths
original_map = {extract_slice_identifier(Path(f).name): f for f in original_files_list}
multiotsu_map = {extract_slice_identifier(Path(f).name): f for f in multiotsu_files_list}

print(f"Found {len(original_map)} unique original image identifiers")
print(f"Found {len(multiotsu_map)} unique Multi-Otsu segmentation identifiers")
print()

# Find common identifiers to process
common_identifiers = sorted(list(set(original_map.keys()) & set(multiotsu_map.keys())))

if not common_identifiers:
    print("Error: No common image identifiers found between directories. Exiting.")
    # Set n_pairs to 0 to prevent further processing in the next cell
    n_pairs = 0
else:
    n_pairs = len(common_identifiers)
    print(f"Will process {n_pairs} image pairs based on common identifiers.\n")

    # Print warnings for unmatched files
    for identifier in original_map.keys():
        if identifier not in common_identifiers:
            print(f"Warning: Original image '{Path(original_map[identifier]).name}' (identifier: '{identifier}') has no matching Multi-Otsu segmentation. Skipping.")
    for identifier in multiotsu_map.keys():
        if identifier not in common_identifiers:
            print(f"Warning: Multi-Otsu segmentation '{Path(multiotsu_map[identifier]).name}' (identifier: '{identifier}') has no matching original image. Skipping.")
    print()

Found 661 original images (raw count)
Found 661 Multi-Otsu segmentations (raw count)

Found 661 unique original image identifiers
Found 661 unique Multi-Otsu segmentation identifiers

Will process 661 image pairs based on common identifiers.




In [17]:
# Process image pairs
# n_pairs is already determined in the previous cell. If n_pairs is 0, the loop will not execute.

for i, identifier in enumerate(common_identifiers):
    orig_path = original_map[identifier]
    seg_path = multiotsu_map[identifier]

    orig_name = Path(orig_path).name
    seg_name = Path(seg_path).name

    print(f"[{i+1}/{n_pairs}] Processing identifier: {identifier}")
    print(f"  Original: {orig_name}")
    print(f"  Multi-Otsu: {seg_name}")

    # Load images
    try:
        original = load_image(orig_path)
        multiotsu_seg = load_image(seg_path)
    except Exception as e:
        print(f"  Error loading images: {e}")
        print(f"  Skipping pair for identifier '{identifier}'")
        print()
        continue

    # Check dimensions match
    if original.shape != multiotsu_seg.shape:
        print(f"  Error: Shape mismatch for identifier '{identifier}' ({original.shape} vs {multiotsu_seg.shape})")
        print(f"  Skipping pair for identifier '{identifier}'")
        print()
        continue

    # Generate weak-supervision mask
    try:
        ws_mask = generate_weak_supervision_mask(original, multiotsu_seg, config)
    except Exception as e:
        print(f"  Error generating mask for identifier '{identifier}': {e}")
        print(f"  Skipping pair for identifier '{identifier}'")
        print()
        continue

    # Save output
    # Use the identifier as the base filename for consistency
    out_name = identifier + ".png"
    out_path = os.path.join(OUTPUT_DIR, out_name)

    try:
        tifffile.imwrite(out_path, ws_mask, compression='deflate')
        print(f"  Saved: {out_name}")
    except Exception as e:
        print(f"  Error saving mask for identifier '{identifier}': {e}")

    print()

print("=" * 60)
print(f"COMPLETED: Processed {n_pairs} image pairs")
print(f"Output directory: {OUTPUT_DIR}")
print("=" * 60)

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  Multi-Otsu: roi_0304_slice0424_norm8_pores_class0plus1.png
  Computing local variance... done
  Computing gradient magnitude... done
  Computing pore↔solid boundary mask... done
  Adaptive variance threshold: 0.065658
  Adaptive gradient threshold: 1.560863
  High variance pixels: 42250 (10.00%)
  High gradient pixels: 42248 (10.00%)
  Pore↔solid boundary pixels: 304621 (72.10%)
  Total ignored pixels: 312633 (74.00%)
  Saved: roi_0304_slice0424_norm8.png

[306/661] Processing identifier: roi_0305_slice0425_norm8
  Original: roi_0305_slice0425_norm8.tif
  Multi-Otsu: roi_0305_slice0425_norm8_pores_class0plus1.png
  Computing local variance... done
  Computing gradient magnitude... done
  Computing pore↔solid boundary mask... done
  Adaptive variance threshold: 0.065155
  Adaptive gradient threshold: 1.558418
  High variance pixels: 42250 (10.00%)
  High gradient pixels: 42249 (10.00%)
  Pore↔solid boundary pixels: 30388