In [1]:
import numpy as np
import time

In [2]:
from scripts.svgd import SVGD
from scripts.log_posterior import log_posterior
from scripts.grad_log_posterior import grad_log_posterior

In [3]:
def test_gradient_validity():
    particles = np.random.randn(100, 2)
    grad = grad_log_posterior(particles)
    assert grad.shape == particles.shape, "Gradient shape mismatch."
    assert not np.isnan(grad).any(), "Gradient contains NaNs."
    print("✅ Gradient shape and NaN check passed.")

In [4]:
def compute_kl(p_samples, log_q_fn):
    log_q = log_q_fn(p_samples)
    log_p = -0.5 * np.sum(p_samples ** 2, axis=1)
    return np.mean(log_p - log_q)

In [5]:
def test_kl_divergence_decreases():
    np.random.seed(42)
    svgd = SVGD()
    n_particles = 100
    dim = 2
    iterations = 300

    particles = np.random.randn(n_particles, dim)
    kl_values = []

    for i in range(iterations):
        particles = svgd.update(particles, grad_log_posterior)
        if i % 50 == 0:
            kl = compute_kl(particles, log_posterior)
            kl_values.append(kl)

    assert all(earlier >= later for earlier, later in zip(kl_values, kl_values[1:])), \
        f"KL did not decrease over time: {kl_values}"
    print("✅ KL divergence decreases test passed.")

In [6]:
def test_wall_time_under_limit():
    np.random.seed(0)
    svgd = SVGD()
    n_particles = 100
    dim = 2
    iterations = 1000
    time_budget_seconds = 20

    particles = np.random.randn(n_particles, dim)
    start_time = time.time()

    for _ in range(iterations):
        particles = svgd.update(particles, grad_log_posterior)

    elapsed = time.time() - start_time
    assert elapsed <= time_budget_seconds, f"Wall-time exceeded: {elapsed:.2f}s > {time_budget_seconds}s"
    print(f"✅ Wall-time test passed: {elapsed:.2f} seconds for {iterations} iterations")

In [7]:
if __name__ == "__main__":
    test_gradient_validity()
    test_kl_divergence_decreases()
    test_wall_time_under_limit()

✅ Gradient shape and NaN check passed.
✅ KL divergence decreases test passed.
✅ Wall-time test passed: 0.89 seconds for 1000 iterations
