[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/bobleesj/quantem.widget/blob/main/notebooks/show4dstem/show4dstem_5d.ipynb)

# Show4DSTEM — 5D Time/Tilt Series

Show4DSTEM supports **5D datasets** with shape `(n_frames, scan_rows, scan_cols, det_rows, det_cols)`.
The extra leading dimension represents a tilt series, time series, or any other
sequential acquisition. A frame slider appears at the bottom of the widget for navigation.

Synthetic data generated with **PyTorch** (GPU-accelerated on MPS/CUDA).

In [11]:
# Install in Google Colab
try:
    import google.colab
    !pip install -q -i https://test.pypi.org/simple/ --extra-index-url https://pypi.org/simple/ quantem-widget
except ImportError:
    pass  # Not in Colab, skip

In [12]:
try:
    %load_ext autoreload
    %autoreload 2
    %env ANYWIDGET_HMR=1
except Exception:
    pass  # autoreload unavailable (Colab Python 3.12+)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
env: ANYWIDGET_HMR=1


In [13]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

import torch
import numpy as np
from quantem.widget import Show4DSTEM

device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: mps


## 1. 5D Tilt Series

Generate a 5D dataset `(n_tilts, scan_rows, scan_cols, det_rows, det_cols)`.
Each tilt angle shifts the BF disk and modulates Bragg spot intensities via the
excitation error. The frame slider at the bottom of the widget navigates between tilts.

In [14]:
def make_5d_tilt(n_tilts=10, scan_rows=12, scan_cols=12, det_rows=64, det_cols=64):
    """5D tilt series with tilt-dependent BF shift and Bragg excitation (fully vectorized)."""
    dr = torch.arange(det_rows, device=device, dtype=torch.float32)
    dc = torch.arange(det_cols, device=device, dtype=torch.float32)
    rr, cc = torch.meshgrid(dr, dc, indexing="ij")  # (det_rows, det_cols)
    cr, cc0 = det_rows / 2, det_cols / 2
    center_dist = ((rr - cr) ** 2 + (cc - cc0) ** 2).sqrt()

    bg = 0.05 * torch.exp(-center_dist / 30)  # (dr, dc)

    # Precompute 6 first-order Bragg spot maps and their phase angles
    angles = torch.arange(6, device=device, dtype=torch.float32) * (torch.pi / 3)  # (6,)
    spot_r = cr + 20 * torch.sin(angles)  # (6,)
    spot_c = cc0 + 20 * torch.cos(angles)  # (6,)
    d2 = (rr[None] - spot_r[:, None, None]) ** 2 + (cc[None] - spot_c[:, None, None]) ** 2  # (6, dr, dc)
    spot_maps = torch.exp(-d2 / (2 * 2.5**2))  # (6, dr, dc)

    # Scan modulation: (scan_rows, scan_cols)
    si = torch.arange(scan_rows, device=device, dtype=torch.float32)
    sj = torch.arange(scan_cols, device=device, dtype=torch.float32)
    si_g, sj_g = torch.meshgrid(si, sj, indexing="ij")
    scan_mod = 1.0 + 0.1 * torch.sin(2 * torch.pi * si_g / scan_rows)

    # Tilts: (n_tilts,)
    tilts = torch.linspace(-0.4, 0.4, n_tilts, device=device)

    # BF disk per tilt — shift center with tilt: (n_tilts, dr, dc)
    shift_r = cr + 3 * torch.sin(tilts)  # (n_tilts,)
    shift_c = cc0 + 3 * torch.cos(tilts) - 3  # (n_tilts,)
    d_bf = ((rr[None] - shift_r[:, None, None]) ** 2 + (cc[None] - shift_c[:, None, None]) ** 2).sqrt()
    bf = (d_bf < 8).float() * (1.0 + 0.2 * torch.cos(d_bf * 0.5))  # (n_tilts, dr, dc)

    # Excitation error per tilt per spot: (n_tilts, 6)
    excitation = 0.4 * torch.clamp(torch.cos(tilts[:, None] * 5 + angles[None, :]), min=0.05)
    # Weighted sum of spot maps per tilt: (n_tilts, dr, dc)
    spots = torch.einsum("ts,sdr->tdr", excitation, spot_maps)

    # Combine: (n_tilts, dr, dc)
    patterns = bg[None] + bf + spots

    # Broadcast with scan modulation: (n_tilts, scan_rows, scan_cols, dr, dc)
    data_5d = patterns[:, None, None, :, :] * scan_mod[None, :, :, None, None]

    # Poisson noise (use NumPy — torch.poisson may fail on MPS)
    data_np = data_5d.cpu().numpy()
    data_np = np.random.poisson(np.clip(data_np, 0, None) * 200).astype(np.float32) / 200
    return data_np


tilt_data = make_5d_tilt(n_tilts=10)
print(f"5D tilt shape: {tilt_data.shape}")
print(f"Range: [{tilt_data.min():.3f}, {tilt_data.max():.3f}]")

5D tilt shape: (10, 12, 12, 64, 64)
Range: [0.000, 1.640]


In [15]:
w_tilt = Show4DSTEM(tilt_data, frame_dim_label="Tilt")
w_tilt.auto_detect_center()
w_tilt.roi_circle()
print(f"Tilts: {w_tilt.n_frames}, scan: {w_tilt.shape_rows}\u00d7{w_tilt.shape_cols}")
print(f"Use the frame slider or [ / ] keys to navigate tilts")
w_tilt

Tilts: 10, scan: 12×12
Use the frame slider or [ / ] keys to navigate tilts


Show4DSTEM(shape=(10, 12, 12, 64, 64), sampling=(1.0 Å, 1.0 px), pos=(6, 6), tilt=0)

