# Test: New Biject API

This notebook tests the new `Var.biject()` and `Dist.biject_parameters()` methods that provide eager evaluation and proper handling of parameters with distributions.

In [24]:
import jax.numpy as jnp
import tensorflow_probability.substrates.jax.bijectors as tfb
import tensorflow_probability.substrates.jax.distributions as tfd

import liesel.model as lsl
import logging

## 1. Basic Var.biject() Usage

### 1.1 Eager transformation with "auto" bijector in constructor

In [25]:
# Create a variable with automatic bijection in constructor
prior = lsl.Dist(tfd.HalfCauchy, loc=0.0, scale=25.0)
scale = lsl.Var(1.0, distribution=prior, name="scale", bijector="auto")

print(f"scale is weak: {scale.weak}")
print(f"scale.bijected_var: {scale.bijected_var}")
print(f"bijected_var is weak: {scale.bijected_var.weak}")
print(f"bijected_var name: {scale.bijected_var.name}")

scale is weak: True
scale.bijected_var: Var(name="scale_transformed")
bijected_var is weak: False
bijected_var name: scale_transformed


### 1.2 Explicit bijector usage with method call

In [26]:
# Create variable and transform with explicit bijector
prior2 = lsl.Dist(tfd.HalfCauchy, loc=0.0, scale=10.0)
scale2 = lsl.Var.new_param(2.0, distribution=prior2, name="scale2")

print(f"Before biject - scale2 is weak: {scale2.weak}")

# Apply bijection with explicit bijector
result = scale2.biject(bijector=tfb.Exp())

print(f"After biject - scale2 is weak: {scale2.weak}")
print(f"biject() returns self: {result is scale2}")
print(f"bijected_var: {scale2.bijected_var}")

Before biject - scale2 is weak: False
After biject - scale2 is weak: True
biject() returns self: True
bijected_var: Var(name="scale2_transformed")


### 1.3 Error when "auto" but no distribution

In [27]:
# This should raise an error
try:
    no_dist = lsl.Var.new_param(1.0, name="no_dist")
    no_dist.biject(bijector="auto")
    print("ERROR: Should have raised an exception!")
except RuntimeError as e:
    print(f"Expected error: {e}")

Expected error: Var(name="no_dist") has no distribution, so there is no default event space bijector to be found. No bijector was given.


### 1.4 bijector=None means no transformation

In [28]:
# bijector=None should skip transformation
scale3 = lsl.Var.new_param(3.0, distribution=prior, name="scale3")
print(f"Before biject(None) - is weak: {scale3.weak}")

scale3.biject(bijector=None)
print(f"After biject(None) - is weak: {scale3.weak}")
print(f"bijected_var: {scale3.bijected_var}")

Before biject(None) - is weak: False
After biject(None) - is weak: False
bijected_var: None


## 2. Basic Dist.biject_parameters() Usage

### 2.1 Auto transformation in Dist constructor

In [29]:
# Create parameters
concentration = lsl.Var.new_param(2.0, name="concentration")
rate = lsl.Var.new_param(1.0, name="rate")

print(f"Before - concentration is weak: {concentration.weak}")
print(f"Before - rate is weak: {rate.weak}")

# Create Dist with auto bijectors
gamma_dist = lsl.Dist(tfd.Gamma, concentration=concentration, rate=rate, bijectors="auto")

print(f"\nAfter - concentration is weak: {concentration.weak}")
print(f"After - rate is weak: {rate.weak}")
print(f"concentration.bijected_var: {concentration.bijected_var}")
print(f"rate.bijected_var: {rate.bijected_var}")

Before - concentration is weak: False
Before - rate is weak: False

After - concentration is weak: True
After - rate is weak: True
concentration.bijected_var: Var(name="concentration_transformed")
rate.bijected_var: Var(name="rate_transformed")


### 2.2 Dict-based bijector specification

In [30]:
# Create parameters
loc = lsl.Var.new_param(0.0, name="loc")
scale = lsl.Var.new_param(1.0, name="scale")

print(f"Before - loc is weak: {loc.weak}, scale is weak: {scale.weak}")

# Use dict to specify: auto for scale, None (skip) for loc
normal_dist = lsl.Dist(
    tfd.Normal,
    loc=loc,
    scale=scale,
    bijectors={"scale": "auto", "loc": None}
)

print(f"After - loc is weak: {loc.weak}, scale is weak: {scale.weak}")
print(f"Only scale was transformed: {scale.bijected_var}")

Before - loc is weak: False, scale is weak: False
After - loc is weak: False, scale is weak: True
Only scale was transformed: Var(name="scale_transformed")


