In [1]:
from typing import Union, Sequence
from pathlib import Path
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import matplotlib.patches as patches
from omegaconf import OmegaConf
import hydra
from hydra import initialize_config_dir, compose
import time
import imageio

import numpy as np

import torch
import torch.nn.functional as F

import spnf
from spnf.utils import set_seed, make_coord_grid, apply_to_tensors, to_py
from spnf.sample import (
    rand_ortho, logrand, construct_covariance, sample_gaussian_delta, sample_ellipsoid_delta
)

config_dir = Path(spnf.__file__).parent / "configs"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
def get_cfg(config_name, config_dir: Path, overrides: list):
    with initialize_config_dir(version_base=None, config_dir=str(config_dir)):
        cfg = compose(config_name=config_name, overrides=overrides)
    return cfg

class Trainer(torch.nn.Module):
    def __init__(self, cfg) -> None:
        super().__init__()
        set_seed(cfg.trainer.seed)
        self.cfg = cfg
        self.model = hydra.utils.instantiate(cfg.model)
        self.data = hydra.utils.instantiate(cfg.data)
        self.optimizer = hydra.utils.instantiate(cfg.optimizer, params=self.model.parameters())
        self.scheduler = hydra.utils.instantiate(cfg.scheduler, optimizer=self.optimizer)
        if cfg.get("tensorboard", None) is not None:
            self.writer = hydra.utils.instantiate(cfg.tensorboard)
        self.global_step = 0
        
    @property
    def device(self):
        return self.model.device
    
    def state_dict(self):
        state_dict = dict(
            model=self.model.state_dict(),
            global_step=self.global_step,
            optimizer=self.optimizer.state_dict(),
        )
        if self.scheduler:
            state_dict["scheduler"] = self.scheduler.state_dict()
        return state_dict
    
    def load_state_dict(self, state_dict):
        self.model.load_state_dict(state_dict["model"])
        self.global_step = state_dict["global_step"]
        self.optimizer.load_state_dict(state_dict["optimizer"])
        if self.scheduler and "scheduler" in state_dict:
            self.scheduler.load_state_dict(state_dict["scheduler"])
        
    def generate_training_data(self):
        batch = {}
        coords = self.data.generate_coords(self.cfg.trainer.batch_size)
        batch['coords'] = coords
        
        # Perturb the coordinates during training
        if self.cfg.trainer.perturb_coords is not None:
            eigenvalues = logrand(
                self.cfg.trainer.covariance_eigenvalue_logrange[0],
                self.cfg.trainer.covariance_eigenvalue_logrange[1],
                (self.cfg.trainer.batch_size, self.model.input_dim),
                device=coords.device
            )
            eigenvectors = rand_ortho(
                self.model.input_dim,
                self.cfg.trainer.batch_size,
                device=coords.device,
            )
            batch["covariances"] = construct_covariance(eigenvectors, eigenvalues)
            
            if self.cfg.trainer.perturb_coords == "gaussian":
                deltas = sample_gaussian_delta(eigenvectors, eigenvalues)
            elif self.cfg.trainer.perturb_coords == "uniform_ellipsoid":
                deltas = sample_ellipsoid_delta(eigenvectors, eigenvalues)
            elif self.cfg.trainer.perturb_coords == "lanczos":
                raise NotImplementedError("Lanczos sampling not implemented yet.")
            else:
                raise ValueError(f"Unknown perturbation type: {self.cfg.trainer.perturb_coords}")
            
            batch['deltas'] = deltas
            
            coords = coords + deltas
                
        return batch | dict(gt_signal=self.data(coords))
    
    def train_step(self) -> dict[str, torch.Tensor]:
        batch = self.generate_training_data()
        
        pred = self.model(batch)
        loss = self.model.loss(pred, batch)

        loss_sum = sum(loss.values())
        self.optimizer.zero_grad()
        loss_sum.backward()
        self.optimizer.step()
        if self.scheduler:
            self.scheduler.step()

        return loss
    
    def generate_grid_data(self, resolution: Union[int, Sequence[int]]=512, bounds=1.0, **kwargs) -> dict[str, torch.Tensor]:
        # Setup grid resolution
        if isinstance(resolution, int):
            grid_res = (resolution,) * self.model.input_dim
        else:
            grid_res = tuple(resolution)
        
        
        # Create coordinate grid
        coord_grid = make_coord_grid(
            self.model.input_dim,
            resolution=grid_res,
            bounds=bounds,
            device=self.device,
        )
        
        # Flatten coordinate grid
        total_pixels = coord_grid.shape[:-1].numel() 
        flattened_coords = coord_grid.view(total_pixels, self.model.input_dim)
        
        # Inference
        with torch.no_grad():
            output = self.model.forward(
                {"coords": flattened_coords} | kwargs
            )
            
        def reshape_if_matching(t: torch.Tensor):
            if t.shape[0] == total_pixels:
                new_shape = grid_res + t.shape[1:]
                return t.view(new_shape)
            return t
            
        return apply_to_tensors(output, reshape_if_matching)

    def fit(self, num_steps=None, no_tqdm=False) -> dict:
        # 1. Compile train step
        if self.cfg.trainer.compile_train_step:
            train_step = torch.compile(self.train_step)
        else:
            train_step = self.train_step
        
        if num_steps is None:
            num_steps = self.cfg.trainer.steps

        pbar = tqdm(total=num_steps, desc="Training", dynamic_ncols=True, disable=no_tqdm)
        
        stats_history = []
        for _ in range(num_steps):
            stats = train_step()
            self.global_step += 1
            
            # Convert tensors to python scalars/lists
            step_stats = {k: to_py(v) for k, v in stats.items()}
            stats_history.append(step_stats)

            # Update progress bar (using the raw tensor items for display is fine/fast)
            pbar.update(1)
            pbar.set_postfix({k: f"{v:.4f}"  for k, v in step_stats.items() if isinstance(v, float)})

        pbar.close()
        
        # Concatenate (Transpose)
        if stats_history:
            final_stats = {k: [step[k] for step in stats_history] for k in stats_history[0]}
        else:
            final_stats = {}
            
        return final_stats

