In [1]:
%load_ext autoreload
%autoreload 2
%env ANYWIDGET_HMR=1

env: ANYWIDGET_HMR=1


# Merge4DSTEM — Simple Demo

Stack multiple 4D-STEM datasets along a time axis → 5D output.

This notebook generates three synthetic 4D-STEM time-series datasets
(simulating a crystal lattice evolving under beam exposure) and merges
them into a single 5D array.

In [None]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

import torch
import numpy as np

device = torch.device(
    "mps" if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available()
    else "cpu"
)
print(f"Using device: {device}")

# Parameters
scan_r, scan_c = 32, 32
det_r, det_c = 64, 64
n_frames = 3  # time points

# Fully vectorized: generate all frames at once on GPU
# k-space coordinates: (det_r, det_c)
ky = torch.arange(det_r, device=device, dtype=torch.float32) - det_r / 2
kx = torch.arange(det_c, device=device, dtype=torch.float32) - det_c / 2
KY, KX = torch.meshgrid(ky, kx, indexing="ij")  # (det_r, det_c)

# Bragg spot positions: (n_spots, 2) — base positions before drift
bragg_base = torch.tensor([
    [10, 0], [-10, 0], [0, 10], [0, -10],
    [7, 7], [-7, -7], [7, -7], [-7, 7],
], device=device, dtype=torch.float32)

# Time-dependent drift: (n_frames,) → broadcast to (n_frames, n_spots, 2)
drift = 0.5 * torch.arange(n_frames, device=device, dtype=torch.float32)  # (n_frames,)
drift_signs = torch.sign(bragg_base)  # (n_spots, 2)
bragg_pos = bragg_base.unsqueeze(0) + drift.reshape(-1, 1, 1) * drift_signs.unsqueeze(0)
# bragg_pos: (n_frames, n_spots, 2)

# Central beam: (det_r, det_c) → broadcast to (n_frames, 1, 1, det_r, det_c)
central = 100.0 * torch.exp(-(KY**2 + KX**2) / (2 * 3.0**2))
dp_all = central.unsqueeze(0).unsqueeze(0).unsqueeze(0).expand(n_frames, 1, 1, det_r, det_c).clone()

# Add all Bragg spots: vectorized over spots
for s in range(bragg_pos.shape[1]):
    qy = bragg_pos[:, s, 0].reshape(-1, 1, 1, 1, 1)  # (n_frames, 1, 1, 1, 1)
    qx = bragg_pos[:, s, 1].reshape(-1, 1, 1, 1, 1)
    KY_5d = KY.reshape(1, 1, 1, det_r, det_c)
    KX_5d = KX.reshape(1, 1, 1, det_r, det_c)
    dp_all += 30.0 * torch.exp(-((KY_5d - qy)**2 + (KX_5d - qx)**2) / (2 * 1.5**2))

# Scan variation: (n_frames, scan_r, scan_c, 1, 1)
scan_var = 0.8 + 0.4 * torch.rand(n_frames, scan_r, scan_c, 1, 1, device=device)
data_5d = dp_all * scan_var  # (n_frames, scan_r, scan_c, det_r, det_c)

# Gaussian noise (fast, stays on GPU)
noise = torch.sqrt(torch.clamp(data_5d, min=0)) * torch.randn_like(data_5d)
data_5d = torch.clamp(data_5d + noise, min=0)

# Split into list of 4D arrays
sources = [data_5d[t].cpu().numpy() for t in range(n_frames)]
for t, s in enumerate(sources):
    print(f"Frame {t}: shape={s.shape}, mean={s.mean():.2f}")
print(f"\nGenerated {len(sources)} source datasets")

In [3]:
from quantem.widget import Merge4DSTEM

w = Merge4DSTEM(
    sources,
    pixel_size=2.39,
    k_pixel_size=0.46,
    frame_dim_label="Time",
    title="Merge4DSTEM — Crystal Lattice Time Series",
)
w

  start_thread=_should_start_thread(path),


Merge4DSTEM(sources=3, scan=(32, 32), det=(64, 64), merged=False, device=mps)

In [4]:
# Merge all sources → 5D
w.merge()
w.summary()

Merge4DSTEM
Sources:    3
Shape:      scan=(32, 32), det=(64, 64)
Real cal:   2.39 Å/px
K cal:      0.46 mrad/px
Merged:     True
Output:     [3, 32, 32, 64, 64]
Device:     mps
Display:    cmap=inferno, log_scale=False
Status:     OK - Merged 3 sources -> (3, 32, 32, 64, 64) on mps


In [5]:
# Access the merged result
arr = w.result_array
print(f"Merged array shape: {arr.shape}")
print(f"Merged array dtype: {arr.dtype}")
print(f"Merged array size: {arr.nbytes / 1e6:.1f} MB")

Merged array shape: (3, 32, 32, 64, 64)
Merged array dtype: float32
Merged array size: 50.3 MB


In [6]:
# Open merged result in Show4DSTEM
viewer = w.to_show4dstem()
viewer

Show4DSTEM(shape=(3, 32, 32, 64, 64), sampling=(2.39 Å, 0.46 mrad), pos=(16, 16), frame=0)