### 2.3 Sequence-based bijector specification (positional)

In [31]:
# Create parameters
loc2 = lsl.Var.new_param(0.0, name="loc2")
scale2 = lsl.Var.new_param(1.0, name="scale2")

# Use sequence: corresponds to positional parameters (loc, scale)
# Skip loc (None), auto-transform scale
normal_dist2 = lsl.Dist(
    tfd.Normal,
    loc2,  # positional
    scale2,  # positional
    bijectors=[None, "auto"]  # loc=None, scale="auto"
)

print(f"loc2 is weak: {loc2.weak}, scale2 is weak: {scale2.weak}")
print(f"Only scale2 was transformed: {scale2.bijected_var}")

loc2 is weak: False, scale2 is weak: True
Only scale2 was transformed: Var(name="scale2_transformed")


### 2.4 Mixed "auto" and explicit bijectors

In [32]:
# Use explicit bijector for one parameter, auto for another
conc = lsl.Var.new_param(3.0, name="conc")
rt = lsl.Var.new_param(2.0, name="rt")

gamma_dist2 = lsl.Dist(
    tfd.Gamma,
    concentration=conc,
    rate=rt,
    bijectors={"concentration": tfb.Softplus(), "rate": "auto"}
)

print(f"conc is weak: {conc.weak}, rt is weak: {rt.weak}")
print(f"Both transformed with different bijectors")

conc is weak: True, rt is weak: True
Both transformed with different bijectors


## 3. Precedence and Conflict Scenarios

### 3.1 Var-level bijector takes precedence (debug message)

In [33]:
logging.basicConfig(level=logging.DEBUG)

# Transform at Var level first
prior_prec = lsl.Dist(tfd.HalfCauchy, loc=0.0, scale=5.0)
prec_param = lsl.Var(1.0, distribution=prior_prec, name="prec_param", bijector="auto")

print(f"prec_param already weak: {prec_param.weak}")

# Now try to transform at Dist level with "auto" - should skip with debug log
print("\nAttempting Dist-level transformation (should see debug message):")
some_dist = lsl.Dist(tfd.Normal, loc=0.0, scale=prec_param, bijectors="auto")

print("\nNo error - Dist-level transformation was skipped")


prec_param already weak: True

Attempting Dist-level transformation (should see debug message):

No error - Dist-level transformation was skipped


### 3.2 Conflict error with explicit bijector

In [34]:
# Transform at Var level
param_weak = lsl.Var(2.0, distribution=lsl.Dist(tfd.HalfCauchy, loc=0.0, scale=25.0), name="param_weak", bijector="auto")

# Try to apply explicit bijector at Dist level - should error
try:
    conflict_dist = lsl.Dist(
        tfd.Normal,
        loc=0.0,
        scale=param_weak,
        bijectors={"scale": tfb.Exp()}  # Explicit bijector conflicts!
    )
    print("ERROR: Should have raised an exception!")
except RuntimeError as e:
    print(f"Expected conflict error: {e}")

ERROR: Should have raised an exception!


### 3.3 Conflict with auto_transform flag

In [35]:
# Set auto_transform flag
param_auto = lsl.Var.new_param(1.0, distribution=lsl.Dist(tfd.HalfCauchy, loc=0.0, scale=25.0), name="param_auto")
param_auto.auto_transform = True

# Try to apply explicit bijector at Dist level - should error
try:
    conflict_dist2 = lsl.Dist(
        tfd.Normal,
        loc=0.0,
        scale=param_auto,
        bijectors={"scale": tfb.Exp()}
    )
    print("ERROR: Should have raised an exception!")
except RuntimeError as e:
    print(f"Expected conflict error: {e}")

Expected conflict error: Parameter 'scale' has auto_transform=True, but explicit bijector provided. Resolve the conflict.


## 4. Edge Case: Parameter with Distribution

When a parameter has its own distribution, the transformation must use TransformedDistribution.

In [None]:
# Create a parameter with its own distribution (prior)
concentration_prior = lsl.Dist(tfd.Gamma, concentration=2.0, rate=1.0)
concentration_param = lsl.Var.new_param(2.0, distribution=concentration_prior, name="concentration_param")

print(f"concentration_param has distribution: {concentration_param.has_dist}")
print(f"concentration_param is weak: {concentration_param.weak}")

# Use this parameter in another distribution with auto bijection
# This should properly create a TransformedDistribution for the parameter
inv_gamma_dist = lsl.Dist(
    tfd.InverseGamma,
    concentration=concentration_param,
    scale=1.0,
    bijectors="auto"
)

