In [None]:
try:
    %load_ext autoreload
    %autoreload 2
    %env ANYWIDGET_HMR=1
except Exception:
    pass  # autoreload unavailable (Colab Python 3.12+)

In [None]:
import numpy as np
import torch
import time
import quantem.widget
from quantem.widget import Show3D
print(f"quantem.widget {quantem.widget.__version__}")

## Generate synthetic focal series (crystal lattice with evolving defocus)
Through-focal series of a hexagonal crystal lattice using FFT-based convolution:
- Sharp atomic columns at focus, heavily blurred away from focus
- Fresnel-like concentric fringes that grow with defocus magnitude
- Nanoparticle boundary ring for additional visual interest
- Light Poisson noise (high dose so noise doesn't dominate the defocus effects)

In [None]:
def generate_focal_series(n_frames, size, device="cpu"):
    """Generate through-focal series — key-frame interpolation, fully vectorized."""
    t0 = time.time()
    # Frequency grids
    fy = torch.fft.fftfreq(size, device=device).unsqueeze(1)
    fx = torch.fft.fftfreq(size, device=device).unsqueeze(0)
    freq_r2 = fy ** 2 + fx ** 2
    # Build lattice via FFT: impulses at atom sites convolved with Gaussian
    spacing = size / 10
    atom_sigma = spacing * 0.06
    impulses = torch.zeros(size, size, device=device)
    for i in range(-2, 15):
        for j in range(-2, 15):
            cx = i * spacing + (j % 2) * spacing * 0.5
            cy = j * spacing * 0.866
            ix, iy = int(round(cx)), int(round(cy))
            if 0 <= ix < size and 0 <= iy < size:
                impulses[iy, ix] = 1.0
    gauss_atom = torch.exp(-2 * np.pi ** 2 * atom_sigma ** 2 * freq_r2)
    lattice_fft = torch.fft.fft2(impulses) * gauss_atom
    # Nanoparticle boundary ring
    y = torch.linspace(0, size - 1, size, device=device)
    x = torch.linspace(0, size - 1, size, device=device)
    yy, xx = torch.meshgrid(y, x, indexing="ij")
    r_center = torch.sqrt((xx - size / 2) ** 2 + (yy - size / 2) ** 2)
    ring = torch.exp(-((r_center - size * 0.35) / (size * 0.01)) ** 2) * 0.3
    lattice_fft = lattice_fft + torch.fft.fft2(ring)
    # Pre-compute 50 key blur frames (only 50 FFTs, not n_frames)
    N_KEY = 50
    key_abs = torch.linspace(0, 1, N_KEY, device=device)
    key_sigmas = (key_abs ** 2 * spacing * 0.5).view(N_KEY, 1, 1)
    key_filters = torch.exp(-2 * np.pi ** 2 * key_sigmas ** 2 * freq_r2.unsqueeze(0))
    key_frames = torch.fft.ifft2(lattice_fft.unsqueeze(0) * key_filters).real
    # Precompute Fresnel radial grid
    r_norm_sq = (r_center / size) ** 2
    # Map each output frame to key frames via |defocus|
    defocus = torch.linspace(-1, 1, n_frames, device=device)
    abs_defocus = torch.abs(defocus)
    idx_float = abs_defocus * (N_KEY - 1)
    idx_low = idx_float.long().clamp(max=N_KEY - 2)
    frac = idx_float - idx_low.float()
    # Pre-allocate output (avoids expensive torch.cat copy)
    result = torch.empty(n_frames, size, size)
    batch_size = min(200, n_frames)
    for b0 in range(0, n_frames, batch_size):
        b1 = min(b0 + batch_size, n_frames)
        nb = b1 - b0
        df = defocus[b0:b1]
        bf = frac[b0:b1].view(nb, 1, 1)
        bi = idx_low[b0:b1]
        # Interpolate between adjacent key frames (no per-frame FFT)
        batch = torch.lerp(key_frames[bi], key_frames[bi + 1], bf)
        # Fresnel concentric fringes (grow with |defocus|)
        fringe_amp = torch.abs(df).view(nb, 1, 1) * 0.3
        fringe_k = (1.0 + torch.abs(df) * 30.0).view(nb, 1, 1)
        result[b0:b1] = batch + fringe_amp * torch.cos(fringe_k * r_norm_sq.unsqueeze(0) * 2 * np.pi) + 1.0
    elapsed = time.time() - t0
    print(f"Generated {n_frames}×{size}×{size} focal series in {elapsed:.1f}s "
          f"({result.nelement() * 4 / 1e6:.0f} MB float32)")
    return result

## Stress Test 1: 1000 frames × 250×250

In [None]:
data_small = generate_focal_series(1000, 250)

Generated 1000×250×250 focal series in 0.4s (250 MB float32)


In [None]:
t0 = time.time()
labels = [f"C10={d:.0f}nm" for d in np.linspace(-500, 500, 1000)]
w1 = Show3D(data_small, title="Focal Series: 1000×250×250", fps=30, labels=labels, pixel_size=0.2)
print(f"Widget created in {time.time()-t0:.2f}s")
w1

Widget created in 0.41s


<quantem.widget.show3d.Show3D object at 0x30ef9d6a0>

## Stress Test 2: 250 frames × 1000×1000 (buffer prefetch test)
This is the target scenario for the sliding buffer optimization.
- 250 frames × 1000×1000 × 4 bytes = 1 GB total data
- Buffer auto-capped at 64 MB → 16 frames per chunk
- At 30 FPS: ~0.53s of buffer, prefetch at 50% (~0.27s)
**How to test:** Press play and crank the FPS slider to 30, then 60. Playback should be smooth with no stuttering.

In [None]:
data_large = generate_focal_series(250, 1000)

In [None]:
t0 = time.time()
labels = [f"C10={d:.0f}nm" for d in np.linspace(-500, 500, 250)]
w2 = Show3D(data_large, title="Buffer Test: 250×1000×1000", fps=30, labels=labels, pixel_size=0.1)
print(f"Widget created in {time.time()-t0:.2f}s")
print(f"Buffer size: {w2._buffer_size} frames ({w2._buffer_size * 1000 * 1000 * 4 / 1e6:.0f} MB)")
w2