<a href="https://colab.research.google.com/github/balakg/ipy-vision/blob/main/notebooks/textures/heeger_bergen_texture_synthesis.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Heeger-Bergen Texture Synthesis

This notebook demonstrates an implementation of the **Heeger-Bergen algorithm** [1] for texture synthesis. The core goal of texture synthesis is to create a new image that is perceptually identical to a source sample but is not a literal pixel-for-pixel copy.
---


## How it Works

The algorithm operates on the principle that a texture is defined by its **first-order statistics** (the histogram of pixel values) across multiple scales and orientations.

1.  **Color Decorrelation (PCA):** We first transform an RGB image into PCA space to extract three independent channels, preventing color bleeding that would occur using raw RGB channels (due to their high correlations).
2.  **Pyramid Decomposition:** Each channel is decomposed into a **Steerable Pyramid** [2]. This breaks the image down into different spatial frequencies and orientations (e.g., vertical edges vs. horizontal gradients).
3.  **Histogram Matching:** The algorithm iteratively forces the synthesized noise image to match the histogram of the target image at every level and subband of the pyramid.
4.  **Reconstruction:** The pyramid is collapsed back into a spatial image, and the process repeats until the noise transforms into a coherent texture.

[1] Heeger, David J., and James R. Bergen. "Pyramid-based texture analysis/synthesis." In Proceedings of the 22nd annual conference on Computer graphics and interactive techniques, pp. 229-238. 1995.

[2] Simoncelli, Eero P., and William T. Freeman. "The steerable pyramid: A flexible architecture for multi-scale derivative computation." In Proceedings., international conference on image processing, vol. 3, pp. 444-447. IEEE, 1995.

In [None]:
!pip install pyrtools

In [None]:
import os
import glob
import subprocess
import numpy as np
import matplotlib.pyplot as plt
import pyrtools as pt
import cv2
from skimage import exposure
import ipywidgets as widgets
from ipywidgets import interact, IntSlider, Dropdown
from sklearn.decomposition import PCA
from google.colab import output
output.no_vertical_scroll()


# --- 1. Repository & Path Setup ---
if not os.path.exists("ipy-vision"):
    print(f"Cloning ipy-vision...")
    !git clone --depth 1 https://github.com/balakg/ipy-vision.git

# Define assets directory
assets_dir = os.path.join("ipy-vision", "assets", "textures")

# --- 2. Dynamic Image Discovery ---
def get_image_options(directory):
    extensions = ['*.jpg', '*.jpeg', '*.png', '*.JPG', '*.JPEG', '*.PNG']
    found_files = []
    for ext in extensions:
        found_files.extend(glob.glob(os.path.join(directory, ext)))

    # Create dictionary: {"Clean Name": "path/to/file.png"}
    options = {}
    for path in sorted(found_files):
        filename = os.path.basename(path)
        # Remove extension and capitalize for the label
        label = os.path.splitext(filename)[0].replace('_', ' ').replace('-', ' ').capitalize()
        options[label] = path

    return options

# --- 3. Heeger-Bergen Logic ---
def match_histogram(source, template):
    matched = exposure.match_histograms(source.ravel(), template.ravel())
    return matched.reshape(source.shape)

class PCATextureSynthesis:
    def __init__(self, target_rgb, max_iters, height=4, order=3):
        self.shape, self.height, self.order = target_rgb.shape, height, order
        pixels = target_rgb.reshape(-1, 3)
        self.pca = PCA(n_components=3)
        pca_pixels = self.pca.fit_transform(pixels)
        self.target_pca_channels = [pca_pixels[:, i].reshape(self.shape[:2]) for i in range(3)]
        self.target_pyrs = [
            pt.pyramids.SteerablePyramidSpace(ch, height=height, order=order, edge_type='reflect1')
            for ch in self.target_pca_channels
        ]
        init_noise_pca = self.pca.transform(np.random.uniform(0, 1, self.shape).reshape(-1, 3)).reshape(self.shape)
        self.history = [self.pca.inverse_transform(init_noise_pca.reshape(-1, 3)).reshape(self.shape)]
        self._precompute(max_iters, init_noise_pca)

    def _precompute(self, max_iters, current_pca_img):
        curr_ch = [current_pca_img[..., i] for i in range(3)]
        for _ in range(max_iters):
            new_ch = []
            for c in range(3):
                ch = match_histogram(curr_ch[c], self.target_pca_channels[c])
                pyr = pt.pyramids.SteerablePyramidSpace(ch, height=self.height, order=self.order, edge_type='reflect1')
                pyr.pyr_coeffs = {k: match_histogram(v, self.target_pyrs[c].pyr_coeffs[k]) for k, v in pyr.pyr_coeffs.items()}
                new_ch.append(pyr.recon_pyr(edge_type='circular'))
            curr_ch = new_ch
            rgb_res = self.pca.inverse_transform(np.stack(curr_ch, axis=-1).reshape(-1, 3)).reshape(self.shape)
            self.history.append(np.clip(rgb_res, 0, 1))

# --- 4. Launch ---
image_options = get_image_options(assets_dir)
if not image_options:
    print(f"⚠️ No images found in {assets_dir}. Please add some and refresh.")

# We'll use a dictionary to cache engine objects so we don't re-run
# the synthesis if you switch back and forth between images.
cache = {}

# This container will hold our plot output
output_plot = widgets.Output()

def update_display(change=None):
    """Only handles drawing the plots using precomputed data."""
    path = image_path.value
    iters = iterations.value

    # 1. Check if we need to run synthesis for this specific image
    if path not in cache:
        with output_plot:
            output_plot.clear_output(wait=True)
            print(f"Synthesizing {list(image_options.keys())[list(image_options.values()).index(path)]}...")

        img_bgr = cv2.imread(path)
        target_img = np.clip(cv2.resize(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB), (256, 256)) / 255.0, 0, 1)
        # Run precompute once
        cache[path] = (target_img, PCATextureSynthesis(target_img, max_iters=6))

    target_img, engine = cache[path]

    # 2. Update the plot using the cached history (this is near-instant)
    with output_plot:
        output_plot.clear_output(wait=True)
        fig, ax = plt.subplots(1, 2, figsize=(10, 5))
        ax[0].imshow(target_img)
        ax[0].set_title("Sample Texture")
        ax[1].imshow(engine.history[iters])
        ax[1].set_title(f"Synthesized (Iteration {iters})")
        for a in ax: a.axis('off')
        plt.show()

# --- Widget Setup -- #
title_widget = widgets.HTML("""
<h2 style="color: #ffffff; margin-top: 0px; margin-bottom: 5px; font-family: sans-serif;">
    Heeger-Bergen Texture Synthesis Demo
</h2>
<p style="color: #7f8c8d; margin-top: 0px; font-family: sans-serif; margin-bottom: 15px;">
    Adjust slider to visualize synthesized output based on iteration.
</p>
""")


image_path = widgets.Dropdown(options=image_options, description='Texture:')
iterations = widgets.IntSlider(min=0, max=6, step=1, value=0, description='Iteration:')

# Observe changes: image_path triggers synthesis, iterations just triggers redraw
image_path.observe(update_display, names='value')
iterations.observe(update_display, names='value')

# Initial call to populate the plot
update_display()

# Layout
controls = widgets.VBox([title_widget, image_path, iterations], layout=widgets.Layout(width='fit-content', border='1px solid gray', padding='10px'))
dashboard = widgets.VBox([controls, output_plot], layout=widgets.Layout(width='fit-content', padding='10px'))

display(dashboard)