[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bobleesj/quantem.widget/blob/main/notebooks/show4d/show4d_simple.ipynb)

# Show4D — Quick Demo

Synthetic nanoparticle sample: ~200 crystalline particles with 10 distinct zone-axis
orientations scattered on an amorphous substrate (128×128 scan, 128×128 detector).
Click on a nanoparticle to see its zone-axis diffraction pattern with Bragg spots,
Kikuchi bands, and HOLZ ring. Click on the background to see diffuse amorphous rings.
Particles vary in size and cluster density — some isolated, some aggregated.

In [1]:
%load_ext autoreload
%autoreload 2
%env ANYWIDGET_HMR=1

env: ANYWIDGET_HMR=1


In [None]:
import math

import numpy as np
import torch

device = torch.device(
    "mps" if torch.backends.mps.is_available() else
    "cuda" if torch.cuda.is_available() else "cpu"
)
print(f"Using device: {device}")


def make_nanoparticle_sample(nav=128, det=128, n_particles=200, n_orientations=10):
    """Nanoparticle sample: crystalline particles on amorphous substrate.
    GPU-accelerated template computation via PyTorch, batched assignment."""
    rng = np.random.default_rng(42)
    c = det // 2

    # Detector grid on GPU
    yy, xx = torch.meshgrid(
        torch.arange(det, device=device, dtype=torch.float32),
        torch.arange(det, device=device, dtype=torch.float32), indexing="ij",
    )
    r = torch.sqrt((xx - c) ** 2 + (yy - c) ** 2)

    # Build particle map on CPU (small, integer operations)
    particle_map = np.full((nav, nav), -1, dtype=int)
    radii = rng.lognormal(mean=1.5, sigma=0.5, size=n_particles).clip(2, 15)
    centers_x = rng.beta(2, 1.2, n_particles) * (nav - 4) + 2
    centers_y = rng.uniform(2, nav - 2, n_particles)
    orientations_arr = rng.integers(0, n_orientations, size=n_particles)

    ii, jj = np.ogrid[:nav, :nav]
    for p in range(n_particles):
        mask = (ii - centers_x[p]) ** 2 + (jj - centers_y[p]) ** 2 <= radii[p] ** 2
        particle_map[mask] = orientations_arr[p]

    crystalline_frac = np.sum(particle_map >= 0) / particle_map.size
    print(f"Particle coverage: {crystalline_frac:.1%}")

    # Per-orientation crystallographic properties
    grain_rot = rng.uniform(0, np.pi, n_orientations)
    grain_a1 = rng.uniform(11, 22, n_orientations)
    grain_a2 = rng.uniform(11, 22, n_orientations)
    grain_angle = rng.uniform(55, 125, n_orientations)

    # Precompute amorphous pattern on GPU → CPU
    amorphous_t = (0.02 * torch.exp(-r / 45) + 0.12 * torch.exp(-((r - 15) ** 2) / 50)
                 + 0.06 * torch.exp(-((r - 32) ** 2) / 80) + 0.03 * torch.exp(-((r - 48) ** 2) / 100))
    amorphous_np = amorphous_t.cpu().numpy().astype(np.float32)

    # Precompute per-orientation diffraction templates on GPU → CPU
    templates = np.zeros((n_orientations, det, det), dtype=np.float32)
    for o in range(n_orientations):
        rot = grain_rot[o]
        a1m, a2m = grain_a1[o], grain_a2[o]
        angle = np.radians(grain_angle[o])
        a1x, a1y = a1m * np.cos(rot), a1m * np.sin(rot)
        a2x, a2y = a2m * np.cos(rot + angle), a2m * np.sin(rot + angle)

        dp = 0.03 * torch.exp(-r / 50)
        dp = dp + torch.clamp(1.0 - torch.clamp(r - 7, min=0) / 1.5, min=0, max=1)

        for h in range(-4, 5):
            for k in range(-4, 5):
                if h == 0 and k == 0:
                    continue
                sx = c + h * a1x + k * a2x
                sy = c + h * a1y + k * a2y
                if not (-5 < sx < det + 5 and -5 < sy < det + 5):
                    continue
                f = 0.6 if (h + k) % 2 == 0 else 0.07
                g_sq = (sx - c) ** 2 + (sy - c) ** 2
                dp = dp + f * math.exp(-g_sq / 5500) * torch.exp(-((xx - sx) ** 2 + (yy - sy) ** 2) / 6.5)

        for dh, dk, s in [(1, 0, 0.018), (0, 1, 0.014), (1, 1, 0.009)]:
            gx = dh * a1x + dk * a2x
            gy = dh * a1y + dk * a2y
            g_len = math.sqrt(gx ** 2 + gy ** 2)
            if g_len < 1:
                continue
            perp = ((xx - c) * (-gy) + (yy - c) * gx) / g_len
            band = torch.exp(-((perp - g_len / 2) ** 2) / 16) + torch.exp(-((perp + g_len / 2) ** 2) / 16)
            dp = dp + s * band * torch.exp(-r / 55)

        dp = dp + 0.02 * torch.exp(-((r - 50) ** 2) / 4.5)
        templates[o] = dp.cpu().numpy()

    # Allocate output on CPU (1 GB)
    data = np.zeros((nav, nav, det, det), dtype=np.float32)
    BATCH = 2048

    # Amorphous positions (batched)
    amorphous_idx = np.argwhere(particle_map == -1)
    for start in range(0, len(amorphous_idx), BATCH):
        idx = amorphous_idx[start:start + BATCH]
        n = len(idx)
        batch = np.empty((n, det, det), dtype=np.float32)
        batch[:] = amorphous_np
        batch += 0.008 * rng.standard_normal((n, det, det)).astype(np.float32)
        np.maximum(batch, 0, out=batch)
        data[idx[:, 0], idx[:, 1]] = rng.poisson(np.clip(batch * 300, 0, 1e6)).astype(np.float32) / 300

    # Crystalline positions (batched per orientation)
    for o in range(n_orientations):
        o_idx = np.argwhere(particle_map == o)
        if len(o_idx) == 0:
            continue
        t_vals = (0.5 + 0.5 * rng.random(len(o_idx))).astype(np.float32)
        for start in range(0, len(o_idx), BATCH):
            idx = o_idx[start:start + BATCH]
            t_batch = t_vals[start:start + BATCH]
            scaled = templates[o][None, :, :] * t_batch[:, None, None]
            np.maximum(scaled, 0, out=scaled)
            data[idx[:, 0], idx[:, 1]] = rng.poisson(np.clip(scaled * 400, 0, 1e6)).astype(np.float32) / 400

    return data, particle_map


data, particle_map = make_nanoparticle_sample()
print(f"Shape: {data.shape}, dtype: {data.dtype}")
print(f"Range: [{data.min():.3f}, {data.max():.3f}]")
print(f"Memory: {data.nbytes / 1e6:.0f} MB")
print(f"Orientations: {len(np.unique(particle_map[particle_map >= 0]))}")

In [None]:
from quantem.widget import Show4D

w = Show4D(
    data,
    title="Nanoparticle Diffraction",
    nav_pixel_size=2.39,
    nav_pixel_unit="\u00c5",
    sig_pixel_size=0.46,
    sig_pixel_unit="mrad",
)
w