## Training visualization for neural fields

In [3]:
lambda_decay_start = 500
vis_steps = 800
total_steps = 2000
SRC_RES = 512
cfg = get_cfg(
    config_name="train",
    config_dir=config_dir,
    overrides=[
        f"data.grid.resize={SRC_RES}",
        "data.bounds=1.0", 
        "tensorboard=null",
        f"trainer.steps={total_steps}",
        "paths.output_dir=/home/myaldiz/Data/Experiments/spnf/${task_name}",
        f"scheduler.lr_lambda.decay_start={lambda_decay_start}",
        "trainer.compile_train_step=False",
        "trainer.perturb_coords=null"
    ],
)
# Create output directories
output_dir = Path(cfg.paths.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
checkpoint_dir = Path(cfg.trainer.checkpoint_dir)
checkpoint_dir.mkdir(parents=True, exist_ok=True)

# Save the config to output directory
OmegaConf.save(config=cfg, f=output_dir / "config_singlescale.yaml", resolve=True)

In [None]:
trainer = Trainer(cfg).to(device)

loss_average = []
predictions = []
steps_history = []  # New: Track the actual step number for the X-axis

current_step = 0
frame_count = 0
pbar = tqdm(total=vis_steps, desc="Training")

while current_step < vis_steps:
    # Determine stride based on how many frames we have generated
    if current_step < 90:
        stride = 1
    elif current_step < 210:
        stride = 2
    else:
        stride = 5
        
    # Ensure we don't exceed vis_steps
    if current_step + stride > vis_steps:
        stride = vis_steps - current_step
        if stride == 0: break

    # Fit for 'stride' steps
    stats = trainer.fit(num_steps=stride, no_tqdm=True)
    
    # Update trackers
    current_step += stride
    frame_count += 1
    pbar.update(stride)

    # Store Data
    loss_average.append(np.mean(stats['mse_loss']))
    steps_history.append(current_step) # Store actual X-axis location
    
    # Generate visualization
    output = trainer.generate_grid_data(resolution=SRC_RES)
    predictions.append(output["filtered_signal"].clamp(-1.0, 1.0).add(1.0).div(2.0).cpu().numpy())

pbar.close()

# --- VIDEO GENERATION ---
video_output = output_dir / "visualizations" / "singlescale_train_vis.mp4"
video_output.parent.mkdir(parents=True, exist_ok=True)
writer = imageio.get_writer(
    video_output, 
    fps=30, 
    codec='libx264', 
    quality=None, 
    pixelformat='yuv420p',
    ffmpeg_params=[
        '-crf', '18', 
        '-preset', 'veryslow', 
        '-tune', 'grain'
    ]
)

# Settings
plt.style.use('default') 
max_loss = max(loss_average) * 1.1

# 1. Geometry Calculation
sample = predictions[0]
if sample.ndim == 3 and sample.shape[0] in [1, 3]:
    h, w = sample.shape[1], sample.shape[2]
else:
    h, w = sample.shape[0], sample.shape[1]

fig_width = 12 
img_height_in = fig_width * (h / w)
loss_height_in = 4.0 
total_height = img_height_in + loss_height_in

# Calculate relative height ratios for manual positioning
h_img_ratio = img_height_in / total_height
h_loss_ratio = loss_height_in / total_height

print("Rendering video frames...")
# Zip with steps_history to get the correct X-axis value
for i, (loss, pred, step_num) in enumerate(tqdm(zip(loss_average, predictions, steps_history), total=len(predictions))):
    
    # Create figure without subplots/layouts to allow manual placement
    fig = plt.figure(figsize=(fig_width, total_height), dpi=100)
    
    # Bottom Plot
    margin_bottom = 0.08
    h_loss_actual = h_loss_ratio - margin_bottom - 0.02
    ax_loss = fig.add_axes([0.12, margin_bottom, 0.82, h_loss_actual])
    
    margin_top = 0.02
    y_img_start = h_loss_ratio + 0.02
    h_img_actual = h_img_ratio - margin_top - 0.02
    ax_img = fig.add_axes([0.02, y_img_start, 0.96, h_img_actual])

    # --- 1. Image Plot ---
    if pred.ndim == 3 and pred.shape[0] in [1, 3]: 
        pred = np.transpose(pred, (1, 2, 0))
    
    ax_img.imshow(pred, aspect='auto') 
    
    # Hide all ticks/spines for the image
    ax_img.set_xticks([])
    ax_img.set_yticks([])
    for spine in ax_img.spines.values():
        spine.set_visible(False) 
    
    # --- 2. Loss Plot ---
    # MODIFICATION: Use steps_history for X-axis data
    current_steps_data = steps_history[:i+1]
    current_loss_data = loss_average[:i+1]
    
    ax_loss.plot(current_steps_data, current_loss_data, color='tab:blue', linewidth=3)
    
    # MODIFICATION: Use step_num (the actual training step) for the scatter X-coordinate
    ax_loss.scatter(step_num, loss, color='tab:red', s=100, zorder=5) 
    
    # MODIFICATION: Set X limit to the total vis_steps (800) rather than frame count
    ax_loss.set_xlim(0, vis_steps)
    ax_loss.set_ylim(0, max_loss)
    
    ax_loss.set_ylabel('MSE Loss', fontsize=24, labelpad=15, fontweight='medium')
    ax_loss.set_xlabel('Steps', fontsize=18) # Optional: Label x-axis
    ax_loss.tick_params(axis='both', which='major', labelsize=16)
    ax_loss.grid(True, linestyle='--', alpha=0.5, linewidth=1.5)

    # --- 3. Save Frame ---
    fig.canvas.draw()
    frame = np.array(fig.canvas.buffer_rgba())
    
    writer.append_data(frame)
    plt.close(fig)

writer.close()
print(f"Video saved to {video_output}")

In [None]:
# Train remaining steps
trainer.fit(num_steps=total_steps - vis_steps)

# Save the state dict
state_dict = trainer.state_dict()
checkpoint_path = checkpoint_dir / "singlescale.pth"
torch.save(state_dict, str(checkpoint_path))

## Visualize aliasing artifacts

In [None]:
checkpoint_path = checkpoint_dir / "singlescale.pth"
if not checkpoint_path.exists():
    # Fit the model
    trainer.fit()

    # Save the state dict
    state_dict = trainer.state_dict()
    torch.save(state_dict, str(checkpoint_path))
else:
    # Load the state dict
    trainer = Trainer(cfg).to(device)
    state_dict = torch.load(checkpoint_path, map_location=device, weights_only=False)
    trainer.load_state_dict(state_dict)
    print(f"Loaded model from {checkpoint_path}")

In [None]:
# ============ Visualization Script ============
COLOR_BOX = '#D46CCD'
COLOR_POINTS = '#4E71BE'
SAMPLE_RES = 32
FPS = 30
SECONDS = 10
FRAMES = FPS * SECONDS
video_output = output_dir / "visualizations" / "singlescale_aliasing_vis.mp4"

# --- Keyframes (Fixed Timing) ---
# Format: (Time_Percent, Center_X, Center_Y, Span)
keyframes = [
    (0.00, -0.5,  0.5, 0.4), # Start Top-Left
    (0.40,  0.5,  0.5, 0.4), # Move Right
    (0.50,  0.5, -0.5, 1.0), # Zoom Out & Down
    (0.90, -0.5, -0.5, 1.0), # Move Left (Zoomed Out)
    (1.00, -0.5,  0.5, 0.4), # Zoom In & Up
]

def get_trajectory_point_time(t_global, keyframes):
    for k in range(len(keyframes) - 1):
        t0, x0, y0, s0 = keyframes[k]
        t1, x1, y1, s1 = keyframes[k+1]
        if t0 <= t_global <= t1 + 1e-5:
            segment_duration = t1 - t0
            if segment_duration <= 1e-5: return x1, y1, s1
            local_t = np.clip((t_global - t0) / segment_duration, 0, 1)
            smooth_t = local_t * local_t * (3 - 2 * local_t) 
            current_x = x0 + (x1 - x0) * smooth_t
            current_y = y0 + (y1 - y0) * smooth_t
            current_s = s0 + (s1 - s0) * smooth_t
            return current_x, current_y, current_s
    return keyframes[-1][1:]

def clamp_bounds(cx, cy, span, limit=1.0):
    half_s = span / 2.0
    if cx - half_s < -limit: cx = -limit + half_s
    if cx + half_s > limit:  cx = limit - half_s
    if cy - half_s < -limit: cy = -limit + half_s
    if cy + half_s > limit:  cy = limit - half_s
    return cx, cy

# --- Setup Plot ---
plt.style.use('seaborn-v0_8-white') 
fig, axs = plt.subplots(1, 2, figsize=(10, 5.5), dpi=200, constrained_layout=True)
fig.patch.set_facecolor('white')

# Context View
full_field = trainer.data
output = trainer.generate_grid_data(resolution=SRC_RES)
reconstructed_data = output["filtered_signal"].clamp(-1.0, 1.0).add(1.0).div(2.0).cpu().numpy()
im_context = axs[0].imshow(reconstructed_data, extent=(-1, 1, 1, -1), cmap='gray', vmin=0, vmax=1)
axs[0].set_title("Learned Signal (Neural Field)", fontsize=14, fontweight='bold', color='gray')
axs[0].axis('off')

rect = patches.Rectangle((0,0), 0, 0, linewidth=2.5, edgecolor=COLOR_BOX, facecolor='none')
axs[0].add_patch(rect)

# FIX 1: Smaller points (s=4) and alpha=0.6 for better visibility without occlusion
scat = axs[0].scatter([], [], s=4, c=COLOR_POINTS, edgecolors='none', alpha=0.6)

# Naive & Ours Views
dummy_data = np.zeros((SAMPLE_RES, SAMPLE_RES, 3))
im_naive = axs[1].imshow(dummy_data, vmin=0, vmax=1, interpolation='nearest') 
axs[1].set_title("Naive Sampling (Aliased)", fontsize=14, fontweight='bold', color=COLOR_BOX)
axs[1].axis('off')

# --- Render Loop ---
# FIX 2: High Quality Compression settings
# -crf 18: Visually lossless (lower is better quality, 0 is lossless)
# -tune grain: Preserves high frequency noise/flicker
# -preset veryslow: Best compression efficiency for the file size
writer = imageio.get_writer(
    video_output, 
    fps=FPS,
    codec='libx264', 
    quality=None, 
    pixelformat='yuv420p',
    ffmpeg_params=[
        '-crf', '18', 
        '-preset', 'veryslow', 
        '-tune', 'grain'
    ]
)

print(f"Rendering {FRAMES} frames...")

for i in tqdm(range(FRAMES)):
    t = i / (FRAMES - 1)
    
    # Logic
    cx, cy, span = get_trajectory_point_time(t, keyframes)
    cx, cy = clamp_bounds(cx, cy, span, limit=0.99)
    bounds = (cx - span/2, cy - span/2, cx + span/2, cy + span/2)
    coords = make_coord_grid(ndim=2, resolution=SAMPLE_RES, bounds=bounds, device=device)

    # Naive
    with torch.no_grad():
        naive_out = trainer.model.forward(
            {"coords": coords.view(-1, 2)},
        )
        naive_out = naive_out["filtered_signal"].view(SAMPLE_RES, SAMPLE_RES, 3)
        naive_out = naive_out.clamp(-1.0, 1.0).add(1.0).div(2.0).cpu().numpy()

    # Updates
    rect.set_xy((bounds[0], bounds[1])) 
    rect.set_width(span)
    rect.set_height(span)
    
    c_np = coords.view(-1,2).cpu().numpy()
    scat.set_offsets(c_np)
    
    # Image Updates
    im_naive.set_data(naive_out)
    
    fig.canvas.draw()
    frame = np.asarray(fig.canvas.buffer_rgba())
    writer.append_data(frame)

writer.close()
plt.close()
print(f"Done. Saved {video_output}")

## Multiscale training

In [5]:
lambda_decay_start = 500
vis_steps = 800
total_steps = 2000
SRC_RES = 512
cfg = get_cfg(
    config_name="train",
    config_dir=config_dir,
    overrides=[
        f"data.grid.resize={SRC_RES}",
        "data.bounds=1.0", 
        "tensorboard=null",
        f"trainer.steps={total_steps}",
        "paths.output_dir=/home/myaldiz/Data/Experiments/spnf/${task_name}",
        f"scheduler.lr_lambda.decay_start={lambda_decay_start}",
        "trainer.compile_train_step=True",
        "trainer.perturb_coords=gaussian"
    ],
)
# Create output directories
output_dir = Path(cfg.paths.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
checkpoint_dir = Path(cfg.trainer.checkpoint_dir)
checkpoint_dir.mkdir(parents=True, exist_ok=True)

# Save the config to output directory
OmegaConf.save(config=cfg, f=output_dir / "config_multiscale.yaml", resolve=True)

In [6]:
trainer = Trainer(cfg).to(device)

loss_average = []
predictions = []
steps_history = []  # New: Track the actual step number for the X-axis

current_step = 0
frame_count = 0
pbar = tqdm(total=vis_steps, desc="Training multiscale")

while current_step < vis_steps:
    # Determine stride based on how many frames we have generated
    if current_step < 90:
        stride = 1
    elif current_step < 210:
        stride = 2
    else:
        stride = 5
        
    # Ensure we don't exceed vis_steps
    if current_step + stride > vis_steps:
        stride = vis_steps - current_step
        if stride == 0: break

    # Fit for 'stride' steps
    stats = trainer.fit(num_steps=stride, no_tqdm=True)
    
    # Update trackers
    current_step += stride
    frame_count += 1
    pbar.update(stride)

    # Store Data
    loss_average.append(np.mean(stats['mse_loss']))
    steps_history.append(current_step) # Store actual X-axis location
    
    # Generate visualization
    output = trainer.generate_grid_data(resolution=SRC_RES)
    predictions.append(output["filtered_signal"].clamp(-1.0, 1.0).add(1.0).div(2.0).cpu().numpy())

pbar.close()

# --- VIDEO GENERATION ---
video_output = output_dir / "visualizations" / "multiscale_train_vis.mp4"
video_output.parent.mkdir(parents=True, exist_ok=True)
writer = imageio.get_writer(
    video_output, 
    fps=30, 
    codec='libx264', 
    quality=None, 
    pixelformat='yuv420p',
    ffmpeg_params=[
        '-crf', '18', 
        '-preset', 'veryslow', 
        '-tune', 'grain'
    ]
)

# Settings
plt.style.use('default') 
max_loss = max(loss_average) * 1.1

# 1. Geometry Calculation
sample = predictions[0]
if sample.ndim == 3 and sample.shape[0] in [1, 3]:
    h, w = sample.shape[1], sample.shape[2]
else:
    h, w = sample.shape[0], sample.shape[1]

fig_width = 12 
img_height_in = fig_width * (h / w)
loss_height_in = 4.0 
total_height = img_height_in + loss_height_in

# Calculate relative height ratios for manual positioning
h_img_ratio = img_height_in / total_height
h_loss_ratio = loss_height_in / total_height

print("Rendering video frames...")
# Zip with steps_history to get the correct X-axis value
for i, (loss, pred, step_num) in enumerate(tqdm(zip(loss_average, predictions, steps_history), total=len(predictions))):
    
    # Create figure without subplots/layouts to allow manual placement
    fig = plt.figure(figsize=(fig_width, total_height), dpi=100)
    
    # Bottom Plot
    margin_bottom = 0.08
    h_loss_actual = h_loss_ratio - margin_bottom - 0.02
    ax_loss = fig.add_axes([0.12, margin_bottom, 0.82, h_loss_actual])
    
    margin_top = 0.02
    y_img_start = h_loss_ratio + 0.02
    h_img_actual = h_img_ratio - margin_top - 0.02
    ax_img = fig.add_axes([0.02, y_img_start, 0.96, h_img_actual])

    # --- 1. Image Plot ---
    if pred.ndim == 3 and pred.shape[0] in [1, 3]: 
        pred = np.transpose(pred, (1, 2, 0))
    
    ax_img.imshow(pred, aspect='auto') 
    
    # Hide all ticks/spines for the image
    ax_img.set_xticks([])
    ax_img.set_yticks([])
    for spine in ax_img.spines.values():
        spine.set_visible(False) 
    
    # --- 2. Loss Plot ---
    # MODIFICATION: Use steps_history for X-axis data
    current_steps_data = steps_history[:i+1]
    current_loss_data = loss_average[:i+1]
    
    ax_loss.plot(current_steps_data, current_loss_data, color='tab:blue', linewidth=3)
    
    # MODIFICATION: Use step_num (the actual training step) for the scatter X-coordinate
    ax_loss.scatter(step_num, loss, color='tab:red', s=100, zorder=5) 
    
    # MODIFICATION: Set X limit to the total vis_steps (800) rather than frame count
    ax_loss.set_xlim(0, vis_steps)
    ax_loss.set_ylim(0, max_loss)
    
    ax_loss.set_ylabel('MSE Loss', fontsize=24, labelpad=15, fontweight='medium')
    ax_loss.set_xlabel('Steps', fontsize=18) # Optional: Label x-axis
    ax_loss.tick_params(axis='both', which='major', labelsize=16)
    ax_loss.grid(True, linestyle='--', alpha=0.5, linewidth=1.5)

    # --- 3. Save Frame ---
    fig.canvas.draw()
    frame = np.array(fig.canvas.buffer_rgba())
    
    writer.append_data(frame)
    plt.close(fig)

writer.close()
print(f"Video saved to {video_output}")

Training multiscale:   0%|          | 0/800 [00:00<?, ?it/s]

W1209 22:27:52.427000 2071494 site-packages/torch/_logging/_internal.py:1199] [4/0] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
W1209 22:27:55.211000 2071494 site-packages/torch/_dynamo/convert_frame.py:1358] [8/8] torch._dynamo hit config.recompile_limit (8)
W1209 22:27:55.211000 2071494 site-packages/torch/_dynamo/convert_frame.py:1358] [8/8]    function: 'scheduler' (/home/myaldiz/GitHub/Spectral-Prefiltering-of-Neural-Fields/spnf/utils.py:121)
W1209 22:27:55.211000 2071494 site-packages/torch/_dynamo/convert_frame.py:1358] [8/8]    last reason: 8/7: (step / 100) == 0.08  # multiplier = np.clip(multiplier, 0.0, 1.0)  # GitHub/Spectral-Prefiltering-of-Neural-Fields/spnf/utils.py:149 in scheduler (_numpy/_util.py:177 in _try_convert_to_tensor)
W1209 22:27:55.211000 2071494 site-packages/torch/_dynamo/convert_frame.py:1358] [8/8] To log all recompilation reasons, use TORCH_LOGS="recompiles".
W1209 22:27:55.211000 2071494 site-packages/torch/_dyna

