# Raw Mix vs Normalized Mix (Demonstration Only)

This notebook contrasts naive algebraic mixing with the normalized mixing pipeline used in this project. The goal is to make it visually and quantitatively clear why normalization is required for EBSD pattern mixing.

**WARNING**

The raw-mix result shown here is **non-physical** and **not representative** of experimental EBSD mixing. It is included **only** for demonstration. **Do not use** the raw-mix output for training, validation, or any production workflow.

## Purpose

We compare:

- **Raw algebraic mix (no normalization):** `C_raw = x*A + y*B`
- **Normalized mix (project pipeline):** `C_norm = normalize(x*A + y*B)`

The normalized mix uses the same normalization logic as the main pipeline (mask-aware, min-max normalization).

## Raw-mix implementation detail (demo only)

For the non-physical demonstration, inputs are converted to **8-bit (0-255)** and mixed using **integer arithmetic only**. Weights are quantized to 8-bit integers, and the mix is computed and clipped in uint8. Any float conversion happens **only after** mixing for visualization.

In [None]:
from __future__ import annotations

from pathlib import Path
import sys

import numpy as np
import matplotlib.pyplot as plt

try:
    import ipywidgets as widgets
    from IPython.display import display
    HAS_WIDGETS = True
except ImportError:
    HAS_WIDGETS = False

try:
    import pandas as pd
    HAS_PANDAS = True
except ImportError:
    HAS_PANDAS = False


def find_repo_root(start: Path) -> Path:
    for path in [start, *start.parents]:
        if (path / 'src').exists() and (path / 'data').exists():
            return path
    return start


REPO_ROOT = find_repo_root(Path.cwd())
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

from src.utils.io import read_image_16bit, to_float01
from src.preprocessing.mask import build_mask_with_metadata, apply_circular_mask
from src.preprocessing.normalise import normalize_image
from src.generation.mix import mix_then_normalize, mix_normalize_then_mix

DEFAULT_DIR = REPO_ROOT / 'data' / 'raw' / 'Double Pattern Data' / 'Good Pattern'
DEFAULT_A = DEFAULT_DIR / 'Perfect_BCC-1.bmp'
DEFAULT_B = DEFAULT_DIR / 'Perfect_FCC-1.bmp'

WEIGHT_A = 0.50
NORMALIZED_PIPELINE = 'mix_then_normalize'
NORMALIZE_METHOD = 'min_max'

print(f'Repo root: {REPO_ROOT}')
print(f'Default A: {DEFAULT_A}')
print(f'Default B: {DEFAULT_B}')


In [None]:
def list_candidates(directory: Path) -> list[Path]:
    if not directory.exists():
        return []
    return sorted([p for p in directory.iterdir() if p.suffix.lower() in {'.bmp', '.png', '.tif', '.tiff', '.jpg', '.jpeg'}])


def select_default_paths() -> tuple[Path, Path]:
    if DEFAULT_A.exists() and DEFAULT_B.exists():
        return DEFAULT_A, DEFAULT_B
    candidates = list_candidates(DEFAULT_DIR)
    if len(candidates) >= 2:
        return candidates[0], candidates[1]
    raise FileNotFoundError('No suitable A/B inputs found. Update A_PATH and B_PATH.')


def load_image_u16(path: Path) -> np.ndarray:
    return read_image_16bit(path)


def to_uint8(image_u16: np.ndarray) -> np.ndarray:
    scaled = (image_u16.astype(np.uint32) * 255 + 32767) // 65535
    return scaled.astype(np.uint8)


def to_float(image_u16: np.ndarray) -> np.ndarray:
    return to_float01(image_u16).astype(np.float32)


def build_mask(image: np.ndarray) -> np.ndarray:
    mask, meta = build_mask_with_metadata(
        image, detect_existing=True, zero_tolerance=5e-4, outside_zero_fraction=0.98
    )
    print('Mask detected:', meta.get('already_masked'))
    print('Outside-zero fraction:', meta.get('outside_zero_fraction'))
    return mask


def image_stats(image: np.ndarray, mask: np.ndarray | None) -> dict[str, float]:
    data = image[mask] if mask is not None else image.ravel()
    p1, p99 = np.percentile(data, [1, 99])
    return {
        'min': float(data.min()),
        'max': float(data.max()),
        'mean': float(data.mean()),
        'std': float(data.std()),
        'p1': float(p1),
        'p99': float(p99),
        'range': float(data.max() - data.min()),
        'p99-p1': float(p99 - p1),
    }


