In [None]:
try:
    %load_ext autoreload
    %autoreload 2
    %env ANYWIDGET_HMR=1
except Exception:
    pass  # autoreload unavailable (Colab Python 3.12+)

In [2]:
import numpy as np
import torch
import time
from quantem.widget import Show3D

## Stress Test: 1000 frames × 250×250

Through-focal series of a hexagonal crystal lattice.
- 1000 frames × 250×250 × 4 bytes = 250 MB total
- Buffer: 64 frames per chunk (each frame 0.25 MB → fits easily in 128 MB cap)
- At 60 FPS: buffer lasts ~1s, plenty of prefetch headroom

**Expected:** Smooth playback at 60 FPS with no stuttering.

In [3]:
def generate_focal_series(n_frames, size, device="cpu"):
    """Generate through-focal series — key-frame interpolation, fully vectorized."""
    t0 = time.time()

    # Frequency grids
    fy = torch.fft.fftfreq(size, device=device).unsqueeze(1)
    fx = torch.fft.fftfreq(size, device=device).unsqueeze(0)
    freq_r2 = fy ** 2 + fx ** 2

    # Build lattice via FFT: impulses at atom sites convolved with Gaussian
    spacing = size / 10
    atom_sigma = spacing * 0.06
    impulses = torch.zeros(size, size, device=device)
    for i in range(-2, 15):
        for j in range(-2, 15):
            cx = i * spacing + (j % 2) * spacing * 0.5
            cy = j * spacing * 0.866
            ix, iy = int(round(cx)), int(round(cy))
            if 0 <= ix < size and 0 <= iy < size:
                impulses[iy, ix] = 1.0
    gauss_atom = torch.exp(-2 * np.pi ** 2 * atom_sigma ** 2 * freq_r2)
    lattice_fft = torch.fft.fft2(impulses) * gauss_atom

    # Nanoparticle boundary ring
    y = torch.linspace(0, size - 1, size, device=device)
    x = torch.linspace(0, size - 1, size, device=device)
    yy, xx = torch.meshgrid(y, x, indexing="ij")
    r_center = torch.sqrt((xx - size / 2) ** 2 + (yy - size / 2) ** 2)
    ring = torch.exp(-((r_center - size * 0.35) / (size * 0.01)) ** 2) * 0.3
    lattice_fft = lattice_fft + torch.fft.fft2(ring)

    # Pre-compute 50 key blur frames (only 50 FFTs, not n_frames)
    N_KEY = 50
    key_abs = torch.linspace(0, 1, N_KEY, device=device)
    key_sigmas = (key_abs ** 2 * spacing * 0.5).view(N_KEY, 1, 1)
    key_filters = torch.exp(-2 * np.pi ** 2 * key_sigmas ** 2 * freq_r2.unsqueeze(0))
    key_frames = torch.fft.ifft2(lattice_fft.unsqueeze(0) * key_filters).real

    # Precompute Fresnel radial grid
    r_norm_sq = (r_center / size) ** 2

    # Map each output frame to key frames via |defocus|
    defocus = torch.linspace(-1, 1, n_frames, device=device)
    abs_defocus = torch.abs(defocus)
    idx_float = abs_defocus * (N_KEY - 1)
    idx_low = idx_float.long().clamp(max=N_KEY - 2)
    frac = idx_float - idx_low.float()

    result = torch.empty(n_frames, size, size)
    batch_size = min(200, n_frames)

    for b0 in range(0, n_frames, batch_size):
        b1 = min(b0 + batch_size, n_frames)
        nb = b1 - b0
        df = defocus[b0:b1]
        bf = frac[b0:b1].view(nb, 1, 1)
        bi = idx_low[b0:b1]

        batch = torch.lerp(key_frames[bi], key_frames[bi + 1], bf)

        fringe_amp = torch.abs(df).view(nb, 1, 1) * 0.3
        fringe_k = (1.0 + torch.abs(df) * 30.0).view(nb, 1, 1)
        result[b0:b1] = batch + fringe_amp * torch.cos(fringe_k * r_norm_sq.unsqueeze(0) * 2 * np.pi) + 1.0

    elapsed = time.time() - t0
    print(f"Generated {n_frames}x{size}x{size} focal series in {elapsed:.1f}s "
          f"({result.nelement() * 4 / 1e6:.0f} MB float32)")
    return result

In [4]:
data = generate_focal_series(1000, 250)

Generated 1000x250x250 focal series in 0.2s (250 MB float32)


In [5]:
t0 = time.time()
labels = [f"C10={d:.0f}nm" for d in np.linspace(-500, 500, 1000)]
w = Show3D(data, title="Stress Test: 1000 x 250 x 250", fps=30, labels=labels, pixel_size=0.2)
print(f"Widget created in {time.time()-t0:.2f}s")
print(f"Buffer: {w._buffer_size} frames ({w._buffer_size * 250 * 250 * 4 / 1e6:.0f} MB)")
w

Widget created in 0.05s
Buffer: 64 frames (16 MB)


<quantem.widget.show3d.Show3D object at 0x31380ea50>