# Show4DSTEM Export + Reproducibility Demo

Programmatic exports with overlays/scalebar parity, figure templates, sequence manifests, and session reproducibility report.

In [None]:
try:
    %load_ext autoreload
    %autoreload 2
    %env ANYWIDGET_HMR=1
except Exception:
    pass

In [None]:
import json
import math
import pathlib

import numpy as np
import torch
from quantem.widget import Show4DSTEM


def make_crystal_4dstem(
    scan_shape=(40, 40),
    det_shape=(72, 72),
    n_frames=1,
    seed=7,
    device=None,
):
    if device is None:
        device = torch.device(
            "mps"
            if torch.backends.mps.is_available()
            else "cuda"
            if torch.cuda.is_available()
            else "cpu"
        )
    torch.manual_seed(seed)

    sy, sx = scan_shape
    ky, kx = det_shape
    fy = 1.0 if n_frames == 1 else float(n_frames)

    y = torch.linspace(-1.0, 1.0, sy, device=device)
    x = torch.linspace(-1.0, 1.0, sx, device=device)
    Y, X = torch.meshgrid(y, x, indexing="ij")

    frame_axis = torch.linspace(0.0, 1.0, n_frames, device=device)
    phase_t = 2.0 * math.pi * frame_axis[:, None, None]

    # Synthetic strain/defect field with crystal-like periodicity.
    lattice = 0.35 * torch.sin(2.0 * math.pi * (7.0 * X + 0.35 * Y))
    lattice += 0.30 * torch.sin(2.0 * math.pi * (7.0 * Y - 0.25 * X))
    defect = 0.55 * torch.exp(-((X + 0.20) ** 2 + (Y - 0.25) ** 2) / 0.05)
    defect += 0.35 * torch.exp(-((X - 0.35) ** 2 + (Y + 0.10) ** 2) / 0.03)
    structural = lattice + defect

    structural_t = structural[None, :, :] + 0.08 * torch.sin(phase_t + 5.0 * X[None, :, :])

    ky_axis = torch.linspace(0.0, float(ky - 1), ky, device=device)
    kx_axis = torch.linspace(0.0, float(kx - 1), kx, device=device)
    KY, KX = torch.meshgrid(ky_axis, kx_axis, indexing="ij")
    KY = KY[None, None, None, :, :]
    KX = KX[None, None, None, :, :]

    cy = (ky - 1) / 2.0
    cx = (kx - 1) / 2.0

    # Shift Bragg spots according to local strain/field.
    shift_y = 1.4 * structural_t[:, :, :, None, None]
    shift_x = -1.1 * structural_t[:, :, :, None, None]

    frame_mod = 1.0 + 0.20 * torch.sin(phase_t)[:, :, :, None, None]

    dp = torch.zeros((n_frames, sy, sx, ky, kx), device=device)
    sigma_direct = 2.8
    direct_amp = 180.0 * frame_mod
    dp += direct_amp * torch.exp(-((KY - cy) ** 2 + (KX - cx) ** 2) / (2.0 * sigma_direct**2))

    spot_sigma = 2.1
    base_spots = [
        (-16.0, 0.0, 75.0),
        (16.0, 0.0, 75.0),
        (0.0, -16.0, 72.0),
        (0.0, 16.0, 72.0),
        (-11.0, -11.0, 58.0),
        (11.0, 11.0, 58.0),
        (-11.0, 11.0, 54.0),
        (11.0, -11.0, 54.0),
    ]

    amp_field = 1.0 + 0.45 * torch.relu(structural_t)[:, :, :, None, None]
    for dy, dx, amp in base_spots:
        cy_t = cy + dy + shift_y
        cx_t = cx + dx + shift_x
        dp += (amp * amp_field) * torch.exp(
            -((KY - cy_t) ** 2 + (KX - cx_t) ** 2) / (2.0 * spot_sigma**2)
        )

    ring_r = torch.sqrt((KY - cy) ** 2 + (KX - cx) ** 2)
    dp += 12.0 * torch.exp(-((ring_r - 23.0) ** 2) / (2.0 * 7.0**2))

    if device.type == "mps":
        # MPS currently lacks aten::poisson; sample on CPU then move back.
        shot = torch.poisson(torch.clamp(dp.cpu(), min=0.0)).to(device)
    else:
        shot = torch.poisson(torch.clamp(dp, min=0.0))
    dp = shot + 0.6 * torch.randn_like(shot)
    dp = torch.clamp(dp, min=0.0).to(torch.float32)

    arr = dp.cpu().numpy()
    if n_frames == 1:
        return arr[0]
    return arr


