# Testing Time-Indexed RealNVP Implementation

This notebook comprehensively tests the time-indexed RealNVP flow T(t, x) for:
1. **Invertibility**: T^{-1}(t, T(t, x)) = x
2. **Differentiability w.r.t. time**: ∂T/∂t exists and is smooth
3. **Differentiability w.r.t. input**: ∂T/∂x exists with correct Jacobian
4. **Log-determinant consistency**: log|det(∂T/∂x)| matches forward and inverse
5. **Multiple scales**: From tiny 8×8 to ImageNet 224×224

In [1]:
import torch
import torch.nn as nn
import gc
import numpy as np
import matplotlib.pyplot as plt
from realnvp import TimeIndexedRealNVP, create_vector_flow, create_mnist_flow,create_cifar10_flow, create_imagenet_flow, create_imagenet_flow_stable
import warnings
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print("="*80)

print("GPUs available:", torch.cuda.device_count())

for i in range(torch.cuda.device_count()):
    props = torch.cuda.get_device_properties(i)
    print(f"GPU {i}: {props.name}, {props.total_memory/1024**3:.1f} GB")

Using device: cpu
GPUs available: 0


## 1. Count Parameters

In [2]:
def count_params(model):
    n = sum(p.numel() for p in model.parameters() if p.requires_grad)
    mb = n * 8 / 1e6  # float32
    return n, mb


## 2. Check Invertibility

In [3]:
@torch.no_grad()
def check_invertibility(model, x, t, name, atol=1e-5, rtol=1e-5):
    model.eval()
    y, ld1 = model(x, t)
    x_rec, ld2 = model.inverse(y, t)

    diff = (x_rec - x).reshape(x.shape[0], -1)
    l2 = diff.norm(dim=1)
    linf = diff.abs().max(dim=1).values
    rel = l2 / (x.reshape(x.shape[0], -1).norm(dim=1) + 1e-12)

    # logdet from inverse should be the negative of forward
    logdet_consistency = (ld1 + ld2).abs()

    print(f"\n[{name}] Invertibility check")
    print(f"  ||x_rec - x||_2    : mean {l2.mean():.3e} | max {l2.max():.3e}")
    print(f"  ||x_rec - x||_inf  : mean {linf.mean():.3e} | max {linf.max():.3e}")
    print(f"  rel L2 error       : mean {rel.mean():.3e} | max {rel.max():.3e}")
    print(f"  |logdet + logdet^-1|: mean {logdet_consistency.mean():.3e} | max {logdet_consistency.max():.3e}")

    ok = (l2.max() < atol + rtol * x.abs().max()) and (logdet_consistency.max() < 1e-8)    
    print(f"  PASS: {ok}")
    return ok

## 3. Check Differentiability

In [4]:
def check_differentiability_t(model, x, name, eps_list=(1e-1, 3e-2, 1e-2, 3e-3, 1e-3)):
    """Check differentiability w.r.t. time using autograd vs finite differences"""
    model.eval()
    B = x.shape[0]
    device = x.device
    D = x.view(B, -1).shape[1]  # Total dimension
    
    # random t in [0,1]
    t = torch.rand(B, device=device, dtype=x.dtype, requires_grad=True)

    # Autograd gradient of sum(y_i) w.r.t. t_i (per-sample)
    y, _ = model(x, t)
    f = y.view(B, -1).sum(dim=1)          # [B]
    g_aut, = torch.autograd.grad(f.sum(), t, create_graph=False)  # [B]
    
    # Per-element gradients (normalized by dimension)
    g_aut_per_elem = g_aut / D

    print(f"\n[{name}] Differentiability in time: autograd vs finite-difference")
    print(f"  Dimension: {D:,}")
    print(f"  Total grad stats: mean {g_aut.mean().item():.3e} | std {g_aut.std().item():.3e}")
    print(f"  Per-elem grad:    mean {g_aut_per_elem.mean().item():.3e} | std {g_aut_per_elem.std().item():.3e}")

    # Finite differences
    with torch.no_grad():
        for eps in eps_list:
            t_plus = (t.detach() + eps).clamp(0.0, 1.0)
            y_plus, _ = model(x, t_plus)
            fd = (y_plus - y.detach()) / eps        # same shape as y
            fd_sum = fd.view(B, -1).sum(dim=1)      # [B]

            abs_err = (fd_sum - g_aut.detach()).abs()
            rel_err = abs_err / (g_aut.detach().abs() + 1e-12)
            
            # Per-element errors
            abs_err_per_elem = abs_err / D
            rel_err_per_elem = rel_err  # relative error doesn't change with normalization

            print(f"  eps={eps:>7.1e} | total_err {abs_err.mean():.3e} | "
                  f"per_elem_err {abs_err_per_elem.mean():.3e} | rel_err {rel_err_per_elem.mean():.3e}")

