In [4]:
import numpy as np
from collections import defaultdict

# -----------------------------
# Step 1: Initialize regions
# -----------------------------
def initialize_regions(image):
    h, w = image.shape
    # region_labels: each pixel initially its own region
    labels = np.arange(h * w).reshape(h, w)
    # region_means: store average intensity (initially pixel value)
    means = image.astype(np.float64).flatten()
    return labels, means

# -----------------------------
# Step 2: Compute connection numbers
# -----------------------------
def compute_connection_numbers(labels):
    h, w = labels.shape
    conn = defaultdict(int)
    # Loop over neighbors (4-connectivity)
    for y in range(h):
        for x in range(w):
            current = labels[y, x]
            # right neighbor
            if x + 1 < w:
                neighbor = labels[y, x + 1]
                if neighbor != current:
                    conn[frozenset([current, neighbor])] += 1
            # bottom neighbor
            if y + 1 < h:
                neighbor = labels[y + 1, x]
                if neighbor != current:
                    conn[frozenset([current, neighbor])] += 1
    return conn

# -----------------------------
# Step 3: Compute fusion cost
# (Very simplified version of paper's Equation 12)
# -----------------------------
def fusion_cost(mean_i, mean_j, size_i, size_j, beta, c_ij):
    merged_mean = (mean_i * size_i + mean_j * size_j) / (size_i + size_j)
    # Data term: error from merging
    data_term = ((mean_i - merged_mean)**2 * size_i +
                 (mean_j - merged_mean)**2 * size_j)
    # Regularization term: L0 gradient penalty
    regularization_term = beta * c_ij
    return data_term - regularization_term

# -----------------------------
# Step 4: Main region fusion loop
# -----------------------------
def region_fusion(image, beta=50.0, max_iterations=20):
    h, w = image.shape
    labels, means = initialize_regions(image)
    conn = compute_connection_numbers(labels)

    # region sizes
    sizes = np.ones(h * w, dtype=np.int32)

    for iteration in range(max_iterations):
        merged = False
        # Go through all neighboring pairs
        for pair in list(conn.keys()):
            i, j = list(pair)
            c_ij = conn[pair]
            if c_ij == 0:
                continue
            cost = fusion_cost(means[i], means[j], sizes[i], sizes[j], beta, c_ij)
            if cost < 0:
                # Merge j into i (arbitrarily)
                labels[labels == j] = i
                # Update mean and size
                new_size = sizes[i] + sizes[j]
                new_mean = (means[i] * sizes[i] + means[j] * sizes[j]) / new_size
                sizes[i] = new_size
                means[i] = new_mean
                sizes[j] = 0
                means[j] = 0
                merged = True

        # Recompute connection numbers after merges
        conn = compute_connection_numbers(labels)
        if not merged:
            break

    # Final reconstruction
    output = np.zeros_like(image, dtype=np.float64)
    for region_id in np.unique(labels):
        output[labels == region_id] = means[region_id]

    return output
"""
# -----------------------------
# Example usage
# -----------------------------
if __name__ == "__main__":
    # Load a grayscale image
    img = cv2.imread("example.jpg", cv2.IMREAD_GRAYSCALE)
    result = region_fusion(img, beta=30.0, max_iterations=30)

    # Display results
    cv2.imshow("Original", img)
    cv2.imshow("Region Fusion Smoothed", result.astype(np.uint8))
    cv2.waitKey(0)
    cv2.destroyAllWindows()
"""

'\n# -----------------------------\n# Example usage\n# -----------------------------\nif __name__ == "__main__":\n    # Load a grayscale image\n    img = cv2.imread("example.jpg", cv2.IMREAD_GRAYSCALE)\n    result = region_fusion(img, beta=30.0, max_iterations=30)\n\n    # Display results\n    cv2.imshow("Original", img)\n    cv2.imshow("Region Fusion Smoothed", result.astype(np.uint8))\n    cv2.waitKey(0)\n    cv2.destroyAllWindows()\n'