In [3]:
# %%
from typing import Union, Sequence
from pathlib import Path
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import matplotlib.patches as patches
from omegaconf import OmegaConf
import hydra
from hydra import initialize_config_dir, compose
import time
import imageio as iio 

import numpy as np
import cv2
import scipy.fft 
import torch
import torch.nn.functional as F
from torch import Tensor 

import spnf
from spnf.utils import set_seed, make_coord_grid, apply_to_tensors, to_py, interpolate_covariance_matrices_numpy
from spnf.sample import (
    rand_ortho, logrand, construct_covariance, sample_gaussian_delta, sample_ellipsoid_delta
)
from spnf.trainer import Trainer

config_dir = Path(spnf.__file__).parent / "configs"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

def get_cfg(config_name, config_dir: Path, overrides: list):
    with initialize_config_dir(version_base=None, config_dir=str(config_dir)):
        cfg = compose(config_name=config_name, overrides=overrides)
    return cfg

def render_kernel(kind, cov_mat, interval=(-3.0, 3.0), res=1024, a=1.0, colormap=None):
    device = cov_mat.device
    lo, hi = float(interval[0]), float(interval[1])
    x = torch.linspace(lo, hi, res, device=device)
    grid = torch.stack(torch.meshgrid(x, x, indexing='xy'), dim=-1).reshape(-1, 2)

    inv_cov = torch.linalg.inv(cov_mat)
    v = grid @ inv_cov
    q = (v * grid).sum(-1)

    if kind.lower() == 'gaussian':
        ker = torch.exp(-0.5 * q)
        ker01 = ker.clamp(0, 1)
    elif kind.lower() == 'uniform_ellipsoid':
        ker = (q.sqrt() < 1).to(grid.dtype)
        ker01 = ker
    elif kind.lower() == 'lanczos':
        t = q.sqrt()
        ker = torch.sinc(t) * torch.sinc(t / a)
        ker = ker / torch.max(ker.abs()) * 2.5
        ker01 = ker.clamp(-1, 1).abs().clamp(0,1)
    else:
        raise ValueError("Unknown kind")

    img_u8 = (ker01.reshape(res, res) * 255).round().to(torch.uint8).cpu().numpy()

    if colormap is not None:
        vis = cv2.applyColorMap(img_u8, colormap)
    else:
        vis = cv2.cvtColor(img_u8, cv2.COLOR_GRAY2BGR)

    return vis

def image_spectrum(image: Tensor) -> Tensor:
    img = image[0].numpy()
    v = np.zeros_like(img)
    v[0, :] = img[-1, :] - img[0, :]
    v[-1, :] = img[0, :] - img[-1, :]
    v[:, 0] += img[:, -1] - img[:, 0]
    v[:, -1] += img[:, 0] - img[:, -1]
    v_hat = np.fft.fftn(v)
    M, N = v_hat.shape
    q = np.arange(M).reshape(M, 1).astype(v_hat.dtype)
    r = np.arange(N).reshape(1, N).astype(v_hat.dtype)
    den = (2 * np.cos(np.divide((2 * np.pi * q), M)) + 2 * np.cos(np.divide((2 * np.pi * r), N)) - 4)
    s = np.divide(v_hat, den, out=np.zeros_like(v_hat), where=den != 0)
    s[0, 0] = 0
    smooth_component = np.real(np.fft.ifftn(s))
    magnitudes = np.abs(scipy.fft.fftshift(scipy.fft.fft2(img - smooth_component)))
    return torch.as_tensor(magnitudes[None])

