[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bobleesj/quantem.widget/blob/main/notebooks/show4d/show4d_simple.ipynb)

# Show4D — Quick Demo

Synthetic nanoparticle sample: ~200 crystalline particles with 10 distinct zone-axis
orientations scattered on an amorphous substrate (128×128 scan, 128×128 detector).
Click on a nanoparticle to see its zone-axis diffraction pattern with Bragg spots,
Kikuchi bands, and HOLZ ring. Click on the background to see diffuse amorphous rings.
Particles vary in size and cluster density — some isolated, some aggregated.

In [6]:
%load_ext autoreload
%autoreload 2
%env ANYWIDGET_HMR=1

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
env: ANYWIDGET_HMR=1


In [7]:
import numpy as np
import torch

device = torch.device(
    "mps" if torch.backends.mps.is_available() else
    "cuda" if torch.cuda.is_available() else "cpu"
)
print(f"Using device: {device}")


def make_nanoparticle_sample(nav=128, det=128, n_particles=200, n_orientations=10):
    """Nanoparticle sample: crystalline particles on amorphous substrate.
    Fully vectorized GPU-accelerated computation — no Python loops over pixels."""
    rng = np.random.default_rng(42)
    c = det // 2

    # Detector grid on GPU
    yy, xx = torch.meshgrid(
        torch.arange(det, device=device, dtype=torch.float32),
        torch.arange(det, device=device, dtype=torch.float32), indexing="ij",
    )
    r = torch.sqrt((xx - c) ** 2 + (yy - c) ** 2)

    # ── Vectorized particle map (no loop over particles) ──
    radii = rng.lognormal(mean=1.5, sigma=0.5, size=n_particles).clip(2, 15).astype(np.float32)
    centers_row = (rng.beta(2, 1.2, n_particles) * (nav - 4) + 2).astype(np.float32)
    centers_col = rng.uniform(2, nav - 2, n_particles).astype(np.float32)
    orientations_arr = rng.integers(0, n_orientations, size=n_particles)

    # Broadcasting: (n_particles, nav, nav) — all distances at once
    ii = np.arange(nav, dtype=np.float32)[:, None]
    jj = np.arange(nav, dtype=np.float32)[None, :]
    dist_sq = ((ii[None] - centers_row[:, None, None]) ** 2 +
               (jj[None] - centers_col[:, None, None]) ** 2)
    inside = dist_sq <= radii[:, None, None] ** 2

    # "Last particle wins" — find highest-index particle covering each pixel
    particle_map = np.full((nav, nav), -1, dtype=np.int32)
    any_inside = inside.any(axis=0)
    last_idx = n_particles - 1 - np.argmax(inside[::-1], axis=0)
    particle_map[any_inside] = orientations_arr[last_idx[any_inside]]

    crystalline_frac = np.sum(particle_map >= 0) / particle_map.size
    print(f"Particle coverage: {crystalline_frac:.1%}")

    # ── Crystallographic properties on GPU ──
    grain_rot = torch.tensor(rng.uniform(0, np.pi, n_orientations), device=device, dtype=torch.float32)
    grain_a1 = torch.tensor(rng.uniform(11, 22, n_orientations), device=device, dtype=torch.float32)
    grain_a2 = torch.tensor(rng.uniform(11, 22, n_orientations), device=device, dtype=torch.float32)
    grain_angle_rad = torch.deg2rad(torch.tensor(rng.uniform(55, 125, n_orientations), device=device, dtype=torch.float32))

    # ── All (h,k) pairs except (0,0) ──
    hk_list = [(h, k) for h in range(-4, 5) for k in range(-4, 5) if not (h == 0 and k == 0)]
    hk = torch.tensor(hk_list, device=device, dtype=torch.float32)  # (80, 2)

    # Kikuchi band directions
    dh_arr = torch.tensor([1, 0, 1], device=device, dtype=torch.float32)
    dk_arr = torch.tensor([0, 1, 1], device=device, dtype=torch.float32)
    s_arr = torch.tensor([0.018, 0.014, 0.009], device=device, dtype=torch.float32)

    # ── Amorphous template ──
    amorphous_t = (0.02 * torch.exp(-r / 45) + 0.12 * torch.exp(-((r - 15) ** 2) / 50)
                 + 0.06 * torch.exp(-((r - 32) ** 2) / 80) + 0.03 * torch.exp(-((r - 48) ** 2) / 100))
    amorphous_np = amorphous_t.cpu().numpy().astype(np.float32)

    # ── Per-orientation templates (vectorized over spots within each orientation) ──
    templates = np.zeros((n_orientations, det, det), dtype=np.float32)
    for o in range(n_orientations):
        rot = grain_rot[o]
        a1m, a2m = grain_a1[o], grain_a2[o]
        angle = grain_angle_rad[o]

        # Reciprocal lattice vectors
        a1x, a1y = a1m * torch.cos(rot), a1m * torch.sin(rot)
        a2x, a2y = a2m * torch.cos(rot + angle), a2m * torch.sin(rot + angle)

        # Background + central beam
        dp = 0.03 * torch.exp(-r / 50) + torch.clamp(1.0 - torch.clamp(r - 7, min=0) / 1.5, min=0, max=1)

        # ── Vectorized Bragg spots: all 80 (h,k) at once ──
        spot_x = c + hk[:, 0] * a1x + hk[:, 1] * a2x  # (80,)
        spot_y = c + hk[:, 0] * a1y + hk[:, 1] * a2y
        valid = (spot_x > -5) & (spot_x < det + 5) & (spot_y > -5) & (spot_y < det + 5)
        sx, sy = spot_x[valid], spot_y[valid]
        h_v, k_v = hk[valid, 0], hk[valid, 1]

        # Structure factors + radial falloff: (n_valid,)
        f = torch.where((h_v + k_v) % 2 == 0, 0.6, 0.07)
        g_sq = (sx - c) ** 2 + (sy - c) ** 2

        # All Gaussians at once: (n_valid, det, det) → sum → (det, det)
        dx = xx[None] - sx[:, None, None]
        dy = yy[None] - sy[:, None, None]
        dp = dp + (f[:, None, None] * torch.exp(-g_sq[:, None, None] / 5500)
                   * torch.exp(-(dx ** 2 + dy ** 2) / 6.5)).sum(dim=0)

        # ── Vectorized Kikuchi bands: all 3 at once ──
        gx = dh_arr * a1x + dk_arr * a2x  # (3,)
        gy = dh_arr * a1y + dk_arr * a2y
        g_len = torch.sqrt(gx ** 2 + gy ** 2)
        band_valid = g_len >= 1
        if band_valid.any():
            gx_v, gy_v = gx[band_valid], gy[band_valid]
            g_len_v, s_v = g_len[band_valid], s_arr[band_valid]
            perp = ((xx[None] - c) * (-gy_v[:, None, None])
                    + (yy[None] - c) * gx_v[:, None, None]) / g_len_v[:, None, None]
            half_g = g_len_v[:, None, None] / 2
            band = torch.exp(-((perp - half_g) ** 2) / 16) + torch.exp(-((perp + half_g) ** 2) / 16)
            dp = dp + (s_v[:, None, None] * band * torch.exp(-r / 55)[None]).sum(dim=0)

        # HOLZ ring
        dp = dp + 0.02 * torch.exp(-((r - 50) ** 2) / 4.5)
        templates[o] = dp.cpu().numpy()

    # ── Assign patterns (batched, memory-bounded) ──
    data = np.zeros((nav, nav, det, det), dtype=np.float32)
    BATCH = 2048

    # Amorphous positions
    amorphous_idx = np.argwhere(particle_map == -1)
    for start in range(0, len(amorphous_idx), BATCH):
        idx = amorphous_idx[start:start + BATCH]
        n = len(idx)
        batch = np.empty((n, det, det), dtype=np.float32)
        batch[:] = amorphous_np
        batch += 0.008 * rng.standard_normal((n, det, det)).astype(np.float32)
        np.maximum(batch, 0, out=batch)
        data[idx[:, 0], idx[:, 1]] = rng.poisson(np.clip(batch * 300, 0, 1e6)).astype(np.float32) / 300

    # Crystalline positions
    for o in range(n_orientations):
        o_idx = np.argwhere(particle_map == o)
        if len(o_idx) == 0:
            continue
        t_vals = (0.5 + 0.5 * rng.random(len(o_idx))).astype(np.float32)
        for start in range(0, len(o_idx), BATCH):
            idx = o_idx[start:start + BATCH]
            t_batch = t_vals[start:start + BATCH]
            scaled = templates[o][None, :, :] * t_batch[:, None, None]
            np.maximum(scaled, 0, out=scaled)
            data[idx[:, 0], idx[:, 1]] = rng.poisson(np.clip(scaled * 400, 0, 1e6)).astype(np.float32) / 400

    return data, particle_map


data, particle_map = make_nanoparticle_sample()
print(f"Shape: {data.shape}, dtype: {data.dtype}")
print(f"Range: [{data.min():.3f}, {data.max():.3f}]")
print(f"Memory: {data.nbytes / 1e6:.0f} MB")
print(f"Orientations: {len(np.unique(particle_map[particle_map >= 0]))}")

Using device: mps
Particle coverage: 57.9%
Shape: (128, 128, 128, 128), dtype: float32
Range: [0.000, 1.237]
Memory: 1074 MB
Orientations: 10


In [8]:
from quantem.widget import Show4D

w = Show4D(
    data,
    title="Nanoparticle Diffraction",
    nav_pixel_size=2.39,
    nav_pixel_unit="\u00c5",
    sig_pixel_size=0.46,
    sig_pixel_unit="mrad",
)       

In [9]:
w

Show4D(shape=(128, 128, 128, 128), pos=(64, 64))

## Inspect Widget State

In [5]:
w.summary()

Nanoparticle Diffraction
════════════════════════════════
Nav:      128×128 (2.39 Å/px)
Signal:   128×128 (0.46 mrad/px)
Position: (64, 64)
Display:  inferno | auto contrast | linear
