# Testing Time-Indexed RealNVP Implementation

This notebook comprehensively tests the time-indexed RealNVP flow T(t, x) for:
1. **Invertibility**: T^{-1}(t, T(t, x)) = x
2. **Differentiability w.r.t. time**: ∂T/∂t exists and is smooth
3. **Differentiability w.r.t. input**: ∂T/∂x exists with correct Jacobian
4. **Log-determinant consistency**: log|det(∂T/∂x)| matches forward and inverse
5. **Multiple scales**: From tiny 8×8 to ImageNet 224×224

In [7]:
import torch
import torch.nn as nn
import gc
import numpy as np
import matplotlib.pyplot as plt
from realnvp import TimeIndexedRealNVP, create_vector_flow, create_mnist_flow,create_cifar10_flow, create_imagenet_flow, create_imagenet_flow_stable
import warnings
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
print("="*80)

print("GPUs available:", torch.cuda.device_count())

for i in range(torch.cuda.device_count()):
    props = torch.cuda.get_device_properties(i)
    print(f"GPU {i}: {props.name}, {props.total_memory/1024**3:.1f} GB")

Using device: cpu
GPUs available: 0


## 1. Count Parameters

In [3]:
def count_params(model):
    n = sum(p.numel() for p in model.parameters() if p.requires_grad)
    mb = n * 8 / 1e6  # float32
    return n, mb


## 2. Check Invertibility

In [4]:
@torch.no_grad()
def check_invertibility(model, x, t, name, atol=1e-5, rtol=1e-5):
    model.eval()
    y, ld1 = model(x, t)
    x_rec, ld2 = model.inverse(y, t)

    diff = (x_rec - x).reshape(x.shape[0], -1)
    l2 = diff.norm(dim=1)
    linf = diff.abs().max(dim=1).values
    rel = l2 / (x.reshape(x.shape[0], -1).norm(dim=1) + 1e-12)

    # logdet from inverse should be the negative of forward
    logdet_consistency = (ld1 + ld2).abs()

    print(f"\n[{name}] Invertibility check")
    print(f"  ||x_rec - x||_2    : mean {l2.mean():.3e} | max {l2.max():.3e}")
    print(f"  ||x_rec - x||_inf  : mean {linf.mean():.3e} | max {linf.max():.3e}")
    print(f"  rel L2 error       : mean {rel.mean():.3e} | max {rel.max():.3e}")
    print(f"  |logdet + logdet^-1|: mean {logdet_consistency.mean():.3e} | max {logdet_consistency.max():.3e}")

    ok = (l2.max() < atol + rtol * x.abs().max()) and (logdet_consistency.max() < 1e-8)    
    print(f"  PASS: {ok}")
    return ok

## 3. Check Differentiability

In [8]:
def check_differentiability_t(model, x, name, eps_list=(1e-1, 3e-2, 1e-2, 3e-3, 1e-3)):
    """Check differentiability w.r.t. time using autograd vs finite differences"""
    model.eval()
    B = x.shape[0]
    device = x.device
    D = x.view(B, -1).shape[1]  # Total dimension
    
    # random t in [0,1]
    t = torch.rand(B, device=device, dtype=x.dtype, requires_grad=True)

    # Autograd gradient of sum(y_i) w.r.t. t_i (per-sample)
    y, _ = model(x, t)
    f = y.view(B, -1).sum(dim=1)          # [B]
    g_aut, = torch.autograd.grad(f.sum(), t, create_graph=False)  # [B]
    
    # Per-element gradients (normalized by dimension)
    g_aut_per_elem = g_aut / D

    print(f"\n[{name}] Differentiability in time: autograd vs finite-difference")
    print(f"  Dimension: {D:,}")
    print(f"  Total grad stats: mean {g_aut.mean().item():.3e} | std {g_aut.std().item():.3e}")
    print(f"  Per-elem grad:    mean {g_aut_per_elem.mean().item():.3e} | std {g_aut_per_elem.std().item():.3e}")

    # Finite differences
    with torch.no_grad():
        for eps in eps_list:
            t_plus = (t.detach() + eps).clamp(0.0, 1.0)
            y_plus, _ = model(x, t_plus)
            fd = (y_plus - y.detach()) / eps        # same shape as y
            fd_sum = fd.view(B, -1).sum(dim=1)      # [B]

            abs_err = (fd_sum - g_aut.detach()).abs()
            rel_err = abs_err / (g_aut.detach().abs() + 1e-12)
            
            # Per-element errors
            abs_err_per_elem = abs_err / D
            rel_err_per_elem = rel_err  # relative error doesn't change with normalization

            print(f"  eps={eps:>7.1e} | total_err {abs_err.mean():.3e} | "
                  f"per_elem_err {abs_err_per_elem.mean():.3e} | rel_err {rel_err_per_elem.mean():.3e}")