Rendering video frames...


  0%|          | 0/268 [00:00<?, ?it/s]

Video saved to /home/myaldiz/Data/Experiments/spnf/spnf-oia/visualizations/multiscale_train_vis.mp4


In [7]:
# Train remaining steps
trainer.fit(num_steps=total_steps - vis_steps)

# Save the state dict
state_dict = trainer.state_dict()
checkpoint_path = checkpoint_dir / "multiscale.pth"
torch.save(state_dict, str(checkpoint_path))

Training:   0%|          | 0/1200 [00:00<?, ?it/s]

In [12]:
# %%
# ============ Visualization Script (Updated for Prefiltering) ============
COLOR_BOX = '#D46CCD'
COLOR_POINTS = '#4E71BE'
SAMPLE_RES = 32
FPS = 30
SECONDS = 10
FRAMES = FPS * SECONDS
video_output = output_dir / "visualizations" / "multiscale_aliasing_vis.mp4"

# --- Keyframes (Fixed Timing) ---
# Format: (Time_Percent, Center_X, Center_Y, Span)
keyframes = [
    (0.00, -0.5,  0.5, 0.4), # Start Top-Left
    (0.40,  0.5,  0.5, 0.4), # Move Right
    (0.50,  0.5, -0.5, 1.0), # Zoom Out & Down
    (0.90, -0.5, -0.5, 1.0), # Move Left (Zoomed Out)
    (1.00, -0.5,  0.5, 0.4), # Zoom In & Up
]