DEVICE = torch.device(
    "mps"
    if torch.backends.mps.is_available()
    else "cuda"
    if torch.cuda.is_available()
    else "cpu"
)
print(f"Using device: {DEVICE}")


In [None]:
OUT = pathlib.Path('notebooks/show4dstem/paper_exports')
OUT.mkdir(parents=True, exist_ok=True)

data = make_crystal_4dstem(scan_shape=(40, 40), det_shape=(72, 72), n_frames=4, seed=12, device=DEVICE)
w = Show4DSTEM(data, pixel_size=0.78, k_pixel_size=0.52)
w.auto_detect_center()
w.roi_annular()
w.roi_radius = 14
w.roi_radius_inner = 7
w.vi_roi_mode = 'rect'
w.vi_roi_center_row = 20
w.vi_roi_center_col = 20
w.vi_roi_width = 14
w.vi_roi_height = 10
w.show_fft = True
w.dp_scale_mode = 'log'
w.vi_scale_mode = 'linear'
w.fft_scale_mode = 'log'
w.export_default_view = 'all'
w.export_default_format = 'png'
w.export_include_overlays = True
w.export_include_scalebar = True
w

In [None]:
img_dp = w.save_image(
    OUT / 'single_dp_overlay.png',
    view='diffraction',
    position=(18, 23),
    frame_idx=2,
    include_overlays=True,
    include_scalebar=True,
)
img_all_pdf = w.save_image(
    OUT / 'single_all_publication.pdf',
    view='all',
    position=(18, 23),
    frame_idx=2,
    format='pdf',
    include_overlays=True,
    include_scalebar=True,
)
fig_pub = w.save_figure(
    OUT / 'figure_dp_vi_fft.png',
    template='publication_dp_vi_fft',
    position=(18, 23),
    frame_idx=2,
    annotations={'diffraction': 'Annular ROI', 'virtual': 'High-info zone', 'fft': 'Lattice'}
)

path_points = [(6, 6), (10, 12), (15, 18), (20, 24), (24, 28), (30, 33)]
seq_manifest = w.save_sequence(
    OUT / 'path_sequence',
    mode='path',
    view='all',
    format='png',
    frame_idx=1,
    path_points=path_points,
    include_overlays=True,
    include_scalebar=True,
    filename_prefix='path_run',
)
report = w.save_reproducibility_report(OUT / 'session_reproducibility.json')

print('Saved:', img_dp)
print('Saved:', img_all_pdf)
print('Saved:', fig_pub)
print('Saved sequence manifest:', seq_manifest)
print('Saved report:', report)

In [None]:
meta = json.loads((OUT / 'single_dp_overlay.json').read_text())
seq = json.loads((OUT / 'path_sequence' / 'save_sequence_manifest.json').read_text())
rep = json.loads((OUT / 'session_reproducibility.json').read_text())

summary = {
    'metadata_version': meta['metadata_version'],
    'widget_version': meta['widget_version'],
    'view': meta['view'],
    'format': meta['format'],
    'include_overlays': meta['include_overlays'],
    'include_scalebar': meta['include_scalebar'],
    'sequence_exports': seq['n_exports'],
    'report_exports_logged': rep['n_exports'],
}
summary

In [None]:
sorted(str(p) for p in OUT.rglob('*') if p.is_file())[:20]