From bcf84ea46b27f9967c148a78d5ff6d31c6f14a41 Mon Sep 17 00:00:00 2001 From: Valentin De Bortoli Date: Wed, 6 May 2026 10:06:29 -0700 Subject: [PATCH] Unify discrete samplers with 3-way routing representation. PiperOrigin-RevId: 911402681 --- hackable_diffusion/lib/sampling/__init__.py | 2 + .../lib/sampling/discrete_step_sampler.py | 710 ++++++++++++------ .../sampling/discrete_step_sampler_test.py | 300 +++++++- 3 files changed, 765 insertions(+), 247 deletions(-) diff --git a/hackable_diffusion/lib/sampling/__init__.py b/hackable_diffusion/lib/sampling/__init__.py index eeca37a..29bcfcc 100644 --- a/hackable_diffusion/lib/sampling/__init__.py +++ b/hackable_diffusion/lib/sampling/__init__.py @@ -31,6 +31,8 @@ from hackable_diffusion.lib.sampling.discrete_step_sampler import NoRemaskingFn from hackable_diffusion.lib.sampling.discrete_step_sampler import RemaskingFn from hackable_diffusion.lib.sampling.discrete_step_sampler import RescaledRemaskingFn +from hackable_diffusion.lib.sampling.discrete_step_sampler import RoutingStrategy +from hackable_diffusion.lib.sampling.discrete_step_sampler import RoutingWeights from hackable_diffusion.lib.sampling.discrete_step_sampler import UnMaskingStep from hackable_diffusion.lib.sampling.gaussian_step_sampler import AdjustedDDIMStep from hackable_diffusion.lib.sampling.gaussian_step_sampler import DDIMStep diff --git a/hackable_diffusion/lib/sampling/discrete_step_sampler.py b/hackable_diffusion/lib/sampling/discrete_step_sampler.py index 38551fe..746489d 100644 --- a/hackable_diffusion/lib/sampling/discrete_step_sampler.py +++ b/hackable_diffusion/lib/sampling/discrete_step_sampler.py @@ -27,11 +27,30 @@ state. The update is also in charge of computing other auxiliary informations such as volatility, drifts, etc. -The `InferenceFn is also called within the step and converted into the +The `InferenceFn` is also called within the step and converted into the relevant representation, for instance score, velocity, etc. + +This module also introduces the concepts of Routing and Planning for discrete +diffusion: +- Routing: Defines the transition probabilities at each position among the + three actions: stay at current token, sample from invariant distribution + (noise), + or use predicted clean token. Samplers compute these weights based on the + diffusion posterior. +- Planning: An optional mechanism that intercepts and modifies these routing + weights before they are applied. For example, a planner might force the + most confident positions to go clean (budgeting) instead of sampling + stochastically. + +How they interact: +Samplers first compute the default routing weights. If a planner is present, +it transforms these weights (e.g., zeroing out some pathways, forcing others). +Finally, `_sample_routing` samples the next state based on these (possibly +modified) weights. """ import dataclasses +import enum from typing import Protocol from hackable_diffusion.lib import hd_typing @@ -52,6 +71,7 @@ DataArray = hd_typing.DataArray TargetInfo = hd_typing.TargetInfo TimeArray = hd_typing.TimeArray +PRNGKey = hd_typing.PRNGKey DiffusionStep = base.DiffusionStep StepInfo = base.StepInfo @@ -225,6 +245,121 @@ def __call__(self, xt: DataArray) -> DataArray: return xt == mask_value +################################################################################ +# MARK: Routing and planning +################################################################################ + +# Almost all discrete samplers compute a 3-way routing for each token position: +# 0 = stay at current token (xt) +# 1 = sample from invariant distribution (noise) +# 2 = use predicted clean token (x0) +# +# The routing weights (p_stay, p_noise, p_clean) are computed by each +# sampler and applied via the shared `_sample_routing` helper. +# IntegratedDiscreteDDIMStep is an exception as it integrates the routing +# probabilities into the update rule. + + +class RoutingAction(enum.IntEnum): + STAY = 0 + NOISE = 1 + CLEAN = 2 + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class RoutingWeights: + stay: Float['...'] + noise: Float['...'] + clean: Float['...'] + + +def _sample_routing( + *, + routing_weights: RoutingWeights, + xt: DataArray, + x0: DataArray, + x_noise: DataArray, + key: PRNGKey, +) -> DataArray: + """Apply 3-way routing to construct the next state. + + 3-way routing determines the next state by sampling from a mixture of three + possible actions at each position: + 1. STAY: Keep the current token `xt`. + 2. NOISE: Sample a new token from the invariant distribution `x_noise`. + 3. CLEAN: Use the predicted clean token `x0`. + + The computation operates by: + 1. Concatenating the weights for stay, noise, and clean along the last axis. + 2. Sampling an action index (0, 1, or 2) for each position using + `jax.random.categorical`. + 3. Selecting the corresponding token (xt, x_noise, or x0) based on the sampled + action. + + Args: + routing_weights: Routing weights containing stay, noise, and clean arrays. + xt: Current state. Shape ``(*, 1)``. + x0: Predicted clean state. Shape ``(*, 1)``. + x_noise: Sample from invariant distribution. Shape ``(*, 1)``. + key: Random key for categorical sampling. + + Returns: + The new state ``new_xt``. Shape ``(*, 1)``. + """ + weights = jnp.concatenate( + [routing_weights.stay, routing_weights.noise, routing_weights.clean], + axis=-1, + ) + action = jax.random.categorical( + key=key, logits=jnp.log(jnp.maximum(weights, 1e-12)) + ) + new_xt = jnp.where( + action[..., None] == RoutingAction.CLEAN, + x0, + jnp.where(action[..., None] == RoutingAction.NOISE, x_noise, xt), + ) + return new_xt + + +class RoutingStrategy(Protocol): + """Protocol for transforming routing weights. + + A planner takes the routing weights computed by a sampler and + optionally transforms them before they are applied via ``_sample_routing``. + This allows injecting different selection strategies (e.g. greedy top-k) + without modifying the sampler logic. + + When no planner is used (``planner=None``), the routing weights are + applied as-is via stochastic categorical sampling. + """ + + def __call__( + self, + routing_weights: RoutingWeights, + logits: Float['... M'], + x0: DataArray, + xt: DataArray, + time: TimeArray, + next_time: TimeArray, + key: PRNGKey, + ) -> RoutingWeights: + """Transforms routing weights. + + Args: + routing_weights: Per-position routing weights. + logits: Model logits ``(*, M)``. + x0: Sampled clean token ``(*, 1)``. + xt: Current state ``(*, 1)``. + time: Current diffusion time. + next_time: Next diffusion time. + key: Random key. + + Returns: + Transformed routing weights. + """ + ... + + ################################################################################ # MARK: UnMasking Step ################################################################################ @@ -234,6 +369,22 @@ def __call__(self, xt: DataArray) -> DataArray: class UnMaskingStep(SamplerStep): """Unmasking step following https://arxiv.org/abs/2406.04329. + This sampler uses the 3-way routing representation. For each token position + we compute the probabilities of three actions: + - STAY: keep the current token. + - NOISE: sample from the invariant distribution (remasking). + - CLEAN: use the predicted clean token x0. + + For masked tokens: + p_clean = (alpha_s - (1 - p_st) * alpha_t) / (1 - alpha_t) + p_noise = p_st + p_stay = 1 - p_clean - p_noise + For unmasked tokens: + p_clean = 0 + p_noise = p_st + p_stay = 1 - p_st + where p_st is the remasking rate. + Attributes: corruption_process: The corruption process to use. remasking_fn: The remasking function to use, see @@ -246,6 +397,7 @@ class UnMaskingStep(SamplerStep): """ corruption_process: CategoricalProcess + planner: RoutingStrategy | None = None remasking_fn: RemaskingFn = NoRemaskingFn() corruption_mask_fn: CorruptedMaskFn = AllCorruptedMaskFn() temperature: float = 1.0 @@ -297,56 +449,84 @@ def update( time = current_step_info.time next_time = next_step_info.time - time = utils.bcast_right(time, xt.ndim) - next_time = utils.bcast_right(next_time, xt.ndim) + time_bcast = utils.bcast_right(time, xt.ndim) + next_time_bcast = utils.bcast_right(next_time, xt.ndim) key = next_step_info.rng - # Sample from p_{0|t} - + # Get model predictions logits = self.corruption_process.convert_predictions( prediction, xt, - time, + time_bcast, )['logits'] logits = logits / self.temperature - key, subkey = jax.random.split(key) - sample = jax.random.categorical(key=subkey, logits=logits)[..., None] - # (bsz, *seq_len, 1) - - # Split xt into masked and unmasked regions - - currently_masked = self.corruption_mask_fn(xt) - currently_unmasked = jnp.invert(currently_masked) - - # Denoising + _, x0_key, noise_key, plan_key, route_key = jax.random.split(key, 5) - alpha_s = self.corruption_process.schedule.alpha(next_time) - alpha_t = self.corruption_process.schedule.alpha(time) - - p_st = self.remasking_fn(s=next_time, t=time) + # Sample candidates + x0 = jax.random.categorical(key=x0_key, logits=logits)[..., None] + x_noise = self.corruption_process.sample_from_invariant( + noise_key, data_spec=xt + ) - prob = (alpha_s - (1.0 - p_st) * alpha_t) / (1.0 - alpha_t) - # Denoising probability following https://arxiv.org/abs/2503.00307v1 - # If no remasking, p_st = 0, so prob = (alpha_s - alpha_t) / (1.0 - alpha_t) - prob = jnp.broadcast_to(prob, currently_masked.shape) + currently_masked = self.corruption_mask_fn(xt) # (bsz, seq_len, 1) - key, subkey = jax.random.split(key) - to_unmask = currently_masked * jax.random.bernoulli(subkey, prob) + # Denoising rates + alpha_s = self.corruption_process.schedule.alpha(next_time_bcast) + alpha_t = self.corruption_process.schedule.alpha(time_bcast) - new_xt = jnp.where(to_unmask, sample, xt) + # Routing decomposition logic + # See docstring for formulae and https://arxiv.org/abs/2503.00307 for + # details. + + p_st = self.remasking_fn(s=next_time_bcast, t=time_bcast) + + p_clean_masked = (alpha_s - (1.0 - p_st) * alpha_t) / (1.0 - alpha_t) + p_noise_masked = p_st + p_stay_masked = 1.0 - p_clean_masked - p_noise_masked + # Denoising probability following https://arxiv.org/abs/2503.00307 + # If no remasking (https://arxiv.org/abs/2503.00307), p_st = 0, + # so p_clean = (alpha_s - alpha_t) / (1.0 - alpha_t) + # These are the routing weights for masked positions xt. + # With prob p_clean, we replace xt with the predicted token x0. + # With prob p_noise, we replace xt with the invariant token x_noise. + # With prob p_stay, we keep the current token xt. + + # Routing weights for unmasked tokens: + p_stay_unmasked = 1.0 - p_st + p_noise_unmasked = p_st + p_clean_unmasked = jnp.zeros_like(p_st) + # Same as above, but for unmasked positions. + # Note that if p_st = 0, then p_noise = 0, and p_stay = 1, which means + # that unmasked tokens are never remasked. + + # Combine based on masking state + # See https://arxiv.org/abs/2503.00307 for an example of the combination of + # probabilities for masked and unmasked tokens. + p_stay = jnp.where(currently_masked, p_stay_masked, p_stay_unmasked) + p_noise = jnp.where(currently_masked, p_noise_masked, p_noise_unmasked) + p_clean = jnp.where(currently_masked, p_clean_masked, p_clean_unmasked) + + routing_weights = RoutingWeights(stay=p_stay, noise=p_noise, clean=p_clean) + # (bsz, seq_len, 3) + + # Apply planner transformation (if any) + if self.planner: + routing_weights = self.planner( + routing_weights, logits, x0, xt, time, next_time, plan_key + ) - # Renoising following https://arxiv.org/abs/2503.00307 - key_noise, key_remask = jax.random.split(key) - noise_sample = self.corruption_process.sample_from_invariant( - key=key_noise, - data_spec=xt, + # xt ~ p(x_s|x_0, x_t) (with optional remasking and planning) + new_xt = _sample_routing( + routing_weights=routing_weights, + xt=xt, + x0=x0, + x_noise=x_noise, + key=route_key, ) - p_st = jnp.broadcast_to(p_st, currently_unmasked.shape) - to_remask = currently_unmasked * jax.random.bernoulli(key_remask, p_st) + # This is the new state after sampling using the routing weights. - new_xt = jnp.where(to_remask, noise_sample, new_xt) new_xt = self.corruption_process.post_corruption_fn(new_xt) # Replace the unused tokens with the unused_token. @@ -389,21 +569,57 @@ class DiscreteDDIMStep(SamplerStep): Diffusion Models in Discrete State-Spaces" (known as D3PM, see https://arxiv.org/abs/2107.03006). - Given the forward process with density p(x_t|x_0) it computes the reverse - process by first sampling from p(x_0|x_t) to obtain x_0. + This sampler uses the 3-way routing representation. Given the forward process + with density p(x_t|x_0), we decompose the reverse posterior + p(x_s|x_t, x_0) into three components: - Then it samples x_s (for s < t) using the following formula: + p(x_s|x_t,x_0) = p_stay * δ_{x_t}(x_s) + p_noise * π(x_s) + + p_clean * δ_{x_0}(x_s) - p(x_s|x_t,x_0) ∝ p(x_s|x_0) * p(x_t|x_s) (1) + where: + - p_stay: probability of staying at x_t + - p_noise: probability of jumping to invariant noise + - p_clean: probability of jumping to the predicted x_0 - In order to compute (1) we recall that for any s, t such that s < t we have: + **Derivation.** Recall that for the forward process: - p(x_t|x_s) = (α_t/α_s) * δ_{x_s}(x_t) + (1 - α_t/α_s) * π(x_t) + p(x_t|x_s) = r * δ_{x_t}(x_s) + (1 - r) * π(x_t) + p(x_s|x_0) = α_s * δ_{x_0}(x_s) + (1 - α_s) * π(x_s) - The computation of the probability happens in the logits space. + where r = α_t/α_s. The posterior is proportional to their product: + + p(x_s|x_t,x_0) ∝ p(x_t|x_s) * p(x_s|x_0) + + Expanding gives four cross-terms: + + (T1) r * α_s * δ_{x_t}(x_s) * δ_{x_0}(x_s) + (T2) r * (1-α_s) * δ_{x_t}(x_s) * π(x_s) = r*(1-α_s)*π(x_t) * δ_{x_t} + (T3) (1-r) * α_s * π(x_t) * δ_{x_0}(x_s) + (T4) (1-r) * (1-α_s) * π(x_t) * π(x_s) + + Collecting by routing outcome: + - p_stay ∝ r * (1-α_s) * π(x_t) [T2: δ_{x_t} · π gives π(x_t)*δ_{x_t}] + - p_noise ∝ (1-r) * (1-α_s) * π(x_t) [T4: π(x_t) · π(x_s)] + - p_clean ∝ (1-r) * α_s * π(x_t) [T3: π(x_t) · δ_{x_0}] + + **Handling x_0 = x_t.** When x_0 = x_t, routing to CLEAN produces the same + output as STAY (both emit x_t). We therefore merge all such mass into p_stay: + + 1. T1 contributes r * α_s to p_stay (the fourth cross-term + δ_{x_t} * δ_{x_0}, which is non-zero only when x_0 = x_t). + 2. p_clean is added to p_stay and then zeroed out, since the CLEAN action + would be a no-op. + + This ensures p_clean = 0 whenever x_0 = x_t, which is important for + planners that use p_clean > 0 as an eligibility signal (e.g. GreedyPlanner). + + Note: when π = δ_MASK (masking process) and x_t = MASK, this reduces to: + P(unmask) = (α_s - α_t) / (1 - α_t), + which coincides with the UnMaskingStep formula (without remasking). """ corruption_process: CategoricalProcess + planner: RoutingStrategy | None = None temperature: float = 1.0 logits_dtype: jnp.dtype = jnp.float32 @@ -452,65 +668,226 @@ def update( xt = current_step.xt unused_mask = xt == self.corruption_process.unused_token - # The mask is True if the token is unused. time = current_step_info.time next_time = next_step_info.time - time = utils.bcast_right(time, xt.ndim) - next_time = utils.bcast_right(next_time, xt.ndim) + time_bcast = utils.bcast_right(time, xt.ndim) + next_time_bcast = utils.bcast_right(next_time, xt.ndim) key = next_step_info.rng - # Sample from p_{0|t} + # Get model predictions logits = self.corruption_process.convert_predictions( prediction, xt, - time, + time_bcast, )['logits'] logits = logits / self.temperature - x0 = jax.random.categorical(key=key, logits=logits)[..., None] + _, x0_key, noise_key, plan_key, route_key = jax.random.split(key, 5) + + # Sample candidates + x0 = jax.random.categorical(key=x0_key, logits=logits)[..., None] + x_noise = self.corruption_process.sample_from_invariant( + noise_key, data_spec=xt + ) + + # Schedule + alpha_s = self.corruption_process.schedule.alpha(next_time_bcast) + alpha_t = self.corruption_process.schedule.alpha(time_bcast) + ratio = alpha_t / alpha_s # (bsz, *seq_len, 1) - key, _ = jax.random.split(key) - # Compute the probability vector + # Routing weights (unnormalized). + # See the class docstring for the full derivation of terms T1–T4. + pi_xt = self.corruption_process.invariant_probs_vec[xt[..., 0]][..., None] + + # T2 → stay, T4 → noise, T3 → clean + p_stay = ratio * (1.0 - alpha_s) * pi_xt + p_noise = (1.0 - ratio) * (1.0 - alpha_s) * pi_xt + p_clean = (1.0 - ratio) * alpha_s * pi_xt + + # When x_0 = x_t, routing to CLEAN produces the same output as STAY. + # Merge T1 (r * α_s) and p_clean into p_stay, and zero out p_clean. + # This ensures planners see p_clean = 0 for no-op positions. + x0_eq_xt = (x0 == xt).astype(jnp.float32) + p_stay = p_stay + x0_eq_xt * (ratio * alpha_s + p_clean) + p_clean = (1.0 - x0_eq_xt) * p_clean + + routing_weights = RoutingWeights(stay=p_stay, noise=p_noise, clean=p_clean) + # (bsz, *seq_len, 3) + + # Apply planner transformation (if any) + if self.planner: + routing_weights = self.planner( + routing_weights, logits, x0, xt, time, next_time, plan_key + ) - xt_oh = jax.nn.one_hot( - xt[..., 0], num_classes=self.corruption_process.process_num_categories + # xt ~ p(x_s|x_0, x_t) + # This is the new state after sampling using the routing weights. + new_xt = _sample_routing( + routing_weights=routing_weights, + xt=xt, + x0=x0, + x_noise=x_noise, + key=route_key, ) - x0_oh = jax.nn.one_hot( - x0[..., 0], num_classes=self.corruption_process.process_num_categories + + new_xt = self.corruption_process.post_corruption_fn(new_xt) + + # Replace the unused tokens with the unused_token. + new_xt = jnp.where( + unused_mask, self.corruption_process.unused_token, new_xt ) - # (bsz, *seq_len, M) - alpha_s = self.corruption_process.schedule.alpha(next_time) - alpha_t = self.corruption_process.schedule.alpha(time) - alpha_s = jnp.broadcast_to(alpha_s, x0_oh.shape) - alpha_t = jnp.broadcast_to(alpha_t, x0_oh.shape) - ratio = alpha_t / alpha_s - # (bsz, *seq_len, M) + return DiffusionStep( + xt=new_xt, + step_info=next_step_info, + aux={'logits': logits}, + ) + # `logits` need to be passed in `aux` dictionary to a performance + # bug when using TPU. Needs to be investigated. + + @kt.typechecked + def finalize( + self, + prediction: TargetInfo, + current_step: DiffusionStep, + last_step_info: StepInfo, + ) -> DiffusionStep: + return self.update( + prediction, + current_step, + last_step_info, + ) + + +################################################################################ +# MARK: Discrete Flow Matching Step +################################################################################ + + +@dataclasses.dataclass(frozen=True, kw_only=True) +class DiscreteFlowMatchingStep(SamplerStep): + """Discrete Flow Matching step following https://arxiv.org/abs/2407.15595. + + This sampler uses the 3-way routing representation. The update rule + decomposes naturally into: + + p(x_s) = p_stay * δ_{x_t} + p_up * p_x0 + p_down * π + + where: + - p_stay = 1 - p_up - p_down + - p_up = (α_s - α_t) / (1 - α_t) * (1 + stoch_coeff) + - p_down = (α_s - α_t) / α_t * stoch_coeff + + Attributes: + corruption_process: The corruption process to use. + temperature: The temperature to use. + stoch_coeff: The stochasticity coefficient (default 0.0). Higher values + introduce more noise during the denoising process. + """ - first_logit = jnp.log( - ratio * xt_oh - + (1.0 - ratio) * self.corruption_process.invariant_probs_vec[xt] + corruption_process: CategoricalProcess + planner: RoutingStrategy | None = None + temperature: float = 1.0 + stoch_coeff: float = 0.0 + + @kt.typechecked + def initialize( + self, + initial_noise: DataArray, + initial_step_info: StepInfo, + ) -> DiffusionStep: + + init_logits = jnp.repeat( + initial_noise, self.corruption_process.num_categories, axis=-1 ) - second_logit = jnp.log( - alpha_s * x0_oh - + (1.0 - alpha_s) * self.corruption_process.invariant_probs_vec + init_logits = jnp.zeros_like(init_logits, dtype=jnp.float32) - jnp.inf + + return DiffusionStep( + xt=initial_noise, + step_info=initial_step_info, + aux={'logits': init_logits}, + ) + + @kt.typechecked + def update( + self, + prediction: TargetInfo, + current_step: DiffusionStep, + next_step_info: StepInfo, + ) -> DiffusionStep: + + current_step_info = current_step.step_info + xt = current_step.xt + + unused_mask = xt == self.corruption_process.unused_token + + time = current_step_info.time + next_time = next_step_info.time + time_bcast = utils.bcast_right(time, xt.ndim) + next_time_bcast = utils.bcast_right(next_time, xt.ndim) + key = next_step_info.rng + + # Get model predictions + logits = self.corruption_process.convert_predictions( + prediction, + xt, + time_bcast, + )['logits'] + logits = logits / self.temperature + + _, x0_key, noise_key, plan_key, route_key = jax.random.split(key, 5) + + # Sample candidates + x0 = jax.random.categorical(key=x0_key, logits=logits)[..., None] + x_noise = self.corruption_process.sample_from_invariant( + noise_key, data_spec=xt + ) + + # Denoising rates + alpha_s = self.corruption_process.schedule.alpha(next_time_bcast) + alpha_t = self.corruption_process.schedule.alpha(time_bcast) + + prob_up = ( + (alpha_s - alpha_t) + / jnp.maximum(1.0 - alpha_t, 1e-12) + * (1.0 + self.stoch_coeff) + ) + prob_down = ( + (alpha_s - alpha_t) / jnp.maximum(alpha_t, 1e-12) * self.stoch_coeff + ) + + # Clip and rescale to ensure valid probabilities + raw_p_up = jnp.maximum(prob_up, 0.0) + raw_p_down = jnp.maximum(prob_down, 0.0) + sum_jumps = raw_p_up + raw_p_down + scale_factor = jnp.maximum(1.0, sum_jumps) + + # Compute the probabilities for the three routing options. + # This is computed according to https://arxiv.org/abs/2407.15595. + p_clean = raw_p_up / scale_factor + p_noise = raw_p_down / scale_factor + p_stay = 1.0 - p_clean - p_noise + + routing_weights = RoutingWeights(stay=p_stay, noise=p_noise, clean=p_clean) + # (bsz, *seq_len, 3) + + # Apply planner transformation (if any) + if self.planner: + routing_weights = self.planner( + routing_weights, logits, x0, xt, time, next_time, plan_key + ) + + # xt ~ p(x_s|x_0, x_t) + # This is the new state after sampling using the routing weights. + new_xt = _sample_routing( + routing_weights=routing_weights, + xt=xt, + x0=x0, + x_noise=x_noise, + key=route_key, ) - total_logit = first_logit + second_logit - # Do not use this sampler for masking. - # What could happen is xt is unmasked (assume at first position) so the - # first logits (first_logit) is [value, -inf, ..., -inf]. Then assume - # that the predictionfor x0 is different than xt - # (can never happen in unmasking),assume that the second position is the - # one chosen by the x0 predictor. Then we have for the second logits - # (second_logit): [-inf, value, -inf, ..., -inf, value_mask]. - # So when we add them together we get [-inf, ..., -inf]. - # jax.random.categorical will then return the first position. - # This is not what we want and this behavior should not be accepted. - - # Sample from the distribution defined by logits - new_xt = jax.random.categorical(key=key, logits=total_logit)[..., None] new_xt = self.corruption_process.post_corruption_fn(new_xt) # Replace the unused tokens with the unused_token. @@ -523,8 +900,6 @@ def update( step_info=next_step_info, aux={'logits': logits}, ) - # `logits` need to be passed in `aux` dictionary to a performance - # bug when using TPU. Needs to be investigated. @kt.typechecked def finalize( @@ -543,6 +918,10 @@ def finalize( ################################################################################ # MARK: Integrated DDIM Step ################################################################################ +# Note: IntegratedDiscreteDDIMStep does NOT fit the 3-way routing scheme +# because it marginalizes over x_0 rather than sampling a single x_0. +# It is kept as-is with direct categorical sampling. +################################################################################ @dataclasses.dataclass(frozen=True, kw_only=True) @@ -573,7 +952,8 @@ class IntegratedDiscreteDDIMStep(SamplerStep): In particular, we use the following formula: - p(x_s|x_t) = p(x_t|x_s) * sum_{x_0} (p(x_0|x_t) / p(x_t|x_0)) p(x_s|x_0) (2) + p(x_s|x_t) = p(x_t|x_s) * sum_{x_0} (p(x_0|x_t) / p(x_t|x_0)) p(x_s|x_0) + (2) Denoting w(x_0, x_t) = p(x_0|x_t) / p(x_t|x_0) and W(x_t) = sum_{x_0} w(x_0, x_t) we have: @@ -684,10 +1064,8 @@ def update( p_xs = q_xt_given_xs * expected_xs_given_x0 # (bsz, *seq_len, M) - # Convert back to logits for safe categorical sampling + # Convert to logits and sample. total_logit = jnp.log(jnp.clip(p_xs, min=1e-12)) - - # Sample and format the new state new_xt = jax.random.categorical(key=key, logits=total_logit)[..., None] new_xt = self.corruption_process.post_corruption_fn(new_xt) @@ -716,161 +1094,3 @@ def finalize( current_step, last_step_info, ) - - -################################################################################ -# MARK: Discrete Flow Matching Step -################################################################################ - - -@dataclasses.dataclass(frozen=True, kw_only=True) -class DiscreteFlowMatchingStep(SamplerStep): - """Discrete Flow Matching step following https://arxiv.org/abs/2407.15595. - - This sampler is the simplest variant of Algorithm 1 in Discrete Flow Matching, - Gat et. al., 2024, https://arxiv.org/abs/2407.15595. It implements the - update rule based on the probability velocity derived for the probability - path family in (9). - - The update rule is: - x_{t-dt} ~ (1 - prob_jump) * delta_{x_t} + prob_jump * prediction - - where prob_jump = (alpha_s - alpha_t) / (1 - alpha_t). Note that alpha(t) in - this codebase is the probability of keeping the original value, which - corresponds to 1 - kappa(t) in the paper if the time is reversed. - - Attributes: - corruption_process: The corruption process to use. - temperature: The temperature to use. - gamma: The corrector term (default 0.0). Higher values introduce more noise - during the denoising process, which can improve sample quality. - """ - - corruption_process: CategoricalProcess - temperature: float = 1.0 - gamma: float = 0.0 - logits_dtype: jnp.dtype = jnp.float32 - - @kt.typechecked - def initialize( - self, - initial_noise: DataArray, - initial_step_info: StepInfo, - ) -> DiffusionStep: - - init_logits = jnp.repeat( - initial_noise, self.corruption_process.num_categories, axis=-1 - ) - init_logits = jnp.zeros_like(init_logits, dtype=self.logits_dtype) - - return DiffusionStep( - xt=initial_noise, - step_info=initial_step_info, - aux={'logits': init_logits}, - ) - - @kt.typechecked - def update( - self, - prediction: TargetInfo, - current_step: DiffusionStep, - next_step_info: StepInfo, - ) -> DiffusionStep: - - current_step_info = current_step.step_info - xt = current_step.xt - - unused_mask = xt == self.corruption_process.unused_token - - time = current_step_info.time - next_time = next_step_info.time - time_bcast = utils.bcast_right(time, xt.ndim) - next_time_bcast = utils.bcast_right(next_time, xt.ndim) - key = next_step_info.rng - - # Sample from p_{0|t} - logits = self.corruption_process.convert_predictions( - prediction, - xt, - time_bcast, - )['logits'] - logits = logits / self.temperature - - _, sample_key, noise_key, jump_key = jax.random.split(key, 4) - sample = jax.random.categorical(key=sample_key, logits=logits)[..., None] - noise_sample = self.corruption_process.sample_from_invariant( - noise_key, data_spec=xt - ) - - # Denoising - alpha_s = self.corruption_process.schedule.alpha(next_time_bcast) - alpha_t = self.corruption_process.schedule.alpha(time_bcast) - - # prob_up is the probability of switching from the current state to the - # predicted data state. Following the paper's formula (24): - # u_fwd = (dot_kappa / (1 - kappa)) * (p_data - delta_xt) - # prob_down is the probability of switching back to noise (corrector logic): - # u_bwd = (dot_kappa / kappa) * (delta_xt - p_noise) - # Following the paper's formula (26), the combined velocity is: - # u_bar = (1 + gamma) * u_fwd - gamma * u_bwd. - # Note that since u_bwd (u^(0) in the paper) involves (delta_xt - p_noise), - # it has negative jump rates back to noise. Subtracting it (-gamma * u_bwd) - # results in positive jump probabilities in the discretization. - - # We discretize this as a jump process where each token has probability - # prob_up of jumping to data and prob_down of jumping to noise. - - prob_up = ( - (alpha_s - alpha_t) - / jnp.maximum(1.0 - alpha_t, 1e-12) - * (1.0 + self.gamma) - ) - prob_down = (alpha_s - alpha_t) / jnp.maximum(alpha_t, 1e-12) * self.gamma - - # Calculate raw, unclipped probabilities - raw_p_up = jnp.maximum(prob_up, 0.0) - raw_p_down = jnp.maximum(prob_down, 0.0) - sum_jumps = raw_p_up + raw_p_down - - # If the sum exceeds 1.0, scale them down proportionally to maintain their - # ratio - scale_factor = jnp.maximum(1.0, sum_jumps) - - p_up = raw_p_up / scale_factor - p_down = raw_p_down / scale_factor - p_stay = 1.0 - p_up - p_down - - probs = jnp.stack([p_stay, p_up, p_down], axis=-1) - probs = jnp.broadcast_to(probs, xt.shape + (3,)) - jump_type = jax.random.categorical( - jump_key, logits=jnp.log(jnp.maximum(probs, 1e-12)) - ) - - # 0: stay, 1: jump to data, 2: jump to noise - new_xt = jnp.where(jump_type == 1, sample, xt) - new_xt = jnp.where(jump_type == 2, noise_sample, new_xt) - new_xt = self.corruption_process.post_corruption_fn(new_xt) - - # Replace the unused tokens with the unused_token. - new_xt = jnp.where( - unused_mask, self.corruption_process.unused_token, new_xt - ) - - return DiffusionStep( - xt=new_xt, - step_info=next_step_info, - aux={'logits': logits}, - ) - - @kt.typechecked - def finalize( - self, - prediction: TargetInfo, - current_step: DiffusionStep, - last_step_info: StepInfo, - ) -> DiffusionStep: - return self.update( - prediction, - current_step, - last_step_info, - ) diff --git a/hackable_diffusion/lib/sampling/discrete_step_sampler_test.py b/hackable_diffusion/lib/sampling/discrete_step_sampler_test.py index 75cd683..9301226 100644 --- a/hackable_diffusion/lib/sampling/discrete_step_sampler_test.py +++ b/hackable_diffusion/lib/sampling/discrete_step_sampler_test.py @@ -572,7 +572,7 @@ def test_initialize(self): init_logits = jnp.repeat( self.initial_noise, self.process.num_categories, axis=-1 ) - init_logits = jnp.zeros_like(init_logits, dtype=jnp.float32) + init_logits = jnp.zeros_like(init_logits, dtype=jnp.float32) - jnp.inf chex.assert_trees_all_equal( initial_step, @@ -649,7 +649,7 @@ def test_update_with_gamma(self): # Use gamma that won't clip. dfm_step_gamma = discrete_step_sampler.DiscreteFlowMatchingStep( - corruption_process=self.process, gamma=1.0 + corruption_process=self.process, stoch_coeff=1.0 ) next_step_info = StepInfo( @@ -667,5 +667,301 @@ def test_update_with_gamma(self): self.assertTrue(jnp.any(next_step.xt != 1)) +class DDIMRoutingEquivalenceTest(absltest.TestCase): + """Verify routing-based DDIM matches the original logit-space computation. + + The original DiscreteDDIMStep computed the reverse posterior in full + M-dimensional logit space: + + first_logit[k] = log(r * 1[k=xt] + (1-r) * π(xt)) + second_logit[k] = log(αs * 1[k=x0] + (1-αs) * π(k)) + total_logit = first_logit + second_logit + + The routing reformulation decomposes this into 3-way routing weights. + This test checks that both produce exactly the same distribution over + output tokens, including the edge case where x0 == xt. + """ + + def _posterior_distribution( + self, xt, x0, alpha_s, alpha_t, invariant_probs_vec + ): + """Compute the exact posterior in probability space. + + p(x_s | x_t, x_0) ∝ p(x_t | x_s) * p(x_s | x_0) + + Evaluated for every x_s in {0, ..., M-1}. + + Args: + xt: Current token. + x0: Predicted clean token. + alpha_s: Diffusion schedule value at time s. + alpha_t: Diffusion schedule value at time t. + invariant_probs_vec: Invariant distribution. + + Returns: + The M-dimensional posterior distribution. + """ + voc_size = int(invariant_probs_vec.shape[0]) + ratio = alpha_t / alpha_s + + # Build unnormalized weight for each x_s value. + weights = [] + for xs in range(voc_size): + # p(x_t | x_s) = r * 1[xs=xt] + (1-r) * π(xt) + p_xt_given_xs = ratio * float(xs == xt) + (1.0 - ratio) * float( + invariant_probs_vec[xt] + ) + # p(x_s | x_0) = α_s * 1[xs=x0] + (1-α_s) * π(xs) + p_xs_given_x0 = alpha_s * float(xs == x0) + (1.0 - alpha_s) * float( + invariant_probs_vec[xs] + ) + weights.append(p_xt_given_xs * p_xs_given_x0) + + weights = jnp.array(weights) + return weights / jnp.sum(weights) + + def _routing_distribution( + self, xt, x0, alpha_s, alpha_t, invariant_probs_vec + ): + """Compute the routing-based posterior distribution. + + Mirrors the actual code in DiscreteDDIMStep.update. + + Args: + xt: Current token. + x0: Predicted clean token. + alpha_s: Diffusion schedule value at time s. + alpha_t: Diffusion schedule value at time t. + invariant_probs_vec: Invariant distribution. + + Returns: + The M-dimensional posterior distribution. + """ + ratio = alpha_t / alpha_s + pi_xt = float(invariant_probs_vec[xt]) + + # T2 → stay, T4 → noise, T3 → clean + p_stay = ratio * (1.0 - alpha_s) * pi_xt + p_noise = (1.0 - ratio) * (1.0 - alpha_s) * pi_xt + p_clean = (1.0 - ratio) * alpha_s * pi_xt + + # When x0 == xt, CLEAN is a no-op. Merge T1 and p_clean into p_stay. + if x0 == xt: + p_stay = p_stay + ratio * alpha_s + p_clean + p_clean = 0.0 + + total = p_stay + p_noise + p_clean + p_stay_norm = p_stay / total + p_noise_norm = p_noise / total + p_clean_norm = p_clean / total + + # Build the M-dimensional output distribution by marginalizing + # over the routing action: + # P(output=k) = P(STAY)*1[k=xt] + P(NOISE)*π(k) + P(CLEAN)*1[k=x0] + inv_probs = [float(p) for p in invariant_probs_vec] + dist = [p_noise_norm * inv_probs[k] for k in range(len(inv_probs))] + dist[xt] += p_stay_norm + dist[x0] += p_clean_norm + return jnp.array(dist) + + def test_equivalence_x0_neq_xt(self): + """Test routing matches posterior when x0 != xt.""" + voc_size = 5 + invariant_probs = jnp.array([0.1, 0.3, 0.2, 0.25, 0.15]) + + for xt_val in range(voc_size): + for x0_val in range(voc_size): + if x0_val == xt_val: + continue + for alpha_s_val in [0.2, 0.5, 0.8]: + alpha_t = 0.05 + p_exact = self._posterior_distribution( + xt_val, x0_val, alpha_s_val, alpha_t, invariant_probs + ) + p_route = self._routing_distribution( + xt_val, x0_val, alpha_s_val, alpha_t, invariant_probs + ) + chex.assert_trees_all_close(p_exact, p_route, atol=1e-6) + + def test_equivalence_x0_eq_xt(self): + """Test routing matches posterior when x0 == xt (the T1 cross-term).""" + voc_size = 5 + invariant_probs = jnp.array([0.1, 0.3, 0.2, 0.25, 0.15]) + + for xt_val in range(voc_size): + x0_val = xt_val + for alpha_s_val in [0.2, 0.5, 0.8]: + alpha_t = 0.05 + p_exact = self._posterior_distribution( + xt_val, x0_val, alpha_s_val, alpha_t, invariant_probs + ) + p_route = self._routing_distribution( + xt_val, x0_val, alpha_s_val, alpha_t, invariant_probs + ) + chex.assert_trees_all_close(p_exact, p_route, atol=1e-6) + + def test_equivalence_nonuniform_invariant(self): + """Test with a highly non-uniform invariant distribution.""" + voc_size = 3 + invariant_probs = jnp.array([0.01, 0.01, 0.98]) + + for xt_val in range(voc_size): + for x0_val in range(voc_size): + p_exact = self._posterior_distribution( + xt_val, x0_val, 0.3, 0.7, invariant_probs + ) + p_route = self._routing_distribution( + xt_val, x0_val, 0.3, 0.7, invariant_probs + ) + chex.assert_trees_all_close(p_exact, p_route, atol=1e-6) + + +class ApplyRoutingTest(absltest.TestCase): + """Tests for the _sample_routing helper.""" + + def test_deterministic_stay(self): + # routing_weights = [1, 0, 0] means stay. + routing_weights = discrete_step_sampler.RoutingWeights( + stay=jnp.array([[[1.0], [1.0]]]), + noise=jnp.array([[[0.0], [0.0]]]), + clean=jnp.array([[[0.0], [0.0]]]), + ) + xt = jnp.array([[[3], [5]]]) + x0 = jnp.array([[[0], [1]]]) + x_noise = jnp.array([[[2], [2]]]) + key = jax.random.PRNGKey(0) + + new_xt = discrete_step_sampler._sample_routing( + routing_weights=routing_weights, + xt=xt, + x0=x0, + x_noise=x_noise, + key=key, + ) + chex.assert_trees_all_equal(new_xt, xt) + + def test_deterministic_clean(self): + # routing_weights = [0, 0, 1] means jump to x0. + routing_weights = discrete_step_sampler.RoutingWeights( + stay=jnp.array([[[0.0], [0.0]]]), + noise=jnp.array([[[0.0], [0.0]]]), + clean=jnp.array([[[1.0], [1.0]]]), + ) + xt = jnp.array([[[3], [5]]]) + x0 = jnp.array([[[0], [1]]]) + x_noise = jnp.array([[[2], [2]]]) + key = jax.random.PRNGKey(0) + + new_xt = discrete_step_sampler._sample_routing( + routing_weights=routing_weights, + xt=xt, + x0=x0, + x_noise=x_noise, + key=key, + ) + chex.assert_trees_all_equal(new_xt, x0) + + def test_deterministic_noise(self): + # routing_weights = [0, 1, 0] means jump to noise. + routing_weights = discrete_step_sampler.RoutingWeights( + stay=jnp.array([[[0.0], [0.0]]]), + noise=jnp.array([[[1.0], [1.0]]]), + clean=jnp.array([[[0.0], [0.0]]]), + ) + xt = jnp.array([[[3], [5]]]) + x0 = jnp.array([[[0], [1]]]) + x_noise = jnp.array([[[2], [2]]]) + key = jax.random.PRNGKey(0) + + new_xt = discrete_step_sampler._sample_routing( + routing_weights=routing_weights, + xt=xt, + x0=x0, + x_noise=x_noise, + key=key, + ) + chex.assert_trees_all_equal(new_xt, x_noise) + + def test_mixed_routing(self): + # Position 0: deterministic stay, Position 1: deterministic clean. + routing_weights = discrete_step_sampler.RoutingWeights( + stay=jnp.array([[[1.0], [0.0]]]), + noise=jnp.array([[[0.0], [0.0]]]), + clean=jnp.array([[[0.0], [1.0]]]), + ) + xt = jnp.array([[[3], [5]]]) + x0 = jnp.array([[[0], [1]]]) + x_noise = jnp.array([[[2], [2]]]) + key = jax.random.PRNGKey(0) + + new_xt = discrete_step_sampler._sample_routing( + routing_weights=routing_weights, + xt=xt, + x0=x0, + x_noise=x_noise, + key=key, + ) + expected = jnp.array([[[3], [1]]]) + chex.assert_trees_all_equal(new_xt, expected) + + def test_stochastic_routing(self): + # 50/50 stay vs clean — results should vary across seeds. + routing_weights = discrete_step_sampler.RoutingWeights( + stay=jnp.array([[[0.5]]]), + noise=jnp.array([[[0.0]]]), + clean=jnp.array([[[0.5]]]), + ) + xt = jnp.array([[[3]]]) + x0 = jnp.array([[[0]]]) + x_noise = jnp.array([[[2]]]) + + results = set() + for seed in range(50): + new_xt = discrete_step_sampler._sample_routing( + routing_weights=routing_weights, + xt=xt, + x0=x0, + x_noise=x_noise, + key=jax.random.PRNGKey(seed), + ) + results.add(int(new_xt[0, 0, 0])) + + # Should see both stay (3) and clean (0). + self.assertIn(3, results) + self.assertIn(0, results) + + def test_routing_constants(self): + self.assertEqual(discrete_step_sampler.RoutingAction.STAY, 0) + self.assertEqual(discrete_step_sampler.RoutingAction.NOISE, 1) + self.assertEqual(discrete_step_sampler.RoutingAction.CLEAN, 2) + + +class PlannerProtocolTest(absltest.TestCase): + + def test_identity_planner(self): + + class IdentityPlanner: + + def __call__(self, routing_weights, logits, x0, xt, time, next_time, key): + return routing_weights + + planner = IdentityPlanner() + routing_weights = discrete_step_sampler.RoutingWeights( + stay=jnp.array([[[0.2]]]), + noise=jnp.array([[[0.3]]]), + clean=jnp.array([[[0.5]]]), + ) + # dummy args + logits = jnp.zeros((1, 1, 5)) + x0 = jnp.zeros((1, 1, 1)) + xt = jnp.zeros((1, 1, 1)) + time = jnp.array([1.0]) + next_time = jnp.array([0.5]) + key = jax.random.PRNGKey(0) + + out = planner(routing_weights, logits, x0, xt, time, next_time, key) + chex.assert_trees_all_equal(out, routing_weights) + + if __name__ == '__main__': absltest.main()