def check_differentiability_x(model, x_template, name, eps_list=(1e-3, 3e-4, 1e-4, 3e-5, 1e-5)):
    """Check differentiability w.r.t. input x using autograd vs finite differences"""
    model.eval()
    B = x_template.shape[0]
    device = x_template.device
    D = x_template.view(B, -1).shape[1]  # Total dimension
    
    # Create input x that requires grad
    x = x_template.clone().detach().requires_grad_(True)
    t = torch.rand(B, device=device, dtype=x.dtype)  # Fixed time for this test
    
    # Forward pass
    y, logdet = model(x, t)
    
    # Compute scalar output for differentiation (sum of all outputs)
    f = y.view(B, -1).sum(dim=1).sum()  # Single scalar
    
    # Autograd gradient w.r.t. x
    g_aut = torch.autograd.grad(f, x, create_graph=False)[0]  # Same shape as x
    g_aut_flat = g_aut.view(B, -1)  # [B, D]
    
    # Per-element gradient statistics
    g_aut_per_elem = g_aut_flat.abs().mean()

    print(f"\n[{name}] Differentiability in input: autograd vs finite-difference")
    print(f"  Input shape: {x.shape}")
    print(f"  Total dimensions: {D:,}")
    print(f"  Grad stats: mean {g_aut_flat.mean().item():.3e} | std {g_aut_flat.std().item():.3e}")
    print(f"  Grad magnitude: mean |∇x| {g_aut_per_elem.item():.3e}")

    # Finite differences - test a few random directions
    with torch.no_grad():
        # Create random perturbation directions (normalized)
        directions = torch.randn_like(x)
        directions = directions / directions.view(B, -1).norm(dim=1, keepdim=True).view(B, *([1] * (x.dim()-1)))
        
        for eps in eps_list:
            # Perturb input
            x_plus = x.detach() + eps * directions
            
            # Forward pass with perturbed input
            y_plus, logdet_plus = model(x_plus, t)
            f_plus = y_plus.view(B, -1).sum(dim=1).sum()
            
            # Finite difference approximation
            fd_scalar = (f_plus - f.detach()) / eps
            
            # Directional derivative from autograd: ∇f · direction
            directional_grad = (g_aut.detach() * directions).view(B, -1).sum(dim=1).sum()
            
            # Compare
            abs_err = abs(fd_scalar - directional_grad)
            rel_err = abs_err / (abs(directional_grad) + 1e-12)
            
            print(f"  eps={eps:>7.1e} | abs_err {abs_err:.3e} | rel_err {rel_err:.3e}")

## 4. Dataset-specific runners

In [5]:
def quick_run_imagenet(device="cpu"):
    print("\n=== IMAGENET-LIKE IMAGE FLOW ===")
    model = create_imagenet_flow(
        resolution=128,  # Smaller than default 224 for memory efficiency
        num_layers=6, time_embed_dim=512,
        img_base_channels=256, img_blocks=3, img_groups=32,
        img_log_scale_clamp=10.0, use_permutation=True
    ).to(device)
    nparams, mb = count_params(model)
    print(f"Params: {nparams:,} ({mb:.2f} MB)")
    x = torch.randn(4, 3, 128, 128, device=device, dtype=torch.float32)  # Very small batch for memory
    t = torch.rand(4, device=device, dtype=torch.float32)
    check_invertibility(model, x, t, "ImageNet (3x128x128)")
    check_differentiability_t(model, x, "ImageNet (3x128x128)")
    check_differentiability_x(model, x, "ImageNet (3x128x128)")

def quick_run_imagenet_full(device="cpu"):
    print("\n=== FULL IMAGENET FLOW ===")
    model = create_imagenet_flow(
        resolution=224,  # Full resolution
        num_layers=4, time_embed_dim=256,
        img_base_channels=128, img_blocks=4, img_groups=32,
        img_log_scale_clamp=10.0, use_permutation=True
    ).to(device)
    nparams, mb = count_params(model)
    print(f"Params: {nparams:,} ({mb:.2f} MB)")
    x = torch.randn(6, 3, 224, 224, device=device, dtype=torch.float32)
    t = torch.rand(6, device=device, dtype=torch.float32)
    check_invertibility(model, x, t, "ImageNet (3x224x224)")
    check_differentiability_t(model, x, "ImageNet (3x224x224)")
    check_differentiability_x(model, x, "ImageNet (3x224x224)")