def saturation_stats(image_u8: np.ndarray, mask: np.ndarray | None) -> dict[str, float]:
    data = image_u8[mask] if mask is not None else image_u8.ravel()
    total = max(int(data.size), 1)
    return {
        'zero_frac': float((data == 0).sum()) / total,
        'sat_255_frac': float((data == 255).sum()) / total,
    }


def raw_mix_uint8(image_a_u8: np.ndarray, image_b_u8: np.ndarray, weight_a: float) -> np.ndarray:
    weight_a_int = int(round(weight_a * 255))
    weight_b_int = 255 - weight_a_int
    mixed = (
        weight_a_int * image_a_u8.astype(np.uint32)
        + weight_b_int * image_b_u8.astype(np.uint32)
    )
    mixed = (mixed + 127) // 255
    return np.clip(mixed, 0, 255).astype(np.uint8)


def normalized_mix(
    image_a: np.ndarray,
    image_b: np.ndarray,
    weight_a: float,
    pipeline: str,
    mask: np.ndarray | None,
    normalize_method: str,
) -> np.ndarray:
    if pipeline == 'mix_then_normalize':
        mixed = mix_then_normalize(
            image_a,
            image_b,
            weight_a=weight_a,
            normalize_method=normalize_method,
            mask=mask,
            normalize_smart=True,
        )
    elif pipeline == 'normalize_then_mix':
        image_a_norm = normalize_image(
            image_a, method=normalize_method, mask=mask, smart_minmax=True
        )
        image_b_norm = normalize_image(
            image_b, method=normalize_method, mask=mask, smart_minmax=True
        )
        mixed = mix_normalize_then_mix(
            image_a_norm,
            image_b_norm,
            weight_a=weight_a,
            normalize_after_mix=False,
            normalize_method=normalize_method,
            mask=mask,
            normalize_smart=True,
        )
    else:
        raise ValueError(f'Unknown pipeline: {pipeline}')
    if mask is not None:
        mixed = apply_circular_mask(mixed, mask)
    return np.clip(mixed, 0.0, 1.0).astype(np.float32)


def plot_comparison(
    image_a: np.ndarray,
    image_b: np.ndarray,
    c_raw_u8: np.ndarray,
    c_norm: np.ndarray,
    mask: np.ndarray | None,
    title: str,
) -> None:
    c_raw = c_raw_u8.astype(np.float32) / 255.0

    fig, axes = plt.subplots(2, 2, figsize=(10, 8))
    axes = axes.ravel()
    axes[0].imshow(image_a, cmap='gray', vmin=0, vmax=1)
    axes[0].set_title('A (input)')
    axes[0].axis('off')

    axes[1].imshow(image_b, cmap='gray', vmin=0, vmax=1)
    axes[1].set_title('B (input)')
    axes[1].axis('off')

    axes[2].imshow(c_raw, cmap='gray', vmin=0, vmax=1)
    axes[2].set_title('C_raw (8-bit, no normalization)')
    axes[2].axis('off')

    axes[3].imshow(c_norm, cmap='gray', vmin=0, vmax=1)
    axes[3].set_title('C_norm (normalized)')
    axes[3].axis('off')

    fig.suptitle(title)
    fig.tight_layout()
    plt.show()

    diff = c_norm - c_raw
    plt.figure(figsize=(6, 5))
    plt.imshow(diff, cmap='coolwarm', vmin=-0.5, vmax=0.5)
    plt.title('Difference (C_norm - C_raw)')
    plt.axis('off')
    plt.colorbar()
    plt.show()

    if mask is not None:
        raw_data = c_raw[mask]
        norm_data = c_norm[mask]
    else:
        raw_data = c_raw.ravel()
        norm_data = c_norm.ravel()
    plt.figure(figsize=(6, 4))
    plt.hist(raw_data, bins=100, alpha=0.6, label='C_raw (scaled)')
    plt.hist(norm_data, bins=100, alpha=0.6, label='C_norm')
    plt.title('Histogram inside mask')
    plt.xlabel('Intensity')
    plt.ylabel('Count')
    plt.legend()
    plt.show()

    stats_raw = image_stats(c_raw, mask)
    stats_raw.update(saturation_stats(c_raw_u8, mask))
    stats_norm = image_stats(c_norm, mask)
    stats_norm.update({
        'near_zero_frac': float((norm_data < 0.01).mean()),
        'near_one_frac': float((norm_data > 0.99).mean()),
    })
    if HAS_PANDAS:
        df = pd.DataFrame([stats_raw, stats_norm], index=['C_raw_u8', 'C_norm'])
        display(df)
    else:
        print('C_raw stats:', stats_raw)
        print('C_norm stats:', stats_norm)