def get_trajectory_point_time(t_global, keyframes):
    for k in range(len(keyframes) - 1):
        t0, x0, y0, s0 = keyframes[k]
        t1, x1, y1, s1 = keyframes[k+1]
        if t0 <= t_global <= t1 + 1e-5:
            segment_duration = t1 - t0
            if segment_duration <= 1e-5: return x1, y1, s1
            local_t = np.clip((t_global - t0) / segment_duration, 0, 1)
            smooth_t = local_t * local_t * (3 - 2 * local_t) 
            current_x = x0 + (x1 - x0) * smooth_t
            current_y = y0 + (y1 - y0) * smooth_t
            current_s = s0 + (s1 - s0) * smooth_t
            return current_x, current_y, current_s
    return keyframes[-1][1:]

def clamp_bounds(cx, cy, span, limit=1.0):
    half_s = span / 2.0
    if cx - half_s < -limit: cx = -limit + half_s
    if cx + half_s > limit:  cx = limit - half_s
    if cy - half_s < -limit: cy = -limit + half_s
    if cy + half_s > limit:  cy = limit - half_s
    return cx, cy

# --- Setup Plot (Now 1 Row, 3 Columns) ---
plt.style.use('seaborn-v0_8-white') 
fig, axs = plt.subplots(1, 3, figsize=(15, 5.5), dpi=200, constrained_layout=True)
fig.patch.set_facecolor('white')

