# *Lux*: Iterative Optimization for Linear Models

This tutorial demonstrates the {py:func}`~pollux.Lux.optimize_iterative` method, which provides an efficient alternating optimization agenda for *Lux* models with linear transforms. This method exploits the linear structure of the model to solve sub-problems exactly using weighted least squares, often converging faster than gradient-based optimization with Adam (the default for {py:func}`~pollux.Lux.optimize`).

This tutorial builds on the [Getting Started tutorial](Lux-linear-simulated-data.ipynb), so we'll skip the basic introductions and jump straight to comparing the two optimization approaches.

We'll start with the standard imports:

In [None]:
import time

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import numpyro
from helpers import make_simulated_linear_data

import pollux as plx
from pollux.models import LinearTransform

jax.config.update("jax_enable_x64", True)
%matplotlib inline

## Generating simulated data

We'll generate the same simulated data as in the Getting Started tutorial:

In [None]:
n_stars = 2048
n_latents = 8
n_labels = 2
n_flux = 128

rng = np.random.default_rng(seed=8675309)

A = np.zeros((n_labels, n_latents))
A[0, 0] = 1.0
A[1, 1] = 1.0

B = rng.normal(scale=0.1, size=(n_flux, n_latents))
B[:, 0] = B[:, 0] + 4 * np.exp(-0.5 * (np.arange(n_flux) - n_flux / 2) ** 2 / 5**2)
B[:, 1] = B[:, 1] + 2 * np.exp(-0.5 * (np.arange(n_flux) - n_flux / 4) ** 2 / 3**2)

data, truth = make_simulated_linear_data(
    n_stars=n_stars,
    n_latents=n_latents,
    n_flux=n_flux,
    n_labels=n_labels,
    A=A,
    B=B,
    rng=rng,
)

Package the data with preprocessors and create train/test splits:

In [None]:
all_data = plx.data.PolluxData(
    flux=plx.data.OutputData(
        data["flux"],
        err=data["flux_err"],
        preprocessor=plx.data.ShiftScalePreprocessor.from_data(data["flux"]),
    ),
    label=plx.data.OutputData(
        data["label"],
        err=data["label_err"],
        preprocessor=plx.data.ShiftScalePreprocessor.from_data(data["label"]),
    ),
)

preprocessed_data = all_data.preprocess()
train_data = preprocessed_data[: n_stars // 2]
test_data = preprocessed_data[n_stars // 2 :]

## Setting up the model

We create a *Lux* model with two linear outputs, exactly as in the Getting Started tutorial:

In [None]:
model = plx.LuxModel(latent_size=n_latents)
model.register_output("label", LinearTransform(output_size=n_labels))
model.register_output("flux", LinearTransform(output_size=n_flux))

## Comparing optimization methods

Now we'll compare the standard `optimize()` method with the new `optimize_iterative()` function.

### Standard optimization with `optimize()`

The standard approach uses gradient-based optimization (SVI with Adam) to jointly optimize all parameters:

In [None]:
t0 = time.time()
opt_pars_svi, svi_results = model.optimize(
    train_data,
    rng_key=jax.random.PRNGKey(112358),
    optimizer=numpyro.optim.Adam(1e-3),
    num_steps=10_000,
    svi_run_kwargs={"progress_bar": False},
)
svi_results.losses.block_until_ready()
svi_time = time.time() - t0
print(f"SVI optimization time: {svi_time:.2f} seconds")

### Iterative optimization with `optimize_iterative()`

The iterative approach exploits the linear structure of the model. For linear transforms like `y = A @ z`, the optimal latents (z) given A, and the optimal A given z, can each be solved exactly using weighted least squares. The algorithm alternates between these two steps:

1. **Fix output parameters, solve for latents**: Given the current A matrices, solve for optimal z using least squares
2. **Fix latents, solve for output parameters**: Given the current z, solve for optimal A matrices using least squares

This is repeated until convergence:

In [None]:
t0 = time.time()
iterative_result = model.optimize_iterative(
    train_data,
    max_cycles=50,
    tol=1e-6,
    rng_key=jax.random.PRNGKey(112358),
    progress=False,
)
iterative_time = time.time() - t0
print(f"Iterative optimization time: {iterative_time:.2f} seconds")
print(
    f"Converged: {iterative_result.converged} after {iterative_result.n_cycles} cycles"
)

## Comparing convergence

Let's visualize how the loss evolves for both methods:

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4), layout="constrained")

# SVI loss trajectory
axes[0].semilogy(svi_results.losses)
axes[0].set(xlabel="Step", ylabel="Loss", title="SVI Optimization")
axes[0].axhline(
    svi_results.losses[-1],
    color="tab:orange",
    ls="--",
    label=f"Final: {svi_results.losses[-1]:.1f}",
)
axes[0].legend()

# Iterative loss trajectory
axes[1].semilogy(iterative_result.losses_per_cycle)
axes[1].set(xlabel="Cycle", ylabel="Loss", title="Iterative Optimization")
axes[1].axhline(
    iterative_result.losses_per_cycle[-1],
    color="tab:orange",
    ls="--",
    label=f"Final: {iterative_result.losses_per_cycle[-1]:.1f}",
)
axes[1].legend()

plt.suptitle("Convergence Comparison", fontsize=14)