In [None]:
A_PATH, B_PATH = select_default_paths()

image_a_u16 = load_image_u16(A_PATH)
image_b_u16 = load_image_u16(B_PATH)
image_a = to_float(image_a_u16)
image_b = to_float(image_b_u16)
mask = build_mask(image_a)

image_a = apply_circular_mask(image_a, mask)
image_b = apply_circular_mask(image_b, mask)

image_a_u8 = apply_circular_mask(to_uint8(image_a_u16), mask)
image_b_u8 = apply_circular_mask(to_uint8(image_b_u16), mask)

c_raw_u8 = raw_mix_uint8(image_a_u8, image_b_u8, WEIGHT_A)
c_norm = normalized_mix(
    image_a, image_b, WEIGHT_A, NORMALIZED_PIPELINE, mask, NORMALIZE_METHOD
)

title = f'Raw mix vs normalized mix (x={WEIGHT_A:.2f}, pipeline={NORMALIZED_PIPELINE})'
plot_comparison(image_a, image_b, c_raw_u8, c_norm, mask, title)


## Interpretation

- **C_raw** is computed in uint8 with integer arithmetic only and no normalization. It often shows
  compressed dynamic range or saturation, which is non-physical for EBSD mixing.
- **C_norm** stretches intensities based on the active (masked) region, matching the project
  mixing assumptions.
- The difference map and histogram make the mismatch in intensity distribution explicit.

**Reminder:** The raw-mix output is non-physical and should never be used for training, validation,
or production.

In [None]:
if HAS_WIDGETS:
    candidates = list_candidates(DEFAULT_DIR)
    labels = [p.name for p in candidates] if candidates else []
    default_a = labels[0] if labels else None
    default_b = labels[1] if len(labels) > 1 else default_a

    a_dropdown = widgets.Dropdown(options=labels, value=default_a, description='A')
    b_dropdown = widgets.Dropdown(options=labels, value=default_b, description='B')
    weight_slider = widgets.FloatSlider(value=WEIGHT_A, min=0.0, max=1.0, step=0.05, description='x')
    pipeline_dropdown = widgets.Dropdown(
        options=['mix_then_normalize', 'normalize_then_mix'],
        value=NORMALIZED_PIPELINE,
        description='pipeline',
    )

    def _update(a_name: str, b_name: str, weight: float, pipeline: str) -> None:
        if not labels:
            print('No candidates found. Update DEFAULT_DIR or paths manually.')
            return
        a_path = DEFAULT_DIR / a_name
        b_path = DEFAULT_DIR / b_name
        image_a_u16 = load_image_u16(a_path)
        image_b_u16 = load_image_u16(b_path)
        image_a = to_float(image_a_u16)
        image_b = to_float(image_b_u16)
        mask_local = mask
        if mask_local is None or mask_local.shape != image_a.shape:
            mask_local = build_mask(image_a)
        image_a = apply_circular_mask(image_a, mask_local)
        image_b = apply_circular_mask(image_b, mask_local)
        image_a_u8 = apply_circular_mask(to_uint8(image_a_u16), mask_local)
        image_b_u8 = apply_circular_mask(to_uint8(image_b_u16), mask_local)
        c_raw_u8 = raw_mix_uint8(image_a_u8, image_b_u8, weight)
        c_norm = normalized_mix(image_a, image_b, weight, pipeline, mask_local, NORMALIZE_METHOD)
        title = f'Raw mix vs normalized mix (x={weight:.2f}, pipeline={pipeline})'
        plot_comparison(image_a, image_b, c_raw_u8, c_norm, mask_local, title)

    ui = widgets.VBox([a_dropdown, b_dropdown, weight_slider, pipeline_dropdown])
    out = widgets.interactive_output(
        _update,
        {
            'a_name': a_dropdown,
            'b_name': b_dropdown,
            'weight': weight_slider,
            'pipeline': pipeline_dropdown,
        },
    )
    display(ui, out)
else:
    print('ipywidgets is not installed. Set WEIGHT_A, A_PATH, B_PATH manually and re-run the cells above.')
