In [None]:
from typing import Union, Sequence
from pathlib import Path
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import matplotlib.patches as patches
from omegaconf import OmegaConf
import hydra
from hydra import initialize_config_dir, compose
import time
import imageio

import numpy as np

import torch
import torch.nn.functional as F

import spnf
from spnf.utils import set_seed, make_coord_grid, apply_to_tensors, to_py
from spnf.sample import (
    rand_ortho, logrand, construct_covariance, sample_gaussian_delta, sample_ellipsoid_delta
)
from spnf.trainer import Trainer

config_dir = Path(spnf.__file__).parent / "configs"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_cfg(config_name, config_dir: Path, overrides: list):
    with initialize_config_dir(version_base=None, config_dir=str(config_dir)):
        cfg = compose(config_name=config_name, overrides=overrides)
    return cfg

## Training visualization for neural fields

In [None]:
lambda_decay_start = 500
vis_steps = 800
total_steps = 2000
SRC_RES = 512
FPS = 30
cfg = get_cfg(
    config_name="train",
    config_dir=config_dir,
    overrides=[
        f"data.grid.resize={SRC_RES}",
        "data.bounds=1.0", 
        "tensorboard=null",
        f"trainer.steps={total_steps}",
        "paths.output_dir=/home/myaldiz/Data/Experiments/spnf/${task_name}",
        f"scheduler.lr_lambda.decay_start={lambda_decay_start}",
        "trainer.compile_train_step=False",
        "trainer.mc_samples_train=0"
    ],
)
video_settings = dict(
    fps=FPS,
    codec='libx265', 
    pixelformat='yuv420p',
    ffmpeg_params=[
        '-crf', '10',          # Lower CRF = Higher Quality (10 is extremely high quality)
        '-preset', 'veryslow', # Takes longer to render, but best compression
        '-tag:v', 'hvc1',      # Critical for Apple compatibility
        '-tune', 'grain',      # Preserves noise/aliasing details
        '-x265-params', 'keyint=30:bframes=0' # Advanced Tweaks (see below)
    ]
)
dev_video_settings = video_settings | dict(
    ffmpeg_params=[
        '-crf', '20', 
        '-preset', 'fast',
        '-tag:v', 'hvc1',
        '-tune', 'grain',
    ]
)