The iterative method typically converges in far fewer iterations (cycles) than SVI requires steps, and each cycle involves closed-form solutions rather than gradient computations.

## Evaluating on the test set

To evaluate the model on unseen data, we need to infer the latents for the test set while keeping the learned model parameters (A matrices) fixed. We do this by passing `fixed_pars` to both optimization methods. For the iterative method, we also pass the fixed parameters as `initial_params`:

In [None]:
# Create fixed_pars containing the trained model parameters (everything except latents)
fixed_pars_svi = {k: v for k, v in opt_pars_svi.items() if k != "latents"}

# Create test data with only flux (we want to predict labels using only flux)
test_flux_only = plx.data.PolluxData(flux=test_data["flux"])

# Optimize latents for test set using SVI (fix model parameters)
# Use names=["flux"] to only model the flux output (since we don't have label data)
test_pars_svi, _ = model.optimize(
    test_flux_only,
    rng_key=jax.random.PRNGKey(42),
    optimizer=numpyro.optim.Adam(1e-3),
    num_steps=2000,
    fixed_pars=fixed_pars_svi,
    names=["flux"],
    svi_run_kwargs={"progress_bar": False},
)
# Merge the fixed parameters back with the optimized latents
test_pars_svi = {**fixed_pars_svi, **test_pars_svi}

Now we do the same for the iterative method:

In [None]:
from pollux.models.iterative import ParameterBlock

opt_pars_iter = iterative_result.params
fixed_pars_iter = {k: v for k, v in opt_pars_iter.items() if k != "latents"}

# Optimize latents for test set using iterative method
# Only optimize latents (not output parameters) since we're using fixed trained params
# Pass initial_params with the trained A matrices and zero latents
initial_test_pars = {
    **fixed_pars_iter,
    "latents": jnp.zeros((len(test_flux_only), model.latent_size)),
}
test_blocks = [ParameterBlock("latents", "latents", optimizer="least_squares")]
test_result_iter = model.optimize_iterative(
    test_flux_only,
    blocks=test_blocks,
    max_cycles=50,
    tol=1e-6,
    initial_params=initial_test_pars,
    progress=False,
)
# Merge trained output params with optimized test latents
test_pars_iter = {**fixed_pars_iter, "latents": test_result_iter.params["latents"]}

In [None]:
# Predictions on test set from both methods
pred_svi = model.predict_outputs(test_pars_svi["latents"], test_pars_svi)
pred_iter = model.predict_outputs(test_pars_iter["latents"], test_pars_iter)

In [None]:
pt_style = {"ls": "none", "ms": 2.0, "alpha": 0.5, "marker": "o", "color": "k"}

fig, axes = plt.subplots(2, 2, figsize=(10, 10), layout="constrained")

# Top row: SVI predictions vs true
for i in range(2):
    axes[0, i].plot(pred_svi["label"][:, i], test_data["label"].data[:, i], **pt_style)
    axes[0, i].set(xlabel=f"Predicted label {i}", ylabel=f"True label {i}")
    axes[0, i].axline([0, 0], slope=1, color="tab:green", zorder=-100)
axes[0, 0].set_title("SVI: Label 0")
axes[0, 1].set_title("SVI: Label 1")

# Bottom row: Iterative predictions vs true
for i in range(2):
    axes[1, i].plot(pred_iter["label"][:, i], test_data["label"].data[:, i], **pt_style)
    axes[1, i].set(xlabel=f"Predicted label {i}", ylabel=f"True label {i}")
    axes[1, i].axline([0, 0], slope=1, color="tab:green", zorder=-100)
axes[1, 0].set_title("Iterative: Label 0")
axes[1, 1].set_title("Iterative: Label 1")

fig.suptitle("Test Set: Predicted vs. True Labels", fontsize=16)

Both the SVI and iterative methods visually seem to predict the test set labels, but the iterative optimization appears to yield slightly better accuracy for this toy example.

Here's another way to look at the test set prediction accuracy:

In [None]:
# Compute prediction errors on test set
svi_label_rmse = np.sqrt(np.mean((pred_svi["label"] - test_data["label"].data) ** 2))
iter_label_rmse = np.sqrt(np.mean((pred_iter["label"] - test_data["label"].data) ** 2))

svi_flux_rmse = np.sqrt(np.mean((pred_svi["flux"] - test_data["flux"].data) ** 2))
iter_flux_rmse = np.sqrt(np.mean((pred_iter["flux"] - test_data["flux"].data) ** 2))

print("Test Set Performance Comparison")
print("=" * 50)
print(f"{'Method':<20} {'Time (s)':<12} {'Label RMSE':<15} {'Flux RMSE':<15}")
print("-" * 50)
print(
    f"{'SVI (10k steps)':<20} {svi_time:<12.2f} {svi_label_rmse:<15.4f} {svi_flux_rmse:<15.4f}"
)
print(
    f"{'Iterative':<20} {iterative_time:<12.2f} {iter_label_rmse:<15.4f} {iter_flux_rmse:<15.4f}"
)

## When to use iterative optimization

The iterative optimization approach works well when your model is purely linear and should out-perform gradient-based methods in terms of speed, convergence, and often on prediction accuracy as well, given the closed-form solutions available for linear least squares problems. 

For models with non-linear transforms (e.g., neural networks, Gaussian processes), you should use the standard `optimize()` method with gradient-based optimization.