# 1. Context View (Ground Truth)
full_field = trainer.data
output = trainer.generate_grid_data(resolution=SRC_RES)
reconstructed_data = output["filtered_signal"].clamp(-1.0, 1.0).add(1.0).div(2.0).cpu().numpy()
im_context = axs[0].imshow(reconstructed_data, extent=(-1, 1, 1, -1), cmap='gray', vmin=0, vmax=1)
axs[0].set_title("Learned Signal (Neural Field)", fontsize=14, fontweight='bold', color='gray')
axs[0].axis('off')

rect = patches.Rectangle((0,0), 0, 0, linewidth=2.5, edgecolor=COLOR_BOX, facecolor='none')
axs[0].add_patch(rect)
scat = axs[0].scatter([], [], s=4, c=COLOR_POINTS, edgecolors='none', alpha=0.6)

# Initialize dummy data for dynamic plots
dummy_data = np.zeros((SAMPLE_RES, SAMPLE_RES, 3))

# 2. Naive Sampling (Aliased)
im_naive = axs[1].imshow(dummy_data, vmin=0, vmax=1, interpolation='nearest') 
axs[1].set_title("Naive Sampling (Aliased)", fontsize=14, fontweight='bold', color=COLOR_BOX)
axs[1].axis('off')

# 3. Prefiltered (Ours) - New Column
im_filtered = axs[2].imshow(dummy_data, vmin=0, vmax=1, interpolation='nearest')
axs[2].set_title("Prefiltered Input (Ours)", fontsize=14, fontweight='bold', color='tab:green')
axs[2].axis('off')