def check_differentiability_x(model, x_template, name, eps_list=(1e-3, 3e-4, 1e-4, 3e-5, 1e-5)):
    """Check differentiability w.r.t. input x using autograd vs finite differences"""
    model.eval()
    B = x_template.shape[0]
    device = x_template.device
    D = x_template.view(B, -1).shape[1]  # Total dimension
    
    # Create input x that requires grad
    x = x_template.clone().detach().requires_grad_(True)
    t = torch.rand(B, device=device, dtype=x.dtype)  # Fixed time for this test
    
    # Forward pass
    y, logdet = model(x, t)
    
    # Compute scalar output for differentiation (sum of all outputs)
    f = y.view(B, -1).sum(dim=1).sum()  # Single scalar
    
    # Autograd gradient w.r.t. x
    g_aut = torch.autograd.grad(f, x, create_graph=False)[0]  # Same shape as x
    g_aut_flat = g_aut.view(B, -1)  # [B, D]
    
    # Per-element gradient statistics
    g_aut_per_elem = g_aut_flat.abs().mean()

    print(f"\n[{name}] Differentiability in input: autograd vs finite-difference")
    print(f"  Input shape: {x.shape}")
    print(f"  Total dimensions: {D:,}")
    print(f"  Grad stats: mean {g_aut_flat.mean().item():.3e} | std {g_aut_flat.std().item():.3e}")
    print(f"  Grad magnitude: mean |∇x| {g_aut_per_elem.item():.3e}")

    # Finite differences - test a few random directions
    with torch.no_grad():
        # Create random perturbation directions (normalized)
        directions = torch.randn_like(x)
        directions = directions / directions.view(B, -1).norm(dim=1, keepdim=True).view(B, *([1] * (x.dim()-1)))
        
        for eps in eps_list:
            # Perturb input
            x_plus = x.detach() + eps * directions
            
            # Forward pass with perturbed input
            y_plus, logdet_plus = model(x_plus, t)
            f_plus = y_plus.view(B, -1).sum(dim=1).sum()
            
            # Finite difference approximation
            fd_scalar = (f_plus - f.detach()) / eps
            
            # Directional derivative from autograd: ∇f · direction
            directional_grad = (g_aut.detach() * directions).view(B, -1).sum(dim=1).sum()
            
            # Compare
            abs_err = abs(fd_scalar - directional_grad)
            rel_err = abs_err / (abs(directional_grad) + 1e-12)
            
            print(f"  eps={eps:>7.1e} | abs_err {abs_err:.3e} | rel_err {rel_err:.3e}")

## 4. Dataset-specific runners

In [None]:
def quick_run_imagenet(device="cpu"):
    print("\n=== IMAGENET-LIKE IMAGE FLOW ===")
    model = create_imagenet_flow(
        resolution=128,  # Smaller than default 224 for memory efficiency
        num_layers=6, time_embed_dim=512,
        img_base_channels=256, img_blocks=3, img_groups=32,
        img_log_scale_clamp=10.0, use_permutation=True
    ).to(device)
    nparams, mb = count_params(model)
    print(f"Params: {nparams:,} ({mb:.2f} MB)")
    x = torch.randn(4, 3, 128, 128, device=device, dtype=torch.float64)  # Very small batch for memory
    t = torch.rand(4, device=device, dtype=torch.float64)
    check_invertibility(model, x, t, "ImageNet (3x128x128)")
    check_differentiability_t(model, x, "ImageNet (3x128x128)")
    check_differentiability_x(model, x, "ImageNet (3x128x128)")

def quick_run_imagenet_full(device="cpu"):
    print("\n=== FULL IMAGENET FLOW ===")
    model = create_imagenet_flow(
        resolution=224,  # Full resolution
        num_layers=4, time_embed_dim=256,
        img_base_channels=128, img_blocks=4, img_groups=32,
        img_log_scale_clamp=10.0, use_permutation=True
    ).to(device)
    nparams, mb = count_params(model)
    print(f"Params: {nparams:,} ({mb:.2f} MB)")
    x = torch.randn(6, 3, 224, 224, device=device, dtype=torch.float64)
    t = torch.rand(6, device=device, dtype=torch.float64)
    check_invertibility(model, x, t, "ImageNet (3x224x224)")
    check_differentiability_t(model, x, "ImageNet (3x224x224)")
    check_differentiability_x(model, x, "ImageNet (3x224x224)")

