# Transformation Comparison: Default vs Manual

This notebook demonstrates a case where NumPyro's default transformation (softplus) is outperformed by a manually chosen transformation (exp) for small positive parameters.


In [4]:
# Let's create a case where manual transformation clearly outperforms default
import numpy as np
import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import SVI, Trace_ELBO
from numpyro.optim import Adam
from numpyro.infer.autoguide import AutoNormal

# print("=== CASE WHERE MANUAL TRANSFORMATION WINS ===")
# print("Model: z ~ N(0,1), y|z ~ N(z, σ), y=1.0, σ ~ HalfNormal(0.01)")
# print("True σ = 0.01 (very small positive value)")
# print()

# Model with very small positive parameter
def model(y):
    z = numpyro.sample('z', dist.Normal(0., 1.))
    sigma = numpyro.sample('sigma', dist.HalfNormal(0.01))  # Very small sigma
    numpyro.sample('y', dist.Normal(z, sigma), obs=y)

# default transformation
def guide_softplus(y):
    z_loc = numpyro.param('z_loc', 0.0)
    z_scale = numpyro.param('z_scale', 1.0, constraint=dist.constraints.positive)
    numpyro.sample('z', dist.Normal(z_loc, z_scale))
    
    sigma_unconstrained = numpyro.param('sigma_softplus', 0.0)
    sigma = jax.nn.softplus(sigma_unconstrained)
    numpyro.sample('sigma', dist.Delta(sigma))

guide_softplus = AutoNormal(model)

# Guide with exp (manual) - should work better for very small values
def guide_exp(y):
    z_loc = numpyro.param('z_loc', 0.0)
    z_scale = numpyro.param('z_scale', 1.0, constraint=dist.constraints.positive)
    numpyro.sample('z', dist.Normal(z_loc, z_scale))
    
    sigma_unconstrained = numpyro.param('sigma_exp', 0.0)
    sigma = jnp.exp(sigma_unconstrained)
    numpyro.sample('sigma', dist.Delta(sigma))

# Guide with custom transformation for very small values
def guide_custom(y):
    z_loc = numpyro.param('z_loc', 0.0)
    z_scale = numpyro.param('z_scale', 1.0, constraint=dist.constraints.positive)
    numpyro.sample('z', dist.Normal(z_loc, z_scale))
    
    # Custom: sigma = 0.01 * exp(x) - this centers the transformation around 0.01
    sigma_unconstrained = numpyro.param('sigma_custom', 0.0)
    sigma = 0.01 * jnp.exp(sigma_unconstrained)
    numpyro.sample('sigma', dist.Delta(sigma))

# Run optimization
elbo = Trace_ELBO(num_particles=100)
optimizer = Adam(1e-2)

print("Running SVI with different transformations...")
svi_softplus = SVI(model, guide_softplus, optimizer, elbo)
svi_exp = SVI(model, guide_exp, optimizer, elbo)
svi_custom = SVI(model, guide_custom, optimizer, elbo)

state_softplus = svi_softplus.init(jax.random.PRNGKey(42), y=jnp.array(1.0))
state_exp = svi_exp.init(jax.random.PRNGKey(42), y=jnp.array(1.0))
state_custom = svi_custom.init(jax.random.PRNGKey(42), y=jnp.array(1.0))

# Run for 1000 steps
for i in range(1000):
    state_softplus, loss_softplus = svi_softplus.update(state_softplus, y=jnp.array(1.0))
    state_exp, loss_exp = svi_exp.update(state_exp, y=jnp.array(1.0))
    state_custom, loss_custom = svi_custom.update(state_custom, y=jnp.array(1.0))
    
    if i % 200 == 0:
        print(f"Step {i}: Softplus={loss_softplus:.6f}, Exp={loss_exp:.6f}, Custom={loss_custom:.6f}")

# Check learned parameters
params_softplus = svi_softplus.get_params(state_softplus)
params_exp = svi_exp.get_params(state_exp)
params_custom = svi_custom.get_params(state_custom)
# print(params_softplus)