print(f"\nAfter Dist.biject_parameters():")
print(f"concentration_param is weak: {concentration_param.weak}")
print(f"concentration_param.bijected_var: {concentration_param.bijected_var}")

# The bijected_var should still have a distribution (TransformedDistribution)
if concentration_param.bijected_var:
    print(f"bijected_var has distribution: {concentration_param.bijected_var.has_dist}")

concentration_param has distribution: True
concentration_param is weak: False

After Dist.biject_parameters():
concentration_param is weak: True
concentration_param.bijected_var: Var(name="concentration_param_transformed")
bijected_var has distribution: True
This demonstrates proper TransformedDistribution handling!


### 4.1 Build a model with this setup

In [14]:
# Create a complete model with parameter that has distribution
conc_prior = lsl.Dist(tfd.Gamma, concentration=2.0, rate=1.0)
conc = lsl.Var.new_param(2.0, distribution=conc_prior, name="conc")

inv_gamma = lsl.Dist(
    tfd.InverseGamma,
    concentration=conc,
    scale=1.0,
    bijectors="auto"
)

y = lsl.Var.new_obs(jnp.array([1.0, 2.0, 3.0]), distribution=inv_gamma, name="y")

# Build model
gb = lsl.GraphBuilder().add(y)
model = gb.build_model()

print(f"Model built successfully!")
print(f"Model has {len(model.vars)} variables")
print(f"Variable names: {list(model.vars.keys())}")

# Check log prob
model.update()
print(f"\nModel log_prob: {model.log_prob}")

DEBUG:2025-11-21 18:39:54,755:jax._src.dispatch:198: Finished tracing + transforming convert_element_type for pjit in 0.000281334 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming convert_element_type for pjit in 0.000281334 sec
DEBUG:2025-11-21 18:39:54,757:jax._src.interpreters.pxla:1861: Compiling jit(convert_element_type) with global shapes and types [ShapedArray(float32[3])]. Argument mapping: (UnspecifiedValue,).
DEBUG:jax._src.interpreters.pxla:Compiling jit(convert_element_type) with global shapes and types [ShapedArray(float32[3])]. Argument mapping: (UnspecifiedValue,).
DEBUG:2025-11-21 18:39:54,761:jax._src.dispatch:198: Finished jaxpr to MLIR module conversion jit(convert_element_type) in 0.003257275 sec
DEBUG:jax._src.dispatch:Finished jaxpr to MLIR module conversion jit(convert_element_type) in 0.003257275 sec
DEBUG:2025-11-21 18:39:54,763:jax._src.compiler:165: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax.

Model built successfully!
Model has 3 variables
Variable names: ['y', 'conc', 'conc_transformed']

Model log_prob: -8.66087818145752


## 5. Method Chaining

Both methods return `self`, enabling fluent API usage.

In [17]:
# Var.biject() returns self
chain_var = lsl.Var.new_param(1.0, distribution=lsl.Dist(tfd.HalfCauchy, loc=0.0, scale=25.0), name="chain_var")
result = chain_var.biject(bijector="auto")

print(f"biject() returns self: {result is chain_var}")
print(f"Can chain methods: weak={chain_var.weak}")

# Dist.biject_parameters() returns self
p1 = lsl.Var.new_param(1.0, name="p1")
p2 = lsl.Var.new_param(2.0, name="p2")
dist_chain = lsl.Dist(tfd.Gamma, concentration=p1, rate=p2)
result2 = dist_chain.biject_parameters(bijectors="auto")

print(f"biject_parameters() returns self: {result2 is dist_chain}")

biject() returns self: True
Can chain methods: weak=True
biject_parameters() returns self: True


## 6. Backward Compatibility

### 6.1 Old `transform()` method still works

In [None]:
# Use old transform() method
old_var = lsl.Var.new_param(1.0, distribution=lsl.Dist(tfd.HalfCauchy, loc=0.0, scale=25.0), name="old_var")
transformed_var = old_var.transform(bijector=tfb.Exp())

print(f"old_var is weak: {old_var.weak}")
print(f"transform() returns the transformed variable: {transformed_var}")
print(f"transformed_var.name: {transformed_var.name}")

old_var is weak: True
transform() returns the transformed variable: Var(name="old_var_transformed")
transformed_var.name: old_var_transformed
Old API still works!


### 6.2 `auto_transform` flag still works

In [21]:
# Use auto_transform flag (lazy evaluation)
lazy_var = lsl.Var.new_param(1.0, distribution=lsl.Dist(tfd.HalfCauchy, loc=0.0, scale=25.0), name="lazy_var")
lazy_var.auto_transform = True

