In [None]:
from typing import Union, Sequence
from pathlib import Path
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import matplotlib.patches as patches
from omegaconf import OmegaConf
import hydra
from hydra import initialize_config_dir, compose
import time
import imageio

import numpy as np
import cv2
import torch
import torch.nn.functional as F

import spnf
from spnf.utils import set_seed, make_coord_grid, apply_to_tensors, to_py, interpolate_covariance_matrices_numpy
from spnf.sample import (
    rand_ortho, logrand, construct_covariance, sample_gaussian_delta, sample_ellipsoid_delta
)
from spnf.trainer import Trainer

config_dir = Path(spnf.__file__).parent / "configs"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_cfg(config_name, config_dir: Path, overrides: list):
    with initialize_config_dir(version_base=None, config_dir=str(config_dir)):
        cfg = compose(config_name=config_name, overrides=overrides)
    return cfg

def render_kernel(kind, cov_mat, interval=(-3.0, 3.0), res=1024, a=1.0, colormap=None):
    device = cov_mat.device
    lo, hi = float(interval[0]), float(interval[1])
    x = torch.linspace(lo, hi, res, device=device)
    grid = torch.stack(torch.meshgrid(x, x, indexing='xy'), dim=-1).reshape(-1, 2)  # (res*res, 2)

    inv_cov = torch.linalg.inv(cov_mat)
    v = grid @ inv_cov
    q = (v * grid).sum(-1)  # quadratic form

    if kind.lower() == 'gaussian':
        ker = torch.exp(-0.5 * q)                                    # [0,1]
        ker01 = ker.clamp(0, 1)
    elif kind.lower() == 'uniform_ellipsoid':
        ker = (q.sqrt() < 1).to(grid.dtype)                           # {0,1}
        ker01 = ker
    elif kind.lower() == 'lanczos':
        t = q.sqrt()
        # torch.sinc(x) = sin(pi*x)/(pi*x)
        ker = torch.sinc(t) * torch.sinc(t / a)                       # can be [-1,1]
        ker = ker / torch.max(ker.abs()) * 2.5
        ker01 = ker.clamp(-1, 1).abs()
        ker01 = ker01.clamp(0,1)
    else:
        raise ValueError("kind must be 'gaussian', 'uniform_ellipsoid', or 'lanczos'")

    img_u8 = (ker01.reshape(res, res) * 255).round().to(torch.uint8).cpu().numpy()

    if colormap is not None:
        vis = cv2.applyColorMap(img_u8, colormap)                     # 3ch BGR
    else:
        vis = cv2.cvtColor(img_u8, cv2.COLOR_GRAY2BGR)                # grayscale 3ch

    return vis

## Training visualization for neural fields

In [None]:
filter_type = "gaussian" # Options: 'gaussian', 'uniform_ellipsoid', 'lanczos'
lambda_decay_start = 500
vis_steps = 800
total_steps = 2000
SRC_RES = 512
FPS = 30
cfg = get_cfg(
    config_name="train",
    config_dir=config_dir,
    overrides=[
        f"data.grid.resize={SRC_RES}",
        "data.bounds=1.0", 
        "tensorboard=null",
        f"trainer.steps={total_steps}",
        "paths.output_dir=/home/myaldiz/Data/Experiments/spnf/${task_name}",
        f"scheduler.lr_lambda.decay_start={lambda_decay_start}",
        "trainer.compile_train_step=True",
    ],
)
video_settings = dict(
    fps=FPS,
    codec='libx265', 
    pixelformat='yuv420p',
    ffmpeg_params = [
        '-crf', '18', '-preset', 'slow', 
        '-tag:v', 'hvc1', '-tune', 'grain'
    ]
)
dev_video_settings = video_settings | dict(
    ffmpeg_params=[
        '-crf', '20', 
        '-preset', 'fast',
        '-tag:v', 'hvc1',
        '-tune', 'grain',
    ]
)