Running SVI with different transformations...
Step 0: Softplus=288.345947, Exp=4997.666016, Custom=11260.791992
Step 200: Softplus=29.039198, Exp=531.310547, Custom=269.575958
Step 400: Softplus=12.312260, Exp=213.392532, Custom=100.111549
Step 600: Softplus=7.708896, Exp=117.829994, Custom=65.084518
Step 800: Softplus=5.404984, Exp=75.140633, Custom=37.067211


In [5]:

sigma_softplus = jax.nn.softplus(params_softplus['sigma_auto_loc'])
sigma_exp = jnp.exp(params_exp['sigma_exp'])
sigma_custom = 0.01 * jnp.exp(params_custom['sigma_custom'])

print("\\nLearned sigma values:")
print(f"Softplus: {sigma_softplus:.6f}")
print(f"Exp:      {sigma_exp:.6f}")
print(f"Custom:   {sigma_custom:.6f}")
print(f"True:     0.01")



\nLearned sigma values:
Softplus: 0.023971
Exp:      0.104820
Custom:   0.029475
True:     0.01


In [6]:
# now we run and collect tracking values
# default transformation
guide_softplus = AutoNormal(model)

# Guide with exp (manual) - should work better for very small values
def guide_exp(y):
    z_loc = numpyro.param('z_loc', 0.0)
    z_scale = numpyro.param('z_scale', 1.0, constraint=dist.constraints.positive)
    numpyro.sample('z', dist.Normal(z_loc, z_scale))
    
    sigma_unconstrained = numpyro.param('sigma_exp', 0.0)
    sigma = jnp.exp(sigma_unconstrained)
    numpyro.sample('sigma', dist.Delta(sigma))

# Guide with custom transformation for very small values
def guide_custom(y):
    z_loc = numpyro.param('z_loc', 0.0)
    z_scale = numpyro.param('z_scale', 1.0, constraint=dist.constraints.positive)
    numpyro.sample('z', dist.Normal(z_loc, z_scale))
    
    # Custom: sigma = 0.01 * exp(x) - this centers the transformation around 0.01
    sigma_unconstrained = numpyro.param('sigma_custom', 0.0)
    sigma = 0.01 * jnp.exp(sigma_unconstrained)
    numpyro.sample('sigma', dist.Delta(sigma))

# Run optimization and collect intermediate results
elbo = Trace_ELBO(num_particles=100)
optimizer = Adam(1e-2)

print("Running SVI with different transformations...")
svi_softplus = SVI(model, guide_softplus, optimizer, elbo)
svi_exp = SVI(model, guide_exp, optimizer, elbo)
svi_custom = SVI(model, guide_custom, optimizer, elbo)

state_softplus = svi_softplus.init(jax.random.PRNGKey(42), y=jnp.array(1.0))
state_exp = svi_exp.init(jax.random.PRNGKey(42), y=jnp.array(1.0))
state_custom = svi_custom.init(jax.random.PRNGKey(42), y=jnp.array(1.0))

# Store results for plotting
results = {
    'softplus': {'losses': [], 'z_locs': [], 'z_scales': [], 'sigmas': []},
    'exp': {'losses': [], 'z_locs': [], 'z_scales': [], 'sigmas': []},
    'custom': {'losses': [], 'z_locs': [], 'z_scales': [], 'sigmas': []}
}

# Run for 10,000 steps and collect data
for i in range(10000):
    state_softplus, loss_softplus = svi_softplus.update(state_softplus, y=jnp.array(1.0))
    state_exp, loss_exp = svi_exp.update(state_exp, y=jnp.array(1.0))
    state_custom, loss_custom = svi_custom.update(state_custom, y=jnp.array(1.0))
    
    if i % 200 == 0:  # Collect data every 200 steps
        # Get parameters
        params_softplus = svi_softplus.get_params(state_softplus)
        params_exp = svi_exp.get_params(state_exp)
        params_custom = svi_custom.get_params(state_custom)
        
        # Store results
        results['softplus']['losses'].append(loss_softplus)
        results['softplus']['z_locs'].append(float(params_softplus['z_auto_loc']))
        results['softplus']['z_scales'].append(float(params_softplus['z_auto_scale']))
        results['softplus']['sigmas'].append(float(jax.nn.softplus(params_softplus['sigma_auto_loc'])))
        
        results['exp']['losses'].append(loss_exp)
        results['exp']['z_locs'].append(float(params_exp['z_loc']))
        results['exp']['z_scales'].append(float(params_exp['z_scale']))
        results['exp']['sigmas'].append(float(jnp.exp(params_exp['sigma_exp'])))
        
        results['custom']['losses'].append(loss_custom)
        results['custom']['z_locs'].append(float(params_custom['z_loc']))
        results['custom']['z_scales'].append(float(params_custom['z_scale']))
        results['custom']['sigmas'].append(float(0.01 * jnp.exp(params_custom['sigma_custom'])))