# --- Render Loop ---
writer = imageio.get_writer(
    video_output, 
    fps=FPS,
    codec='libx264', 
    quality=None, 
    pixelformat='yuv420p',
    ffmpeg_params=[
        '-crf', '5', 
        '-preset', 'veryslow', 
        '-tune', 'grain'
    ]
)

print(f"Rendering {FRAMES} frames...")

for i in tqdm(range(FRAMES)):
    t = i / (FRAMES - 1)
    
    # Logic
    cx, cy, span = get_trajectory_point_time(t, keyframes)
    cx, cy = clamp_bounds(cx, cy, span, limit=0.99)
    bounds = (cx - span/2, cy - span/2, cx + span/2, cy + span/2)
    
    # Generate Coords
    coords = make_coord_grid(ndim=2, resolution=SAMPLE_RES, bounds=bounds, device=device)
    flattened_coords = coords.view(-1, 2)
    
    with torch.no_grad():
        # --- A. Naive Inference (Coords Only) ---
        naive_out = trainer.model.forward(
            {"coords": flattened_coords},
        )
        naive_out = naive_out["filtered_signal"].view(SAMPLE_RES, SAMPLE_RES, 3)
        naive_out = naive_out.clamp(-1.0, 1.0).add(1.0).div(2.0).cpu().numpy()

        # --- B. Prefiltered Inference (Coords + Covariance) ---
        # 1. Calculate pixel width based on current zoom (span)
        pixel_width = span / SAMPLE_RES
        
        # 2. Construct isotropic covariance matrix: Diagonal(pixel_width^2)
        # Shape: (N, 2, 2)
        # Note: Depending on your specific kernel definition, you might scale this 
        # (e.g., pixel_width/2 for radius). Using width^2 is standard for variance.
        # Match variance of a box filter: sigma^2 = width^2 / 12
        cov_val = pixel_width ** 2 / 12.0
        covariances = torch.eye(2, device=device).unsqueeze(0).repeat(flattened_coords.shape[0], 1, 1) * cov_val
        
        filtered_out = trainer.model.forward(
            {"coords": flattened_coords, "covariances": covariances},
        )
        filtered_out = filtered_out["filtered_signal"].view(SAMPLE_RES, SAMPLE_RES, 3)
        filtered_out = filtered_out.clamp(-1.0, 1.0).add(1.0).div(2.0).cpu().numpy()

    # Updates
    rect.set_xy((bounds[0], bounds[1])) 
    rect.set_width(span)
    rect.set_height(span)
    
    c_np = flattened_coords.cpu().numpy()
    scat.set_offsets(c_np)
    
    # Image Updates
    im_naive.set_data(naive_out)
    im_filtered.set_data(filtered_out) # Update third column
    
    fig.canvas.draw()
    frame = np.asarray(fig.canvas.buffer_rgba())
    writer.append_data(frame)

writer.close()
plt.close()
print(f"Done. Saved {video_output}")

Rendering 300 frames...


  0%|          | 0/300 [00:00<?, ?it/s]



Done. Saved /home/myaldiz/Data/Experiments/spnf/spnf-oia/visualizations/multiscale_aliasing_vis.mp4
