[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bobleesj/quantem.widget/blob/main/notebooks/show4dstem/show4dstem_simple.ipynb)

# Show4DSTEM — Quick Demo

Synthetic 4D-STEM dataset with a bright-field disk, six first-order Bragg reflections,
six second-order spots, and scan-position-dependent intensity variation. Data generated
with PyTorch (GPU-accelerated on MPS/CUDA) for realistic vectorized simulation.

For **5D time/tilt series** support, see `show4dstem_5d.ipynb`.

In [1]:
# Install in Google Colab
try:
    import google.colab
    !pip install -q -i https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ quantem-widget
except ImportError:
    pass  # Not in Colab, skip

In [2]:
try:
    %load_ext autoreload
    %autoreload 2
    %env ANYWIDGET_HMR=1
except Exception:
    pass  # autoreload unavailable (Colab Python 3.12+)

env: ANYWIDGET_HMR=1


In [3]:
import torch
import numpy as np
from quantem.widget import Show4DSTEM

device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


def make_4dstem(scan_rows=16, scan_cols=16, det_rows=64, det_cols=64):
    """4D-STEM dataset with BF disk, Bragg spots, and amorphous background (PyTorch)."""
    # Detector coordinate grids
    dr = torch.arange(det_rows, device=device, dtype=torch.float32)
    dc = torch.arange(det_cols, device=device, dtype=torch.float32)
    rr, cc = torch.meshgrid(dr, dc, indexing="ij")  # (det_rows, det_cols)
    cr, cc0 = det_rows / 2, det_cols / 2
    center_dist = ((rr - cr) ** 2 + (cc - cc0) ** 2).sqrt()

    # Amorphous background (radial falloff) — same for all positions
    bg = 0.05 * torch.exp(-center_dist / 30)

    # BF disk (sharp circular edge with internal modulation)
    bf = (center_dist < 8).float() * (1.0 + 0.2 * torch.cos(center_dist * 0.5))

    # 6 first-order Bragg spots — precompute positions
    spots = torch.zeros(det_rows, det_cols, device=device)
    for k in range(6):
        angle = k * torch.pi / 3
        sr = cr + 20 * torch.sin(torch.tensor(angle, device=device))
        sc = cc0 + 20 * torch.cos(torch.tensor(angle, device=device))
        d2 = (rr - sr) ** 2 + (cc - sc) ** 2
        spots += 0.4 * torch.exp(-d2 / (2 * 2.5**2))

    # 6 second-order spots (weaker, at larger radius)
    for k in range(6):
        angle = k * torch.pi / 3 + torch.pi / 6
        sr = cr + 35 * torch.sin(torch.tensor(angle, device=device))
        sc = cc0 + 35 * torch.cos(torch.tensor(angle, device=device))
        d2 = (rr - sr) ** 2 + (cc - sc) ** 2
        spots += 0.1 * torch.exp(-d2 / (2 * 2.0**2))

    # Base pattern: (det_rows, det_cols)
    base = bg + bf + spots  # (det_rows, det_cols)

    # Scan-position-dependent modulation via broadcasting
    # Simulates thickness/orientation variation across the sample
    si = torch.arange(scan_rows, device=device, dtype=torch.float32)
    sj = torch.arange(scan_cols, device=device, dtype=torch.float32)
    si_grid, sj_grid = torch.meshgrid(si, sj, indexing="ij")  # (scan_rows, scan_cols)
    modulation = 1.0 + 0.15 * torch.sin(
        2 * torch.pi * si_grid / scan_rows
    ) * torch.cos(
        2 * torch.pi * sj_grid / scan_cols
    )  # (scan_rows, scan_cols)

    # Broadcast: (scan_rows, scan_cols, 1, 1) * (1, 1, det_rows, det_cols)
    data = base.unsqueeze(0).unsqueeze(0) * modulation.unsqueeze(-1).unsqueeze(-1)

    # Poisson shot noise for realism
    # MPS does not implement torch.poisson; sample on CPU when needed.
    if device.type == "mps":
        data = torch.poisson(data.clamp(min=0).cpu() * 200) / 200
    else:
        data = torch.poisson(data.clamp(min=0) * 200) / 200

    return data.cpu().numpy()


data = make_4dstem()
print(f"Shape: {data.shape}, dtype: {data.dtype}")
print(f"Range: [{data.min():.3f}, {data.max():.3f}]")

  start_thread=_should_start_thread(path),


Using device: mps
Shape: (16, 16, 64, 64), dtype: float32
Range: [0.000, 1.625]


In [4]:
w = Show4DSTEM(data)
w.auto_detect_center()
w.roi_circle()
print(f"Detected center: ({w.center_row:.1f}, {w.center_col:.1f}), BF radius: {w.bf_radius:.1f}")
w

Detected center: (32.0, 32.1), BF radius: 9.7


Show4DSTEM(shape=(16, 16, 64, 64), sampling=(1.0 Å, 1.0 px), pos=(8, 8))

## Inspect Widget State

In [5]:
w.summary()

Show4DSTEM
════════════════════════════════
Scan:     16×16 (1.00 Å/px)
Detector: 64×64 (1.0000 px/px)
Position: (8, 8)
Center:   (32.0, 32.1)  BF r=9.7 px
Display:  DC masked
ROI:      circle at (32.0, 32.1) r=4.9
DP view:  inferno, linear, 0.0-100.0%
VI view:  inferno, linear, 0.0-100.0%
