In [None]:
%load_ext autoreload
%autoreload 2
%env ANYWIDGET_HMR=1

# Merge4DSTEM — All Features
Comprehensive demo of the Merge4DSTEM widget showing:
1. **Synthetic time-series generation** — crystal lattice with beam damage evolution
2. **Widget construction** with calibration
3. **Merge workflow** — merge, inspect result, save
4. **State persistence** — save/load display settings
5. **Tool lock/hide** — programmatic UI customization
6. **Open in Show4DSTEM** — view merged 5D data

In [None]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
import torch
import numpy as np
device = torch.device(
    "mps" if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available()
    else "cpu"
)
print(f"Using device: {device}")
# Simulate a 4D-STEM time series of a crystal under beam damage
# Each time frame shows progressively weaker Bragg spots
scan_r, scan_c = 48, 48
det_r, det_c = 64, 64
n_frames = 5
pixel_size = 2.39  # Å/px
k_pixel_size = 0.46  # mrad/px
# Fully vectorized: generate all frames at once on GPU
ky = torch.arange(det_r, device=device, dtype=torch.float32) - det_r / 2
kx = torch.arange(det_c, device=device, dtype=torch.float32) - det_c / 2
KY, KX = torch.meshgrid(ky, kx, indexing="ij")  # (det_r, det_c)
# Hexagonal lattice Bragg spots: (n_spots, 2)
bragg_base = torch.tensor([
    [12, 0], [-12, 0], [0, 12], [0, -12],
    [6, 10.4], [-6, -10.4], [6, -10.4], [-6, 10.4],
    [12, 10.4], [-12, -10.4], [12, -10.4], [-12, 10.4],
], device=device, dtype=torch.float32)
# Beam damage decay per frame: (n_frames, 1, 1, 1, 1)
t_idx = torch.arange(n_frames, device=device, dtype=torch.float32)
damage = (1.0 - 0.15 * t_idx).reshape(-1, 1, 1, 1, 1)
# Central beam: always strong → (1, 1, 1, det_r, det_c)
central = 200.0 * torch.exp(-(KY**2 + KX**2) / (2 * 4.0**2))
central = central.reshape(1, 1, 1, det_r, det_c)
# Amorphous ring (increases with damage): (n_frames, 1, 1, det_r, det_c)
R = torch.sqrt(KY**2 + KX**2).reshape(1, 1, 1, det_r, det_c)
amorphous = 5.0 * (1 + 0.3 * t_idx.reshape(-1, 1, 1, 1, 1)) * torch.exp(
    -((R - 18)**2) / (2 * 4.0**2)
)
# Sum Bragg spots: vectorized over spots → (1, 1, 1, det_r, det_c) per spot
bragg_sum = torch.zeros(1, 1, 1, det_r, det_c, device=device)
KY_5d = KY.reshape(1, 1, 1, det_r, det_c)
KX_5d = KX.reshape(1, 1, 1, det_r, det_c)
for s in range(bragg_base.shape[0]):
    qy, qx = bragg_base[s, 0], bragg_base[s, 1]
    bragg_sum += 50.0 * torch.exp(-((KY_5d - qy)**2 + (KX_5d - qx)**2) / (2 * 1.2**2))
# Combine: central + amorphous + damage-scaled Bragg → (n_frames, 1, 1, det_r, det_c)
dp_all = central + amorphous + damage * bragg_sum
# Thickness gradient (thicker on left): (1, 1, scan_c, 1, 1)
col_idx = torch.arange(scan_c, device=device, dtype=torch.float32)
thickness = (0.6 + 0.8 * col_idx / scan_c).reshape(1, 1, scan_c, 1, 1)
# Row variation: (1, scan_r, 1, 1, 1)
row_idx = torch.arange(scan_r, device=device, dtype=torch.float32)
row_var = (0.9 + 0.2 * torch.sin(2 * np.pi * row_idx / scan_r)).reshape(1, scan_r, 1, 1, 1)
# Full 5D data: (n_frames, scan_r, scan_c, det_r, det_c)
data_5d = dp_all * thickness * row_var
# Gaussian noise (shot-noise-like: σ = √I)
noise = torch.sqrt(torch.clamp(data_5d, min=0)) * torch.randn_like(data_5d)
data_5d = torch.clamp(data_5d + noise, min=0)
# Split into list of 4D numpy arrays
sources = [data_5d[t].cpu().numpy() for t in range(n_frames)]
for t, s in enumerate(sources):
    damage_val = 1.0 - 0.15 * t
    print(f"Frame {t}: shape={s.shape}, mean={s.mean():.2f}, damage={damage_val:.2f}")
print(f"\n{len(sources)} source datasets generated")

## 1. Create the Merge Widget

In [None]:
import quantem.widget
from quantem.widget import Merge4DSTEM
w = Merge4DSTEM(
    sources,
    pixel_size=pixel_size,
    k_pixel_size=k_pixel_size,
    frame_dim_label="Time",
    title="Beam Damage Time Series",
    cmap="inferno",
    log_scale=True,
)
w
print(f"quantem.widget {quantem.widget.__version__}")

In [None]:
# Compact repr
w

## 2. Merge and Inspect

In [None]:
# Stack all sources on GPU → 5D
w.merge()
w.summary()

In [None]:
# Access raw 5D numpy array
arr = w.result_array
print(f"Shape: {arr.shape}")
print(f"Size:  {arr.nbytes / 1e6:.1f} MB")
print(f"dtype: {arr.dtype}")
# Verify each frame is different (beam damage)
for i in range(arr.shape[0]):
    print(f"  Frame {i}: mean={arr[i].mean():.2f}")

## 3. State Persistence

In [None]:
# Get state dict
sd = w.state_dict()
sd

In [None]:
import tempfile, json
# Save state to JSON
with tempfile.NamedTemporaryFile(suffix=".json", delete=False, mode="w") as f:
    state_path = f.name
w.save(state_path)
print(f"Saved to: {state_path}")
print(json.dumps(json.loads(open(state_path).read()), indent=2))

In [None]:
# Restore from saved state
w2 = Merge4DSTEM(sources, state=state_path)
print(f"Restored: cmap={w2.cmap}, log_scale={w2.log_scale}, frame_dim_label={w2.frame_dim_label}")

## 4. Tool Lock / Hide

In [None]:
# Lock merge button
w3 = Merge4DSTEM(sources, disabled_tools=["merge"])
print(f"Disabled: {w3.disabled_tools}")
w3

In [None]:
# Hide sources table and stats
w4 = Merge4DSTEM(sources, hidden_tools=["sources", "stats"])
print(f"Hidden: {w4.hidden_tools}")
w4

In [None]:
# Runtime API
w5 = Merge4DSTEM(sources)
w5.lock_tool("display").hide_tool("export")
print(f"Disabled: {w5.disabled_tools}")
print(f"Hidden: {w5.hidden_tools}")
# Undo
w5.unlock_tool("display").show_tool("export")
print(f"After undo — Disabled: {w5.disabled_tools}, Hidden: {w5.hidden_tools}")

## 5. Open in Show4DSTEM

In [None]:
# Open merged 5D result in Show4DSTEM viewer
viewer = w.to_show4dstem()
print(f"Show4DSTEM loaded with 5D data")
viewer