print("\\nFinal Results:")
print(f"Softplus final loss: {results['softplus']['losses'][-1]:.6f}")
print(f"Exp final loss:      {results['exp']['losses'][-1]:.6f}")
print(f"Custom final loss:   {results['custom']['losses'][-1]:.6f}")

print("\\nLearned sigma values:")
print(f"Softplus: {results['softplus']['sigmas'][-1]:.6f}")
print(f"Exp:      {results['exp']['sigmas'][-1]:.6f}")
print(f"Custom:   {results['custom']['sigmas'][-1]:.6f}")
print(f"True:     0.01")


Running SVI with different transformations...
\nFinal Results:
Softplus final loss: 2.080383
Exp final loss:      -2.773820
Custom final loss:   -2.488031
\nLearned sigma values:
Softplus: 0.008191
Exp:      0.006041
Custom:   0.008955
True:     0.01


In [7]:
# Create comprehensive plots
fig, axes = plt.subplots(2, 1, figsize=(15, 10))
fig.suptitle('Transformation Comparison: Posterior Convergence', fontsize=16)

# Convert to numpy arrays for plotting
steps = np.arange(0, 10000, 200)


# Plot 1: Sigma parameter convergence
axes[0].plot(steps, results['softplus']['sigmas'], label='Softplus', linewidth=2)
axes[0].plot(steps, results['exp']['sigmas'], label='Exp', linewidth=2)
axes[0].plot(steps, results['custom']['sigmas'], label='Custom', linewidth=2, color='purple')
axes[0].axhline(y=0.01, color='red', linestyle='--', label='True σ=0.01', linewidth=2)
axes[0].set_xlabel('Optimization Steps')
axes[0].set_ylabel('Learned σ')
axes[0].set_title('σ Parameter Convergence')
axes[0].legend()
axes[0].grid(True, alpha=0.3)


# Plot 5: Final posterior distributions for z
z_range = np.linspace(0, 2, 1000)
for i, (name, color) in enumerate([('softplus', 'blue'), ('exp', 'green'), ('custom', 'purple')]):
    z_loc = results[name]['z_locs'][-1]
    z_scale = results[name]['z_scales'][-1]
    z_pdf = (1/(z_scale * np.sqrt(2*np.pi))) * np.exp(-0.5 * ((z_range - z_loc)/z_scale)**2)
    axes[1].plot(z_range, z_pdf, label=f'{name.capitalize()}', color=color, linewidth=2)

# True posterior (approximate)
true_z_loc = 1.0  # Approximate true posterior mean
true_z_scale = 0.01  # Approximate true posterior std
true_z_pdf = (1/(true_z_scale * np.sqrt(2*np.pi))) * np.exp(-0.5 * ((z_range - true_z_loc)/true_z_scale)**2)
axes[1].plot(z_range, true_z_pdf, label='True Posterior', color='red', linestyle='--', linewidth=2, alpha=0.5)

axes[1].set_xlabel('z')
axes[1].set_ylabel('Density')
axes[1].set_title('Final Z Posterior Distributions')
axes[1].legend()
axes[1].grid(True, alpha=0.3)
axes[1].set_xlim(0.9, 1.1)


plt.tight_layout()
plt.show()

print("\\nVisual Analysis:")
print("1. Loss convergence shows which transformation optimizes fastest")
print("2. σ parameter convergence shows which gets closest to true value (0.01)")
print("3. Z posterior distributions show the quality of the learned approximation")
print("4. The custom transformation should show better convergence to the true σ value")


NameError: name 'plt' is not defined