# Create output directories
output_dir = Path(cfg.paths.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
checkpoint_dir = Path(cfg.trainer.checkpoint_dir)
checkpoint_dir.mkdir(parents=True, exist_ok=True)
video_output_dir = output_dir / "visualizations" 
video_output_dir.mkdir(parents=True, exist_ok=True)

# Save the config to output directory
OmegaConf.save(config=cfg, f=output_dir / "config_multiscale.yaml", resolve=True)

In [None]:
# Instantiate the trainer
trainer = Trainer(cfg).to(device)

# Get the gt image saved as a reference
gt_signal = trainer.data.grid.to(device)
gt_signal = gt_signal[0].add(1.0).div(2.0).mul(255).round().clamp(0, 255).permute(1,2,0).detach().cpu().numpy().astype(np.uint8)
imageio.imwrite(
    output_dir / "visualizations" / "ground_truth.png", 
    gt_signal
)

repeat_frames = FPS * 3
times = np.array([0.0, 0.125, 0.25, 0.375, 0.5, 0.75, 0.85, 1.0])
vis_covs, vis_eigvals, vis_eigvecs = interpolate_covariance_matrices_numpy(
    np.round(times * (repeat_frames-1)).astype(np.int32),
    np.array([
        [0, -2, -7], [0.5, -4, -4], [0.0, -1, -1], [0.0, -3, -2], [0.5, -3, -2], 
        [0, -1, -1], [1.0, -4, -4], [1.0, -2, -7]
    ])
)
vis_covs = torch.tensor(vis_covs, device=device, dtype=torch.float32)
vis_eigvals = torch.tensor(vis_eigvals, device=device, dtype=torch.float32)
vis_eigvecs = torch.tensor(vis_eigvecs, device=device, dtype=torch.float32)

VIS_PATCH_RES = 64
vis_patch_x, vis_patch_y = 190, 120

loss_average = []
predictions, kernel_images, train_images = [], [], []
steps_history = []  # New: Track the actual step number for the X-axis

current_step = 0
frame_count = 0
noise_steps = 5
pbar = tqdm(total=vis_steps, desc="Multiscale Training")

while current_step < vis_steps:
    # Determine stride based on how many frames we have generated
    if current_step < 200:
        stride = 1
    else:
        stride = 2
        
    # Ensure we don't exceed vis_steps
    if current_step + stride > vis_steps:
        stride = vis_steps - current_step
        if stride == 0: break

    # Fit for 'stride' steps
    stats = trainer.fit(num_steps=stride, no_tqdm=True)
    
    # Update trackers
    current_step += stride
    
    # Store Data
    loss_average.append(np.mean(stats['mse_loss']))
    steps_history.append(current_step) # Store actual X-axis location
    
    current_cov = vis_covs[(frame_count//noise_steps) % vis_covs.shape[0]]
    current_eigvals = vis_eigvals[(frame_count//noise_steps) % vis_eigvals.shape[0]]
    current_eigvecs = vis_eigvecs[(frame_count//noise_steps) % vis_eigvecs.shape[0]]
    
    # Generate visualization
    output = trainer.generate_grid_data(
        resolution=SRC_RES,
        covariances=current_cov
    )
    predictions.append(output["filtered_signal"].clamp(-1.0, 1.0).add(1.0).div(2.0).cpu().numpy())
    kernel_images.append(
        render_kernel(
            kind=filter_type,
            cov_mat=current_cov,
            res=SRC_RES,
        )
    )
    # Generate an example training image used for visualization.
    train_images.append(
        trainer.generate_data_train(
            coords=output["coords"],
            eigenvalues=current_eigvals.unsqueeze(0).repeat(SRC_RES*SRC_RES, 1),
            eigenvectors=current_eigvecs.unsqueeze(0).repeat(SRC_RES*SRC_RES, 1, 1),
        )["gt_signal"].clamp(-1.0, 1.0).add(1.0).div(2.0).cpu().numpy()
    )
    
    frame_count += 1
    pbar.update(stride)

pbar.close()

# --- VIDEO GENERATION ---
video_output = output_dir / "visualizations" / "multiscale_train_vis.mov"
video_patch_output = output_dir / "visualizations" / "multiscale_train_vis_patch.mov"
video_kernel_output = output_dir / "visualizations" / "multiscale_train_vis_kernel.mov"
video_data_output = output_dir / "visualizations" / "multiscale_train_vis_data.mov"
writer = imageio.get_writer(video_output, **video_settings)
writer_patch = imageio.get_writer(video_patch_output, **video_settings)
writer_kernel = imageio.get_writer(video_kernel_output, **video_settings)
writer_data = imageio.get_writer(video_data_output, **video_settings)

# Settings
plt.style.use('default') 
max_loss = max(loss_average) * 1.1

# 1. Geometry Calculation
sample = predictions[0]
if sample.ndim == 3 and sample.shape[0] in [1, 3]:
    h, w = sample.shape[1], sample.shape[2]
else:
    h, w = sample.shape[0], sample.shape[1]

fig_width = 12 
img_height_in = fig_width * (h / w)
loss_height_in = 4.0 
total_height = img_height_in + loss_height_in

# Calculate relative height ratios for manual positioning
h_img_ratio = img_height_in / total_height
h_loss_ratio = loss_height_in / total_height

# repeat_ranges, repeat_times = [(30, 45), (60, 75), (230, 245)], 10
# repeat_ranges = [i for r in repeat_ranges for i in range(r[0], r[1])]
print("Rendering video frames...")
# Zip with steps_history to get the correct X-axis value
for i, (loss, pred, step_num) in enumerate(tqdm(zip(loss_average, predictions, steps_history), total=len(predictions))):
    
    # Create figure without subplots/layouts to allow manual placement
    fig = plt.figure(figsize=(fig_width, total_height), dpi=100)
    
    # Bottom Plot
    margin_bottom = 0.08
    h_loss_actual = h_loss_ratio - margin_bottom - 0.02
    ax_loss = fig.add_axes([0.12, margin_bottom, 0.82, h_loss_actual])
    
    margin_top = 0.02
    y_img_start = h_loss_ratio + 0.02
    h_img_actual = h_img_ratio - margin_top - 0.02
    ax_img = fig.add_axes([0.02, y_img_start, 0.96, h_img_actual])

    # --- 1. Image Plot ---
    if pred.ndim == 3 and pred.shape[0] in [1, 3]: 
        pred = np.transpose(pred, (1, 2, 0))
    
    ax_img.imshow(pred, aspect='auto') 
    
    # Hide all ticks/spines for the image
    ax_img.set_xticks([])
    ax_img.set_yticks([])
    for spine in ax_img.spines.values():
        spine.set_visible(False) 
    
    # --- 2. Loss Plot ---
    # MODIFICATION: Use steps_history for X-axis data
    current_steps_data = steps_history[:i+1]
    current_loss_data = loss_average[:i+1]
    
    ax_loss.plot(current_steps_data, current_loss_data, color='tab:blue', linewidth=3)
    
    # MODIFICATION: Use step_num (the actual training step) for the scatter X-coordinate
    ax_loss.scatter(step_num, loss, color='tab:red', s=100, zorder=5) 
    
    # MODIFICATION: Set X limit to the total vis_steps (800) rather than frame count
    ax_loss.set_xlim(0, vis_steps)
    ax_loss.set_ylim(0, max_loss)
    
    ax_loss.set_ylabel('MSE Loss', fontsize=24, labelpad=15, fontweight='medium')
    ax_loss.set_xlabel('Steps', fontsize=18) # Optional: Label x-axis
    ax_loss.tick_params(axis='both', which='major', labelsize=16)
    ax_loss.grid(True, linestyle='--', alpha=0.5, linewidth=1.5)

    # --- 3. Save Frame ---
    fig.canvas.draw()
    frame = np.array(fig.canvas.buffer_rgba())
    patch = pred[vis_patch_y:vis_patch_y+VIS_PATCH_RES, vis_patch_x:vis_patch_x+VIS_PATCH_RES]
    
    # Repeat the frames in specified range for slow down. 
    # if i in repeat_ranges:
    #     for _ in range(repeat_times):
    #         writer.append_data(frame)
    #         writer_patch.append_data((patch * 255).astype(np.uint8))
    #         writer_kernel.append_data(kernel_images[i])
    #         writer_data.append_data((train_images[i] * 255).astype(np.uint8))
    # else:
    writer.append_data(frame)
    writer_patch.append_data((patch * 255).astype(np.uint8))
    writer_kernel.append_data(kernel_images[i])
    writer_data.append_data((train_images[i] * 255).astype(np.uint8))
    
    plt.close(fig)

writer.close()
writer_patch.close()
writer_kernel.close()
writer_data.close()
print(f"Videos are saved: {video_output}, {video_patch_output}")

# Save last frame patch as PNG
last_frame_output = output_dir / "visualizations" / "multiscale_train_vis_patch.png"
patch = predictions[-1][vis_patch_y:vis_patch_y+VIS_PATCH_RES, vis_patch_x:vis_patch_x+VIS_PATCH_RES]
imageio.imwrite(
    last_frame_output, 
    (patch * 255).astype(np.uint8)
)
# Save the last traing image as PNG
last_frame_data_output = output_dir / "visualizations" / "multiscale_train_vis_data.png"
imageio.imwrite(
    last_frame_data_output, 
    (train_images[-1] * 255).astype(np.uint8)
)
# Save the last predicted image as PNG
last_frame_image_output = output_dir / "visualizations" / "multiscale_train_vis_image.png"
final_pred = predictions[-1]
if final_pred.ndim == 3 and final_pred.shape[0] in [1, 3]: 
    final_pred = np.transpose(final_pred, (1, 2, 0))
imageio.imwrite(
    last_frame_image_output, 
    (final_pred * 255).astype(np.uint8)
)

# Save the last kernel image as PNG
last_frame_kernel_output = output_dir / "visualizations" / "multiscale_train_vis_kernel.png"
imageio.imwrite(
    last_frame_kernel_output, 
    kernel_images[-1]
)

print(f"Last frame image saved: {last_frame_output}")

In [None]:
# Train remaining steps
trainer.fit(num_steps=total_steps - vis_steps)

# Save the state dict
state_dict = trainer.state_dict()
checkpoint_path = checkpoint_dir / "multiscale.pth"
torch.save(state_dict, str(checkpoint_path))

## Visualize changing covariance matrices vs predicted image

In [None]:
checkpoint_path = checkpoint_dir / "multiscale.pth"
if not checkpoint_path.exists():
    # Fit the model
    trainer.fit()

    # Save the state dict
    state_dict = trainer.state_dict()
    torch.save(state_dict, str(checkpoint_path))
else:
    # Load the state dict
    trainer = Trainer(cfg).to(device)
    state_dict = torch.load(checkpoint_path, map_location=device, weights_only=False)
    trainer.load_state_dict(state_dict)
    print(f"Loaded model from {checkpoint_path}")

In [None]:
FPS = 30
SECONDS = 20
FRAMES = FPS * SECONDS

times = np.array([0.  , 0.05, 0.2, 0.35, 0.46, 0.52, 0.64, 0.7 , 0.82, 1.  ])
vis_covs, _, _ = interpolate_covariance_matrices_numpy(
    np.round(times * (FRAMES-1)).astype(np.int32),
    np.array([
        [0, -7, -7], [0, -4, -4], [0, -1, -1], [0, -4, -1.5], [0.125, -4, -1.5], 
        [0.125, -3, -2], [0.25, -3, -2], [0.25, -4, -3], [0.5, -4, -3], [1, -7, -7],
    ])
)
vis_covs = torch.tensor(vis_covs, device=device, dtype=torch.float32)

video_output = output_dir / "visualizations" / "multiscale_filtering.mov"
kernel_output = output_dir / "visualizations" / "multiscale_filtering_kernels.mov"
writer = imageio.get_writer(video_output, **dev_video_settings)
writer_kernel = imageio.get_writer(kernel_output, **dev_video_settings)

for i in tqdm(range(FRAMES)):
    filtered_out = trainer.generate_grid_data(
        resolution=SRC_RES,
        covariances=vis_covs[i],
        filter_type=filter_type,
    )
    filtered_out = filtered_out["filtered_signal"].add(1.0).div(2.0).mul(255).round().clamp(0, 255).to(torch.uint8)
    kernel_img = render_kernel(
        kind=filter_type,
        cov_mat=vis_covs[i],
        res=SRC_RES,
    )
    
    writer.append_data(filtered_out.cpu().numpy())
    writer_kernel.append_data(kernel_img)
    
writer.close()
writer_kernel.close()
    

In [None]:
# %%
# ============ Visualization Script (Updated for Prefiltering) ============
COLOR_BOX = '#D46CCD'
COLOR_POINTS = '#4E71BE'
SAMPLE_RES = 32
FPS = 30
SECONDS = 10
FRAMES = FPS * SECONDS
video_output = output_dir / "visualizations" / "multiscale_aliasing_vis.mov"

# --- Keyframes (Fixed Timing) ---
# Format: (Time_Percent, Center_X, Center_Y, Span)
keyframes = [
    (0.00, -0.5,  0.5, 0.7), # Start Top-Left
    (0.40,  0.5,  0.5, 0.7), # Move Right
    (0.50,  0.5, -0.4, 0.85), # Zoom Out & Down
    (0.90, -0.5, -0.4, 0.85), # Move Left (Zoomed Out)
    (1.00, -0.5,  0.5, 0.7), # Zoom In & Up
]

def get_trajectory_point_time(t_global, keyframes):
    for k in range(len(keyframes) - 1):
        t0, x0, y0, s0 = keyframes[k]
        t1, x1, y1, s1 = keyframes[k+1]
        if t0 <= t_global <= t1 + 1e-5:
            segment_duration = t1 - t0
            if segment_duration <= 1e-5: return x1, y1, s1
            local_t = np.clip((t_global - t0) / segment_duration, 0, 1)
            smooth_t = local_t * local_t * (3 - 2 * local_t) 
            current_x = x0 + (x1 - x0) * smooth_t
            current_y = y0 + (y1 - y0) * smooth_t
            current_s = s0 + (s1 - s0) * smooth_t
            return current_x, current_y, current_s
    return keyframes[-1][1:]

def clamp_bounds(cx, cy, span, limit=1.0):
    half_s = span / 2.0
    if cx - half_s < -limit: cx = -limit + half_s
    if cx + half_s > limit:  cx = limit - half_s
    if cy - half_s < -limit: cy = -limit + half_s
    if cy + half_s > limit:  cy = limit - half_s
    return cx, cy

# --- Setup Plot (Now 1 Row, 3 Columns) ---
plt.style.use('seaborn-v0_8-white') 
fig, axs = plt.subplots(1, 3, figsize=(15, 5.5), dpi=200, constrained_layout=True)
fig.patch.set_facecolor('white')

# 1. Context View (Ground Truth)
full_field = trainer.data
output = trainer.generate_grid_data(resolution=SRC_RES)
reconstructed_data = output["filtered_signal"].clamp(-1.0, 1.0).add(1.0).div(2.0).cpu().numpy()
im_context = axs[0].imshow(reconstructed_data, extent=(-1, 1, 1, -1), cmap='gray', vmin=0, vmax=1)
axs[0].set_title("Learned Signal (Neural Field)", fontsize=14, fontweight='bold', color='gray')
axs[0].axis('off')

rect = patches.Rectangle((0,0), 0, 0, linewidth=2.5, edgecolor=COLOR_BOX, facecolor='none')
axs[0].add_patch(rect)
scat = axs[0].scatter([], [], s=4, c=COLOR_POINTS, edgecolors='none', alpha=0.6)

# Initialize dummy data for dynamic plots
dummy_data = np.zeros((SAMPLE_RES, SAMPLE_RES, 3))

# 2. Naive Sampling (Aliased)
im_naive = axs[1].imshow(dummy_data, vmin=0, vmax=1, interpolation='nearest') 
axs[1].set_title("Naive Sampling (Aliased)", fontsize=14, fontweight='bold', color=COLOR_BOX)
axs[1].axis('off')

# 3. Prefiltered (Ours) - New Column
im_filtered = axs[2].imshow(dummy_data, vmin=0, vmax=1, interpolation='nearest')
axs[2].set_title("Prefiltered Input (Ours)", fontsize=14, fontweight='bold', color='tab:green')
axs[2].axis('off')

# --- Render Loop ---
writer = imageio.get_writer(
    video_output, 
    fps=FPS,
    codec='libx264', 
    quality=None, 
    pixelformat='yuv420p',
    ffmpeg_params=[
        '-crf', '5', 
        '-preset', 'veryslow', 
        '-tune', 'grain'
    ]
)

print(f"Rendering {FRAMES} frames...")

for i in tqdm(range(FRAMES)):
    t = i / (FRAMES - 1)
    
    # Logic
    cx, cy, span = get_trajectory_point_time(t, keyframes)
    cx, cy = clamp_bounds(cx, cy, span, limit=0.99)
    bounds = (cx - span/2, cy - span/2, cx + span/2, cy + span/2)
    
    # Generate Coords
    coords = make_coord_grid(ndim=2, resolution=SAMPLE_RES, bounds=bounds, device=device)
    flattened_coords = coords.view(-1, 2)
    
    with torch.no_grad():
        # --- A. Naive Inference (Coords Only) ---
        naive_out = trainer.model.forward(
            {"coords": flattened_coords},
        )
        naive_out = naive_out["filtered_signal"].view(SAMPLE_RES, SAMPLE_RES, 3)
        naive_out = naive_out.clamp(-1.0, 1.0).add(1.0).div(2.0).cpu().numpy()

        # --- B. Prefiltered Inference (Coords + Covariance) ---
        # 1. Calculate pixel width based on current zoom (span)
        pixel_width = span / SAMPLE_RES
        
        # 2. Construct isotropic covariance matrix)
        cov_val = (pixel_width / 2.0) ** 2
        covariances = torch.eye(2, device=device).unsqueeze(0).repeat(flattened_coords.shape[0], 1, 1) * cov_val
        
        filtered_out = trainer.model.forward(
            {"coords": flattened_coords, "covariances": covariances},
        )
        filtered_out = filtered_out["filtered_signal"].view(SAMPLE_RES, SAMPLE_RES, 3)
        filtered_out = filtered_out.clamp(-1.0, 1.0).add(1.0).div(2.0).cpu().numpy()

    # Updates
    rect.set_xy((bounds[0], bounds[1])) 
    rect.set_width(span)
    rect.set_height(span)
    
    c_np = flattened_coords.cpu().numpy()
    scat.set_offsets(c_np)
    
    # Image Updates
    im_naive.set_data(naive_out)
    im_filtered.set_data(filtered_out) # Update third column
    
    fig.canvas.draw()
    frame = np.asarray(fig.canvas.buffer_rgba())
    writer.append_data(frame)

writer.close()
plt.close()
print(f"Done. Saved {video_output}")