# Create output directories
output_dir = Path(cfg.paths.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
checkpoint_dir = Path(cfg.trainer.checkpoint_dir)
checkpoint_dir.mkdir(parents=True, exist_ok=True)
video_output_dir = output_dir / "visualizations" 
video_output_dir.mkdir(parents=True, exist_ok=True)

# Save the config to output directory
OmegaConf.save(config=cfg, f=output_dir / "config_singlescale.yaml", resolve=True)

In [None]:
VIS_PATCH_RES = 64
vis_patch_x, vis_patch_y = 190, 120
trainer = Trainer(cfg).to(device)

loss_average = []
predictions = []
steps_history = []  # New: Track the actual step number for the X-axis

current_step = 0
frame_count = 0
pbar = tqdm(total=vis_steps, desc="Training")

while current_step < vis_steps:
    # Determine stride based on how many frames we have generated
    if current_step < 90:
        stride = 1
    elif current_step < 210:
        stride = 2
    else:
        stride = 5
        
    # Ensure we don't exceed vis_steps
    if current_step + stride > vis_steps:
        stride = vis_steps - current_step
        if stride == 0: break

    # Fit for 'stride' steps
    stats = trainer.fit(num_steps=stride, no_tqdm=True)
    
    # Update trackers
    current_step += stride
    frame_count += 1
    pbar.update(stride)

    # Store Data
    loss_average.append(np.mean(stats['mse_loss']))
    steps_history.append(current_step) # Store actual X-axis location
    
    # Generate visualization
    output = trainer.generate_grid_data(resolution=SRC_RES)
    predictions.append(output["filtered_signal"].clamp(-1.0, 1.0).add(1.0).div(2.0).cpu().numpy())

pbar.close()

# --- VIDEO GENERATION ---
video_output = output_dir / "visualizations" / "singlescale_train_vis.mov"
video_patch_output = output_dir / "visualizations" / "singlescale_train_vis_patch.mov"
writer = imageio.get_writer(
    video_output, 
    **video_settings
)
writer_patch = imageio.get_writer(
    video_patch_output, 
    **video_settings
)

# Settings
plt.style.use('default') 
max_loss = max(loss_average) * 1.1

# 1. Geometry Calculation
sample = predictions[0]
if sample.ndim == 3 and sample.shape[0] in [1, 3]:
    h, w = sample.shape[1], sample.shape[2]
else:
    h, w = sample.shape[0], sample.shape[1]

fig_width = 12 
img_height_in = fig_width * (h / w)
loss_height_in = 4.0 
total_height = img_height_in + loss_height_in

# Calculate relative height ratios for manual positioning
h_img_ratio = img_height_in / total_height
h_loss_ratio = loss_height_in / total_height

print("Rendering video frames...")
# Zip with steps_history to get the correct X-axis value
for i, (loss, pred, step_num) in enumerate(tqdm(zip(loss_average, predictions, steps_history), total=len(predictions))):
    
    # Create figure without subplots/layouts to allow manual placement
    fig = plt.figure(figsize=(fig_width, total_height), dpi=100)
    
    # Bottom Plot
    margin_bottom = 0.08
    h_loss_actual = h_loss_ratio - margin_bottom - 0.02
    ax_loss = fig.add_axes([0.12, margin_bottom, 0.82, h_loss_actual])
    
    margin_top = 0.02
    y_img_start = h_loss_ratio + 0.02
    h_img_actual = h_img_ratio - margin_top - 0.02
    ax_img = fig.add_axes([0.02, y_img_start, 0.96, h_img_actual])

    # --- 1. Image Plot ---
    if pred.ndim == 3 and pred.shape[0] in [1, 3]: 
        pred = np.transpose(pred, (1, 2, 0))
    
    ax_img.imshow(pred, aspect='auto') 
    
    # Hide all ticks/spines for the image
    ax_img.set_xticks([])
    ax_img.set_yticks([])
    for spine in ax_img.spines.values():
        spine.set_visible(False) 
    
    # --- 2. Loss Plot ---
    # MODIFICATION: Use steps_history for X-axis data
    current_steps_data = steps_history[:i+1]
    current_loss_data = loss_average[:i+1]
    
    ax_loss.plot(current_steps_data, current_loss_data, color='tab:blue', linewidth=3)
    
    # MODIFICATION: Use step_num (the actual training step) for the scatter X-coordinate
    ax_loss.scatter(step_num, loss, color='tab:red', s=100, zorder=5) 
    
    # MODIFICATION: Set X limit to the total vis_steps (800) rather than frame count
    ax_loss.set_xlim(0, vis_steps)
    ax_loss.set_ylim(0, max_loss)
    
    ax_loss.set_ylabel('MSE Loss', fontsize=24, labelpad=15, fontweight='medium')
    ax_loss.set_xlabel('Steps', fontsize=18) # Optional: Label x-axis
    ax_loss.tick_params(axis='both', which='major', labelsize=16)
    ax_loss.grid(True, linestyle='--', alpha=0.5, linewidth=1.5)

    # --- 3. Save Frame ---
    fig.canvas.draw()
    frame = np.array(fig.canvas.buffer_rgba())
    
    writer.append_data(frame)
    patch = pred[vis_patch_y:vis_patch_y+VIS_PATCH_RES, vis_patch_x:vis_patch_x+VIS_PATCH_RES]
    
    writer_patch.append_data((patch * 255).astype(np.uint8))
    plt.close(fig)

writer.close()
writer_patch.close()
print(f"Videos are saved: {video_output}, {video_patch_output}")

# Save last frame patch as PNG
last_frame_output = output_dir / "visualizations" / "singlescale_train_vis_patch.png"
patch = predictions[-1][vis_patch_y:vis_patch_y+VIS_PATCH_RES, vis_patch_x:vis_patch_x+VIS_PATCH_RES]
imageio.imwrite(
    last_frame_output, 
    (patch * 255).astype(np.uint8)
)
print(f"Last frame image saved: {last_frame_output}")

In [None]:
# Train remaining steps
trainer.fit(num_steps=total_steps - vis_steps)

# Save the state dict
state_dict = trainer.state_dict()
checkpoint_path = checkpoint_dir / "singlescale.pth"
torch.save(state_dict, str(checkpoint_path))

## Visualize aliasing artifacts

In [None]:
checkpoint_path = checkpoint_dir / "singlescale.pth"
if not checkpoint_path.exists():
    # Fit the model
    trainer.fit()

    # Save the state dict
    state_dict = trainer.state_dict()
    torch.save(state_dict, str(checkpoint_path))
else:
    # Load the state dict
    trainer = Trainer(cfg).to(device)
    state_dict = torch.load(checkpoint_path, map_location=device, weights_only=False)
    trainer.load_state_dict(state_dict)
    print(f"Loaded model from {checkpoint_path}")

In [None]:
# ============ Visualization Script ============
COLOR_BOX = '#D46CCD'
COLOR_POINTS = '#4E71BE'
SAMPLE_RES = 32
FPS = 30
SECONDS = 12
FRAMES = FPS * SECONDS
video_output = output_dir / "visualizations" / "singlescale_aliasing_vis.mov"

# --- Keyframes (Fixed Timing) ---
# Format: (Time_Percent, Center_X, Center_Y, Span)
keyframes = [
    (0.00, -0.5,  0.5, 0.65), 
    (0.40,  0.5,  0.5, 0.65), 
    (0.50,  0.5, -0.4, 0.95),
    (0.90, -0.5, -0.4, 0.95), 
    (1.00, -0.5,  0.5, 0.65),
]

def get_trajectory_point_time(t_global, keyframes):
    for k in range(len(keyframes) - 1):
        t0, x0, y0, s0 = keyframes[k]
        t1, x1, y1, s1 = keyframes[k+1]
        if t0 <= t_global <= t1 + 1e-5:
            segment_duration = t1 - t0
            if segment_duration <= 1e-5: return x1, y1, s1
            local_t = np.clip((t_global - t0) / segment_duration, 0, 1)
            smooth_t = local_t * local_t * (3 - 2 * local_t) 
            current_x = x0 + (x1 - x0) * smooth_t
            current_y = y0 + (y1 - y0) * smooth_t
            current_s = s0 + (s1 - s0) * smooth_t
            return current_x, current_y, current_s
    return keyframes[-1][1:]

def clamp_bounds(cx, cy, span, limit=1.0):
    half_s = span / 2.0
    if cx - half_s < -limit: cx = -limit + half_s
    if cx + half_s > limit:  cx = limit - half_s
    if cy - half_s < -limit: cy = -limit + half_s
    if cy + half_s > limit:  cy = limit - half_s
    return cx, cy

# --- Setup Plot ---
plt.style.use('seaborn-v0_8-white') 
fig, axs = plt.subplots(1, 2, figsize=(10, 5.5), dpi=200, constrained_layout=True)
fig.patch.set_facecolor('white')

# Context View
full_field = trainer.data
output = trainer.generate_grid_data(resolution=SRC_RES)
reconstructed_data = output["filtered_signal"].clamp(-1.0, 1.0).add(1.0).div(2.0).cpu().numpy()
im_context = axs[0].imshow(reconstructed_data, extent=(-1, 1, 1, -1), cmap='gray', vmin=0, vmax=1)
axs[0].set_title("Learned Signal (Neural Field)", fontsize=14, fontweight='bold', color='gray')
axs[0].axis('off')

rect = patches.Rectangle((0,0), 0, 0, linewidth=2.5, edgecolor=COLOR_BOX, facecolor='none')
axs[0].add_patch(rect)

# FIX 1: Smaller points (s=4) and alpha=0.6 for better visibility without occlusion
scat = axs[0].scatter([], [], s=4, c=COLOR_POINTS, edgecolors='none', alpha=0.6)

# Naive & Ours Views
dummy_data = np.zeros((SAMPLE_RES, SAMPLE_RES, 3))
im_naive = axs[1].imshow(dummy_data, vmin=0, vmax=1, interpolation='nearest') 
axs[1].set_title("Naive Sampling (Aliased)", fontsize=14, fontweight='bold', color=COLOR_BOX)
axs[1].axis('off')

# --- Render Loop ---
writer = imageio.get_writer(
    video_output, 
    **video_settings,
)

print(f"Rendering {FRAMES} frames...")

for i in tqdm(range(FRAMES)):
    t = i / (FRAMES - 1)
    
    # Logic
    cx, cy, span = get_trajectory_point_time(t, keyframes)
    cx, cy = clamp_bounds(cx, cy, span, limit=0.99)
    bounds = (cx - span/2, cy - span/2, cx + span/2, cy + span/2)
    coords = make_coord_grid(ndim=2, resolution=SAMPLE_RES, bounds=bounds, device=device)

    # Naive
    with torch.no_grad():
        naive_out = trainer.model.forward(
            {"coords": coords.view(-1, 2)},
        )
        naive_out = naive_out["filtered_signal"].view(SAMPLE_RES, SAMPLE_RES, 3)
        naive_out = naive_out.clamp(-1.0, 1.0).add(1.0).div(2.0).cpu().numpy()

    # Updates
    rect.set_xy((bounds[0], bounds[1])) 
    rect.set_width(span)
    rect.set_height(span)
    
    c_np = coords.view(-1,2).cpu().numpy()
    scat.set_offsets(c_np)
    
    # Image Updates
    im_naive.set_data(naive_out)
    
    fig.canvas.draw()
    frame = np.asarray(fig.canvas.buffer_rgba())
    writer.append_data(frame)

writer.close()
plt.close()
print(f"Done. Saved {video_output}")