## 2. 5D Time Series (In-Situ)

Simulate an in-situ experiment where Bragg spots gradually intensify over time
(e.g., crystallization from an amorphous precursor). Use `frame_dim_label="Time"`
to label the frame axis accordingly.

Playback controls match Show3D: **fps slider**, **loop**, and **bounce** (ping-pong).
Use the transport buttons (rewind/play/forward/stop) or `[`/`]` keys to navigate frames.

In [16]:
def make_5d_time(n_times=8, scan_rows=12, scan_cols=12, det_rows=64, det_cols=64):
    """5D time series: crystallization — spots emerge from amorphous background (fully vectorized)."""
    dr = torch.arange(det_rows, device=device, dtype=torch.float32)
    dc = torch.arange(det_cols, device=device, dtype=torch.float32)
    rr, cc = torch.meshgrid(dr, dc, indexing="ij")  # (dr, dc)
    cr, cc0 = det_rows / 2, det_cols / 2
    center_dist = ((rr - cr) ** 2 + (cc - cc0) ** 2).sqrt()

    bg = 0.05 * torch.exp(-center_dist / 30)  # (dr, dc)
    bf = (center_dist < 8).float() * (1.0 + 0.2 * torch.cos(center_dist * 0.5))  # (dr, dc)

    # Precompute 6 Bragg spot maps: (6, dr, dc)
    angles = torch.arange(6, device=device, dtype=torch.float32) * (torch.pi / 3)
    spot_r = cr + 20 * torch.sin(angles)
    spot_c = cc0 + 20 * torch.cos(angles)
    d2 = (rr[None] - spot_r[:, None, None]) ** 2 + (cc[None] - spot_c[:, None, None]) ** 2
    spot_maps = torch.exp(-d2 / (2 * 2.5**2))
    spot_sum = 0.4 * spot_maps.sum(dim=0)  # (dr, dc)

    # Amorphous ring template: (dr, dc)
    amorphous_ring = 0.15 * torch.exp(-((center_dist - 22) ** 2) / (2 * 5**2))

    # Scan modulation: (scan_rows, scan_cols)
    si = torch.arange(scan_rows, device=device, dtype=torch.float32)
    sj = torch.arange(scan_cols, device=device, dtype=torch.float32)
    si_g, sj_g = torch.meshgrid(si, sj, indexing="ij")
    scan_mod = 1.0 + 0.1 * torch.sin(2 * torch.pi * si_g / scan_rows)

    # Crystallinity ramp: (n_times, 1, 1)
    c = torch.linspace(0.0, 1.0, n_times, device=device)[:, None, None]

    # Per-time patterns via broadcasting: (n_times, dr, dc)
    patterns = bg[None] + bf[None] + c * spot_sum[None] + (1 - c) * amorphous_ring[None]

    # Broadcast with scan modulation: (n_times, scan_rows, scan_cols, dr, dc)
    data_5d = patterns[:, None, None, :, :] * scan_mod[None, :, :, None, None]

    # Poisson noise (use NumPy — torch.poisson may fail on MPS)
    data_np = data_5d.cpu().numpy()
    data_np = np.random.poisson(np.clip(data_np, 0, None) * 200).astype(np.float32) / 200
    return data_np


time_data = make_5d_time(n_times=8)
print(f"5D time shape: {time_data.shape}")

5D time shape: (8, 12, 12, 64, 64)


In [None]:
w_time = Show4DSTEM(time_data, frame_dim_label="Time")
w_time.auto_detect_center()
w_time.roi_circle()
w_time.frame_loop = True
w_time.frame_fps = 3.0       # slower playback for crystallization
w_time.frame_boomerang = True # bounce back and forth
print(f"Time steps: {w_time.n_frames}")
print(f"Frame 0: amorphous diffuse ring, Frame {w_time.n_frames-1}: crystalline Bragg spots")
print(f"Playback: {w_time.frame_fps} fps, bounce={'on' if w_time.frame_boomerang else 'off'}")
w_time

## 3. 5D State Persistence

`state_dict()` and `save()` include the `frame_idx` and `frame_dim_label` for 5D data.
Restore the exact tilt/time position after a kernel restart.

In [18]:
# Navigate to tilt 5, save state
w_tilt.frame_idx = 5
w_tilt.save("show4dstem_5d_state.json")

import json
state = w_tilt.state_dict()
print(f"Saved state \u2014 frame_idx: {state['frame_idx']}, frame_dim_label: {state['frame_dim_label']}")

# Restore into a new widget
w_restored_5d = Show4DSTEM(tilt_data, state="show4dstem_5d_state.json")
print(f"Restored frame_idx: {w_restored_5d.frame_idx}")
w_restored_5d

Saved state — frame_idx: 5, frame_dim_label: Tilt
Restored frame_idx: 5


Show4DSTEM(shape=(10, 12, 12, 64, 64), sampling=(1.0 Å, 1.0 px), pos=(6, 6), tilt=5)

In [19]:
# Clean up
from pathlib import Path
Path("show4dstem_5d_state.json").unlink(missing_ok=True)

## 4. Inspect 5D Widget

Use `summary()` to see the full 5D state including frame dimension.

In [10]:
w_tilt.summary()

Show4DSTEM
════════════════════════════════
Frames:   10 (Tilt), current: 5
Scan:     12×12 (1.00 Å/px)
Detector: 64×64 (1.0000 px/px)
Position: (7, 3)
Center:   (32.0, 31.8)  BF r=8.4 px
Display:  DC masked
ROI:      annular at (32.0, 31.8) r=25.1
