In [None]:
try:
    %load_ext autoreload
    %autoreload 2
    %env ANYWIDGET_HMR=1
except Exception:
    pass  # autoreload unavailable (Colab Python 3.12+)

In [5]:
import numpy as np
import torch
import time
from quantem.widget import Show3D

## Stress Test: 50 frames x 4000x4000

Large-frame stress test:
- 50 frames x 4000x4000 x 4 bytes = **3.2 GB total data**
- Each frame = 64 MB
- Buffer: 1 frame per chunk (64 MB cap / 64 MB per frame)
- At 30 FPS: each frame needs to transfer in <33ms

**How to test:**
1. Press play - should start at default 5 FPS
2. Crank FPS slider to 15, then 30
3. Watch for: smooth playback, no freezing, slider tracks position
4. Try zoom/pan on the 4000x4000 canvas

In [6]:
def generate_focal_series(n_frames, size, device="cpu"):
    t0 = time.time()
    fy = torch.fft.fftfreq(size, device=device).unsqueeze(1)
    fx = torch.fft.fftfreq(size, device=device).unsqueeze(0)
    freq_r2 = fy ** 2 + fx ** 2
    spacing = size / 10
    atom_sigma = spacing * 0.06
    impulses = torch.zeros(size, size, device=device)
    for i in range(-2, 15):
        for j in range(-2, 15):
            cx = i * spacing + (j % 2) * spacing * 0.5
            cy = j * spacing * 0.866
            ix, iy = int(round(cx)), int(round(cy))
            if 0 <= ix < size and 0 <= iy < size:
                impulses[iy, ix] = 1.0
    gauss_atom = torch.exp(-2 * np.pi ** 2 * atom_sigma ** 2 * freq_r2)
    lattice_fft = torch.fft.fft2(impulses) * gauss_atom
    y = torch.linspace(0, size - 1, size, device=device)
    x = torch.linspace(0, size - 1, size, device=device)
    yy, xx = torch.meshgrid(y, x, indexing="ij")
    r_center = torch.sqrt((xx - size / 2) ** 2 + (yy - size / 2) ** 2)
    ring = torch.exp(-((r_center - size * 0.35) / (size * 0.01)) ** 2) * 0.3
    lattice_fft = lattice_fft + torch.fft.fft2(ring)
    N_KEY = 50
    key_abs = torch.linspace(0, 1, N_KEY, device=device)
    key_sigmas = (key_abs ** 2 * spacing * 0.5).view(N_KEY, 1, 1)
    key_filters = torch.exp(-2 * np.pi ** 2 * key_sigmas ** 2 * freq_r2.unsqueeze(0))
    key_frames = torch.fft.ifft2(lattice_fft.unsqueeze(0) * key_filters).real
    r_norm_sq = (r_center / size) ** 2
    defocus = torch.linspace(-1, 1, n_frames, device=device)
    abs_defocus = torch.abs(defocus)
    idx_float = abs_defocus * (N_KEY - 1)
    idx_low = idx_float.long().clamp(max=N_KEY - 2)
    frac = idx_float - idx_low.float()
    result = torch.empty(n_frames, size, size)
    batch_size = min(5, n_frames)
    for b0 in range(0, n_frames, batch_size):
        b1 = min(b0 + batch_size, n_frames)
        nb = b1 - b0
        df = defocus[b0:b1]
        bf = frac[b0:b1].view(nb, 1, 1)
        bi = idx_low[b0:b1]
        batch = torch.lerp(key_frames[bi], key_frames[bi + 1], bf)
        fringe_amp = torch.abs(df).view(nb, 1, 1) * 0.3
        fringe_k = (1.0 + torch.abs(df) * 30.0).view(nb, 1, 1)
        result[b0:b1] = batch + fringe_amp * torch.cos(fringe_k * r_norm_sq.unsqueeze(0) * 2 * np.pi) + 1.0
    elapsed = time.time() - t0
    print(f"Generated {n_frames}x{size}x{size} in {elapsed:.1f}s ({result.nelement() * 4 / 1e9:.1f} GB)")
    return result

In [None]:
data = generate_focal_series(50, 4000)