def _spectrum(picture):
    spectrum = image_spectrum(picture.mean(dim=0, keepdim=True))
    crop = (picture.shape[1] - picture.shape[1]//4) // 2
    return (spectrum[0, crop:-crop, crop:-crop].clamp(1e-5).log10() - 1.5).clamp(0) / 3.5 * 2 - 1

def spectrum_vis(x_rgb_m1p1):
    s = _spectrum(x_rgb_m1p1.permute(2,0,1).cpu())
    if isinstance(s, torch.Tensor): s = s.detach().cpu()
    s = np.asarray(s.squeeze())
    u8 = np.clip((s + 1.0) * 127.5, 0, 255).round().astype(np.uint8)
    return cv2.cvtColor(u8, cv2.COLOR_GRAY2BGR)

def render_basis_frame(A, covariance, t_plot, proj_dir, num_basis_vis):
    # Calculates damping and renders the basis plot for a specific covariance state
    freqs_1d = (A @ proj_dir)
    quad_form = (A @ covariance @ A.T).diagonal()
    attenuation = torch.exp(-2 * (torch.pi**2) * quad_form)
    
    phases = freqs_1d @ t_plot.unsqueeze(0) * 2 * np.pi
    waves = torch.cos(phases) * attenuation.unsqueeze(1)
    
    waves_np = waves.cpu().numpy()
    att_np = attenuation.cpu().numpy()
    
    fig_basis, axes = plt.subplots(2, 5, figsize=(15, 4), sharex=True, sharey=True)
    axes_flat = axes.flatten() 
    for b_i in range(num_basis_vis):
        ax = axes_flat[b_i]
        # Reference wave (gray)
        ref_wave = torch.cos(phases[b_i]).cpu().numpy()
        ax.plot(t_plot.cpu(), ref_wave, color='gray', alpha=0.2, linewidth=1)
        # Attenuated wave
        color_val = att_np[b_i]
        ax.plot(t_plot.cpu(), waves_np[b_i], color=plt.cm.magma(color_val), linewidth=1.5)
        ax.set_ylim(-1.2, 1.2)
        ax.axis('off')

    plt.tight_layout()
    fig_basis.canvas.draw()
    frame = np.array(fig_basis.canvas.buffer_rgba())
    plt.close(fig_basis)
    return frame

# %% [markdown]
# ## Training visualization for neural fields with Spectral Prefiltering

# %%
filter_type = "gaussian" 
lambda_decay_start = 100 
train_steps_per_model = 350
num_models = 10
SRC_RES = 512
FPS = 30                 

# Covariance range: -1.5 (Blurry) to -5.0 (Sharp)
cov_logvars = np.linspace(-3.5, -7.0, num_models)

cfg = get_cfg(
    config_name="train",
    config_dir=config_dir,
    overrides=[
        f"data.grid.resize={SRC_RES}",
        "data.bounds=1.0", 
        "tensorboard=null",
        f"trainer.steps={train_steps_per_model}",
        "paths.output_dir=/home/myaldiz/Data/Experiments/spnf/${task_name}",
        f"scheduler.lr_lambda.decay_start={lambda_decay_start}",
        "trainer.compile_train_step=True",
        "trainer.mc_samples_train=0",
        "model.encoder.length_distribution_param=500.0"
    ],
)

video_settings = dict(fps=FPS, codec='libx265', pixelformat='yuv420p')
ffmpeg_params = [
    '-crf', '20','-preset', 'fast', '-tag:v', 'hvc1', '-tune', 'grain'
]

output_dir = Path(cfg.paths.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
video_output_dir = output_dir / "visualizations" 
video_output_dir.mkdir(parents=True, exist_ok=True)

# %%
class TrainerModified(Trainer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.current_covariance = None

    @torch.no_grad()
    def generate_data_train(self, coords=None, eigenvalues=None, eigenvectors=None, perturb_kernel=None) -> dict[str, torch.Tensor]:
        batch = super().generate_data_train(coords, eigenvalues, eigenvectors, perturb_kernel)
        if self.current_covariance is not None:
            batch_size = batch["coords"].shape[0]
            batch["covariances"] = self.current_covariance.expand(batch_size, -1, -1)
            batch["filter_type"] = filter_type
        return batch

final_frames_main = []
# We will store the covariance matrices to interpolate them later for the basis video
saved_covariances = [] 

# --- Setup Basis Selection ---
temp_trainer = TrainerModified(cfg).to(device)
num_basis_vis = 10
basis_matrix = temp_trainer.model.encoder.A 
magnitudes = basis_matrix.norm(dim=1)
sorted_indices = torch.argsort(magnitudes)
selected_indices = sorted_indices[torch.linspace(0, len(magnitudes)-1, num_basis_vis).long()]
# Store A matrix for rendering later
vis_A = basis_matrix[selected_indices]

t_plot = torch.linspace(-0.5, 0.5, 200, device=device)
proj_dir = torch.randn(2, 1, device=device)
proj_dir = F.normalize(proj_dir, dim=0)
del temp_trainer 

# --- Main Multi-Model Loop ---
print(f"Training {num_models} models from scratch...")

for model_idx, log_var in enumerate(cov_logvars):
    print(f"--- Model {model_idx+1}/{num_models} (LogVar: {log_var:.2f}) ---")
    
    trainer = TrainerModified(cfg).to(device)
    
    # Construct Covariance
    cov_val = 10**log_var
    cov_mat = torch.eye(2, device=device) * cov_val
    trainer.current_covariance = cov_mat
    
    # Save for later interpolation
    saved_covariances.append(cov_mat)
    
    trainer.fit(num_steps=train_steps_per_model, no_tqdm=True)
    
    # Generate Prediction Frame
    output = trainer.generate_grid_data(resolution=SRC_RES)
    pred_img = output["filtered_signal"].clamp(-1.0, 1.0).add(1.0).div(2.0).cpu().numpy()
    
    if pred_img.ndim == 3 and pred_img.shape[0] in [1, 3]:
        pred_img = np.transpose(pred_img, (1, 2, 0))
        
    # Insets
    kern_img = render_kernel(kind=filter_type, cov_mat=cov_mat, interval=(-0.25, 0.25), res=256)
    pred_tensor_m1p1 = (torch.from_numpy(pred_img) * 2.0) - 1.0 
    spect_img = spectrum_vis(pred_tensor_m1p1)
    
    # Compose Main Frame 
    h, w = pred_img.shape[:2]
    fig_width = 12 
    img_height_in = fig_width * (h / w)
    
    fig = plt.figure(figsize=(fig_width, img_height_in), dpi=100)
    ax_img = fig.add_axes([0, 0, 1, 1])
    ax_img.imshow(pred_img)
    ax_img.axis('off')
    
    inset_size = 0.25 
    
    # Top Right: Kernel
    ax_kern = fig.add_axes([0.98 - inset_size, 0.98 - inset_size * (fig_width/img_height_in), inset_size, inset_size * (fig_width/img_height_in)])
    center = kern_img.shape[0] // 2
    q_size = kern_img.shape[0] // 4
    kern_zoom = kern_img[center - q_size:center + q_size, center - q_size:center + q_size]
    ax_kern.imshow(kern_zoom)
    ax_kern.set_xticks([]); ax_kern.set_yticks([])
    for spine in ax_kern.spines.values(): spine.set_color('white'); spine.set_linewidth(2)
    ax_kern.set_title("Spatial Kernel", fontsize=12, color='white', backgroundcolor='black')

    # Top Left: Spectrum
    ax_spect = fig.add_axes([0.02, 0.98 - inset_size * (fig_width/img_height_in), inset_size, inset_size * (fig_width/img_height_in)])
    ax_spect.imshow(spect_img)
    ax_spect.set_xticks([]); ax_spect.set_yticks([])
    for spine in ax_spect.spines.values(): spine.set_color('white'); spine.set_linewidth(2)
    ax_spect.set_title("Spectrum", fontsize=12, color='white', backgroundcolor='black')

    fig.canvas.draw()
    main_frame = np.array(fig.canvas.buffer_rgba())
    final_frames_main.append(main_frame)
    plt.close(fig)
    
    del trainer
    torch.cuda.empty_cache()

Training 10 models from scratch...
--- Model 1/10 (LogVar: -3.50) ---
--- Model 2/10 (LogVar: -3.89) ---
--- Model 3/10 (LogVar: -4.28) ---
--- Model 4/10 (LogVar: -4.67) ---
--- Model 5/10 (LogVar: -5.06) ---
--- Model 6/10 (LogVar: -5.44) ---
--- Model 7/10 (LogVar: -5.83) ---
--- Model 8/10 (LogVar: -6.22) ---
--- Model 9/10 (LogVar: -6.61) ---
--- Model 10/10 (LogVar: -7.00) ---


In [4]:
# %%
# --- Interpolation & Video Generation ---

video_main_path = str(video_output_dir / "fourier_vis_sweep.mp4")
video_basis_path = str(video_output_dir / "fourier_basis_vis_sweep.mp4")

writer_main = iio.get_writer(video_main_path, **video_settings, ffmpeg_params=ffmpeg_params)
writer_basis = iio.get_writer(video_basis_path, **video_settings, ffmpeg_params=ffmpeg_params)

frames_per_hold = 15        
frames_per_transition = 30  

print("Rendering interpolated videos...")

for i in range(len(final_frames_main)):
    current_main = final_frames_main[i]
    current_cov = saved_covariances[i]
    
    # Generate the Basis frame for the *current* covariance exactly once for the hold
    current_basis = render_basis_frame(vis_A, current_cov, t_plot, proj_dir, num_basis_vis)
    
    # 1. HOLD (Wait 1 second)
    for _ in range(frames_per_hold):
        writer_main.append_data(current_main)
        writer_basis.append_data(current_basis)
        
    # 2. TRANSITION (Interpolate 1 second)
    if i < len(final_frames_main) - 1:
        next_main = final_frames_main[i+1]
        next_cov = saved_covariances[i+1]
        
        for t in range(frames_per_transition):
            alpha = t / float(frames_per_transition - 1)
            
            # A. Main Video: Standard Pixel Cross Dissolve
            interp_main = cv2.addWeighted(current_main, 1 - alpha, next_main, alpha, 0)
            writer_main.append_data(interp_main)
            
            # B. Basis Video: Mathematical Interpolation of Covariance
            # Linear interpolation of the covariance matrix itself
            interp_cov = (1 - alpha) * current_cov + alpha * next_cov
            
            # Render a FRESH frame using the interpolated covariance
            interp_basis = render_basis_frame(vis_A, interp_cov, t_plot, proj_dir, num_basis_vis)
            writer_basis.append_data(interp_basis)

writer_main.close()
writer_basis.close()
print(f"Main video saved to: {video_main_path}")
print(f"Basis video saved to: {video_basis_path}")

Rendering interpolated videos...


x265 [info]: HEVC encoder version 3.5+1-f0c1022b6
x265 [info]: build info [Linux][GCC 8.3.0][64 bit] 8bit+10bit+12bit
x265 [info]: using cpu capabilities: MMX2 SSE2Fast LZCNT SSSE3 SSE4.2 AVX FMA3 BMI2 AVX2
x265 [info]: Main profile, Level-4 (Main tier)
x265 [info]: Thread pool created using 24 threads
x265 [info]: Slices                              : 1
x265 [info]: frame threads / pool features       : 4 / wpp(19 rows)
x265 [info]: Coding QT: max CU size, min CU size : 64 / 8
x265 [info]: Residual QT: max TU size, max depth : 32 / 1 inter / 1 intra
x265 [info]: ME / range / subpel / merge         : hex / 57 / 2 / 2
x265 [info]: Keyframe min / max / scenecut / bias  : 25 / 250 / 40 / 5.00 
x265 [info]: Lookahead / bframes / badapt        : 15 / 4 / 0
x265 [info]: b-pyramid / weightp / weightb       : 1 / 1 / 0
x265 [info]: References / ref-limit  cu / depth  : 3 / on / on
x265 [info]: Rate Control / qCompress            : CRF-20.0 / 0.60
x265 [info]: tools: rd=2 psy-rd=4.00 signhide t

x265 [info]: frame I:      2, Avg QP:15.50  kb/s: 69421.20
x265 [info]: frame P:     83, Avg QP:16.25  kb/s: 4978.46 
x265 [info]: frame B:    335, Avg QP:16.22  kb/s: 623.98  
x265 [info]: Weighted P-Frames: Y:15.7% UV:9.6%
x265 [info]: consecutive B-frames: 1.2% 0.0% 0.0% 1.2% 97.6% 

encoded 420 frames in 49.03s (8.57 fps), 1812.12 kb/s, Avg QP:16.22


Main video saved to: /home/myaldiz/Data/Experiments/spnf/spnf-oia/visualizations/fourier_vis_sweep.mp4
Basis video saved to: /home/myaldiz/Data/Experiments/spnf/spnf-oia/visualizations/fourier_basis_vis_sweep.mp4


x265 [info]: frame I:      2, Avg QP:16.50  kb/s: 6334.44 
x265 [info]: frame P:     83, Avg QP:14.86  kb/s: 1497.65 
x265 [info]: frame B:    335, Avg QP:14.86  kb/s: 295.78  
x265 [info]: Weighted P-Frames: Y:51.8% UV:37.3%
x265 [info]: consecutive B-frames: 1.2% 0.0% 0.0% 1.2% 97.6% 

encoded 420 frames in 49.71s (8.45 fps), 562.05 kb/s, Avg QP:14.87