print(f"Before model building - lazy_var is weak: {lazy_var.weak}")

# Create a simple model
y_lazy = lsl.Var.new_obs(jnp.array([1.0, 2.0]), distribution=lsl.Dist(tfd.Normal, loc=0.0, scale=lazy_var), name="y_lazy")
gb_lazy = lsl.GraphBuilder().add(y_lazy)
model_lazy = gb_lazy.build_model()

print(f"After model building - lazy_var is weak: {lazy_var.weak}")
print(f"Lazy evaluation with auto_transform still works!")
print(f"Model variables: {list(model_lazy.vars.keys())}")

DEBUG:2025-11-21 18:46:51,167:jax._src.dispatch:198: Finished tracing + transforming convert_element_type for pjit in 0.000802994 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming convert_element_type for pjit in 0.000802994 sec
DEBUG:2025-11-21 18:46:51,168:jax._src.interpreters.pxla:1861: Compiling jit(convert_element_type) with global shapes and types [ShapedArray(float32[2])]. Argument mapping: (UnspecifiedValue,).
DEBUG:jax._src.interpreters.pxla:Compiling jit(convert_element_type) with global shapes and types [ShapedArray(float32[2])]. Argument mapping: (UnspecifiedValue,).
DEBUG:2025-11-21 18:46:51,173:jax._src.dispatch:198: Finished jaxpr to MLIR module conversion jit(convert_element_type) in 0.004088640 sec
DEBUG:jax._src.dispatch:Finished jaxpr to MLIR module conversion jit(convert_element_type) in 0.004088640 sec
DEBUG:2025-11-21 18:46:51,175:jax._src.compiler:165: get_compile_options: num_replicas=1 num_partitions=1 device_assignment=[[CpuDevice(id=0)]]
DEBUG:jax.

Before model building - lazy_var is weak: False


DEBUG:2025-11-21 18:46:51,382:jax._src.compiler:831: Not writing persistent cache entry for 'jit_integer_pow' because it took < 1.00 seconds to compile (0.02s)
DEBUG:jax._src.compiler:Not writing persistent cache entry for 'jit_integer_pow' because it took < 1.00 seconds to compile (0.02s)
DEBUG:2025-11-21 18:46:51,384:jax._src.dispatch:198: Finished XLA compilation of jit(integer_pow) in 0.034348726 sec
DEBUG:jax._src.dispatch:Finished XLA compilation of jit(integer_pow) in 0.034348726 sec
DEBUG:2025-11-21 18:46:51,387:jax._src.dispatch:198: Finished tracing + transforming log1p for pjit in 0.000719309 sec
DEBUG:jax._src.dispatch:Finished tracing + transforming log1p for pjit in 0.000719309 sec
DEBUG:2025-11-21 18:46:51,390:jax._src.interpreters.pxla:1861: Compiling jit(log1p) with global shapes and types [ShapedArray(float32[])]. Argument mapping: (UnspecifiedValue,).
DEBUG:jax._src.interpreters.pxla:Compiling jit(log1p) with global shapes and types [ShapedArray(float32[])]. Argument

After model building - lazy_var is weak: True
Lazy evaluation with auto_transform still works!
Model variables: ['y_lazy', 'lazy_var', 'lazy_var_transformed']


### 6.3 Both APIs can coexist

In [23]:
# Use new API for one parameter
eager_param = lsl.Var(1.0, distribution=lsl.Dist(tfd.HalfCauchy, loc=0.0, scale=25.0), name="eager_param", bijector="auto")

# Use old API for another
lazy_param = lsl.Var.new_param(2.0, distribution=lsl.Dist(tfd.HalfCauchy, loc=0.0, scale=25.0), name="lazy_param")
lazy_param.auto_transform = True

print(f"eager_param is weak (eager): {eager_param.weak}")
print(f"lazy_param is weak (before model): {lazy_param.weak}")

# Build model with both
mixed_dist = lsl.Dist(tfd.Normal, loc=eager_param, scale=lazy_param)
y_mixed = lsl.Var.new_obs(jnp.array([1.0]), distribution=mixed_dist, name="y_mixed")
gb_mixed = lsl.GraphBuilder().add(y_mixed)
model_mixed = gb_mixed.build_model()

print(f"\nAfter model building:")
print(f"eager_param is weak: {eager_param.weak}")
print(f"lazy_param is weak: {lazy_param.weak}")

eager_param is weak (eager): True
lazy_param is weak (before model): False

After model building:
eager_param is weak: True
lazy_param is weak: True