def quick_run_vector(device="cpu"):
    print("\n=== VECTOR FLOW ===")
    model = create_vector_flow(
        dim=32, num_layers=6, time_embed_dim=64,
        hidden=512, mlp_blocks=3, activation="gelu",
        use_layernorm=False, use_permutation=True
    ).to(device)
    nparams, mb = count_params(model)
    print(f"Params: {nparams:,} ({mb:.2f} MB)")
    x = torch.randn(8, 32, device=device, dtype=torch.float32)
    t = torch.rand(8, device=device, dtype=torch.float32)
    check_invertibility(model, x, t, "vector(32)")
    check_differentiability_t(model, x, "vector(32)")
    check_differentiability_x(model, x, "vector(32)")

def quick_run_mnist_like(device="cpu"):
    print("\n=== MNIST-LIKE IMAGE FLOW ===")
    model = create_mnist_flow(
        image_mode=True, num_layers=6, time_embed_dim=128,
        img_base_channels=96, img_blocks=3, img_groups=32,
        img_log_scale_clamp=10.0, use_permutation=True
    ).to(device)
    nparams, mb = count_params(model)
    print(f"Params: {nparams:,} ({mb:.2f} MB)")
    x = torch.randn(4, 1, 28, 28, device=device, dtype=torch.float32)
    t = torch.rand(4, device=device, dtype=torch.float32)
    check_invertibility(model, x, t, "MNIST (1x28x28)")
    check_differentiability_t(model, x, "MNIST (1x28x28)")
    check_differentiability_x(model, x, "MNIST (1x28x28)")

def quick_run_cifar(device="cpu"):
    print("\n=== CIFAR-10 IMAGE FLOW ===")
    model = create_cifar10_flow(
        num_layers=8, time_embed_dim=128,
        img_base_channels=128, img_blocks=4, img_groups=32,
        img_log_scale_clamp=10.0, use_permutation=True
    ).to(device)
    nparams, mb = count_params(model)
    print(f"Params: {nparams:,} ({mb:.2f} MB)")
    x = torch.randn(4, 3, 32, 32, device=device, dtype=torch.float32)  # keep small batch for speed
    t = torch.rand(4, device=device, dtype=torch.float32)
    check_invertibility(model, x, t, "CIFAR (3x32x32)")
    check_differentiability_t(model, x, "CIFAR (3x32x32)")
    check_differentiability_x(model, x, "CIFAR (3x32x32)")

## 5. Tests

In [6]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

quick_run_vector(device)
torch.cuda.empty_cache(); gc.collect()  # Cleanup

quick_run_mnist_like(device)
torch.cuda.empty_cache(); gc.collect()  # Cleanup

quick_run_cifar(device)
torch.cuda.empty_cache(); gc.collect()  # Cleanup

quick_run_imagenet(device)
torch.cuda.empty_cache(); gc.collect()  # Cleanup

quick_run_imagenet_full(device)  # Now there's room!
torch.cuda.empty_cache(); gc.collect()
print("\nAll checks done.")

Device: cpu

=== VECTOR FLOW ===
Params: 3,499,200 (27.99 MB)

[vector(32)] Invertibility check
  ||x_rec - x||_2    : mean 2.484e-07 | max 4.971e-07
  ||x_rec - x||_inf  : mean 1.788e-07 | max 4.768e-07
  rel L2 error       : mean 4.306e-08 | max 7.834e-08
  |logdet + logdet^-1|: mean 7.153e-07 | max 9.537e-07
  PASS: False

[vector(32)] Differentiability in time: autograd vs finite-difference
  Dimension: 32
  Total grad stats: mean 2.346e+00 | std 6.113e+01
  Per-elem grad:    mean 7.331e-02 | std 1.910e+00
  eps=1.0e-01 | total_err 4.944e+01 | per_elem_err 1.545e+00 | rel_err 1.008e+00
  eps=3.0e-02 | total_err 4.999e+01 | per_elem_err 1.562e+00 | rel_err 1.012e+00
  eps=1.0e-02 | total_err 5.002e+01 | per_elem_err 1.563e+00 | rel_err 1.017e+00
  eps=3.0e-03 | total_err 4.987e+01 | per_elem_err 1.558e+00 | rel_err 1.018e+00
  eps=1.0e-03 | total_err 6.285e+01 | per_elem_err 1.964e+00 | rel_err 1.251e+00

[vector(32)] Differentiability in input: autograd vs finite-difference
  Input