def quick_run_vector(device="cpu"):
    print("\n=== VECTOR FLOW ===")
    model = create_vector_flow(
        dim=32, num_layers=6, time_embed_dim=64,
        hidden=512, mlp_blocks=3, activation="gelu",
        use_layernorm=False, use_permutation=True
    ).to(device)
    nparams, mb = count_params(model)
    print(f"Params: {nparams:,} ({mb:.2f} MB)")
    x = torch.randn(8, 32, device=device, dtype=torch.float64)
    t = torch.rand(8, device=device, dtype=torch.float64)
    check_invertibility(model, x, t, "vector(32)")
    check_differentiability_t(model, x, "vector(32)")
    check_differentiability_x(model, x, "vector(32)")

def quick_run_mnist_like(device="cpu"):
    print("\n=== MNIST-LIKE IMAGE FLOW ===")
    model = create_mnist_flow(
        image_mode=True, num_layers=6, time_embed_dim=128,
        img_base_channels=96, img_blocks=3, img_groups=32,
        img_log_scale_clamp=10.0, use_permutation=True
    ).to(device)
    nparams, mb = count_params(model)
    print(f"Params: {nparams:,} ({mb:.2f} MB)")
    x = torch.randn(4, 1, 28, 28, device=device, dtype=torch.float64)
    t = torch.rand(4, device=device, dtype=torch.float64)
    check_invertibility(model, x, t, "MNIST (1x28x28)")
    check_differentiability_t(model, x, "MNIST (1x28x28)")
    check_differentiability_x(model, x, "MNIST (1x28x28)")

def quick_run_cifar(device="cpu"):
    print("\n=== CIFAR-10 IMAGE FLOW ===")
    model = create_cifar10_flow(
        num_layers=8, time_embed_dim=128,
        img_base_channels=128, img_blocks=4, img_groups=32,
        img_log_scale_clamp=10.0, use_permutation=True
    ).to(device)
    nparams, mb = count_params(model)
    print(f"Params: {nparams:,} ({mb:.2f} MB)")
    x = torch.randn(4, 3, 32, 32, device=device, dtype=torch.float64)  # keep small batch for speed
    t = torch.rand(4, device=device, dtype=torch.float64)
    check_invertibility(model, x, t, "CIFAR (3x32x32)")
    check_differentiability_t(model, x, "CIFAR (3x32x32)")
    check_differentiability_x(model, x, "CIFAR (3x32x32)")

## 5. Tests

In [10]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)

quick_run_vector(device)
torch.cuda.empty_cache(); gc.collect()  # Cleanup

quick_run_mnist_like(device)
torch.cuda.empty_cache(); gc.collect()  # Cleanup

quick_run_cifar(device)
torch.cuda.empty_cache(); gc.collect()  # Cleanup

quick_run_imagenet(device)
torch.cuda.empty_cache(); gc.collect()  # Cleanup

quick_run_imagenet_full(device)  # Now there's room!
torch.cuda.empty_cache(); gc.collect()
print("\nAll checks done.")

Device: cpu

=== VECTOR FLOW ===
Params: 3,499,200 (27.99 MB)

[vector(32)] Invertibility check
  ||x_rec - x||_2    : mean 5.518e-16 | max 7.973e-16
  ||x_rec - x||_inf  : mean 3.608e-16 | max 4.441e-16
  rel L2 error       : mean 9.486e-17 | max 1.287e-16
  |logdet + logdet^-1|: mean 8.882e-16 | max 1.776e-15
  PASS: True

[vector(32)] Differentiability in time: autograd vs finite-difference
  Dimension: 32
  Total grad stats: mean -8.494e+00 | std 8.408e+01
  Per-elem grad:    mean -2.654e-01 | std 2.628e+00
  eps=1.0e-01 | total_err 7.089e+01 | per_elem_err 2.215e+00 | rel_err 9.994e-01
  eps=3.0e-02 | total_err 7.030e+01 | per_elem_err 2.197e+00 | rel_err 9.882e-01
  eps=1.0e-02 | total_err 6.912e+01 | per_elem_err 2.160e+00 | rel_err 9.551e-01
  eps=3.0e-03 | total_err 6.830e+01 | per_elem_err 2.134e+00 | rel_err 9.501e-01
  eps=1.0e-03 | total_err 6.303e+01 | per_elem_err 1.970e+00 | rel_err 7.952e-01

=== MNIST-LIKE IMAGE FLOW ===
Params: 3,463,500 (27.71 MB)


[W903 02:36:50.633738000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:50.647839000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:50.672257000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:50.683331000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:50.693425000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:50.707403000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:50.717766000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:50.729228000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:50.733024000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:50.734562000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:3


[MNIST (1x28x28)] Invertibility check
  ||x_rec - x||_2    : mean 2.126e-14 | max 2.185e-14
  ||x_rec - x||_inf  : mean 3.331e-15 | max 3.997e-15
  rel L2 error       : mean 7.476e-16 | max 7.781e-16
  |logdet + logdet^-1|: mean 2.842e-14 | max 5.684e-14
  PASS: True


[W903 02:36:51.462508000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:51.475146000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:51.486215000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:51.495538000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:51.498150000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:51.502985000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:51.516814000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:51.531258000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:51.542889000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:51.557667000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:3


[MNIST (1x28x28)] Differentiability in time: autograd vs finite-difference
  Dimension: 784
  Total grad stats: mean 5.941e+04 | std 3.544e+05
  Per-elem grad:    mean 7.578e+01 | std 4.520e+02


[W903 02:36:52.494648000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:52.496562000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:52.508630000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:52.520955000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:52.531807000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:52.542658000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:52.552703000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:52.561215000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:52.563608000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:52.564445000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:3

  eps=1.0e-01 | total_err 2.916e+05 | per_elem_err 3.719e+02 | rel_err 1.007e+00


[W903 02:36:52.912382000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:52.921355000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:52.923874000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:52.925059000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:52.935671000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:52.945186000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:52.954629000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:52.963907000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:52.973609000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:52.983289000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:3

  eps=3.0e-02 | total_err 2.922e+05 | per_elem_err 3.727e+02 | rel_err 1.013e+00


[W903 02:36:53.327338000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:53.337429000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:53.348245000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:53.358253000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:53.369619000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:53.379938000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:53.382144000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:53.382966000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:53.393539000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:53.403206000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:3

  eps=1.0e-02 | total_err 3.005e+05 | per_elem_err 3.833e+02 | rel_err 1.023e+00


[W903 02:36:53.744330000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:53.754947000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:53.764545000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:53.767134000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:53.769096000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:53.779865000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:53.790372000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:53.801813000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:53.811773000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:53.821835000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:3

  eps=3.0e-03 | total_err 2.802e+05 | per_elem_err 3.574e+02 | rel_err 9.728e-01


[W903 02:36:53.161149000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:53.163639000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:53.164944000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:53.176114000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:53.186688000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:53.197505000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:53.207279000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:53.218476000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:54.228094000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:54.230549000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:3

  eps=1.0e-03 | total_err 2.236e+05 | per_elem_err 2.852e+02 | rel_err 7.640e-01

=== CIFAR-10 IMAGE FLOW ===
Params: 10,611,760 (84.89 MB)


[W903 02:36:54.586368000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:54.588372000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:54.620054000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:54.642818000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:54.666772000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:54.688092000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:54.711602000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:54.733267000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:54.755515000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:54.775754000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:3


[CIFAR (3x32x32)] Invertibility check
  ||x_rec - x||_2    : mean 8.434e-14 | max 8.751e-14
  ||x_rec - x||_inf  : mean 7.647e-15 | max 8.438e-15
  rel L2 error       : mean 1.539e-15 | max 1.633e-15
  |logdet + logdet^-1|: mean 1.137e-13 | max 2.274e-13
  PASS: True


[W903 02:36:57.670410000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:57.700589000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:57.727184000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:57.750224000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:57.772032000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:57.794023000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:57.814094000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:57.818135000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:57.819929000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:36:57.843221000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:3


[CIFAR (3x32x32)] Differentiability in time: autograd vs finite-difference
  Dimension: 3,072
  Total grad stats: mean -1.111e+06 | std 1.225e+06
  Per-elem grad:    mean -3.618e+02 | std 3.988e+02


[W903 02:37:00.908289000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:37:00.931061000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:37:00.955379000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:37:00.980754000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:37:00.001709000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:37:00.022979000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:37:00.045615000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:37:00.050201000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:37:00.052685000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:37:00.074557000 NNPACK.cpp:57] Could not initialize NNPACK! Reason: Unsupported hardware.
[W903 02:3

KeyboardInterrupt: 

In [None]:
def run_tests_with_cleanup(device):
    """Run all tests with proper memory cleanup"""
    
    tests = [
        ("Vector", quick_run_vector),
        ("MNIST", quick_run_mnist_like),
        ("CIFAR", quick_run_cifar),
        ("ImageNet 128x128", quick_run_imagenet),
    ]
    
    for test_name, test_func in tests:
        print(f"\n{'='*80}")
        print(f"RUNNING: {test_name}")
        print(f"{'='*80}")
        
        # Show memory before test
        if device == "cuda":
            torch.cuda.empty_cache()
            memory_before = torch.cuda.memory_allocated() / 1e9
            print(f"GPU memory before: {memory_before:.2f} GB")
        
        try:
            # Run the test
            test_func(device)
            
        except Exception as e:
            print(f"❌ {test_name} failed: {e}")
            
        finally:
            # CRITICAL: Clean up memory
            if device == "cuda":
                torch.cuda.empty_cache()
                import gc
                gc.collect()

# Run with proper cleanup
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)
run_tests_with_cleanup(device)