In [None]:
import torch
import tqdm
from typing import Tuple, Dict, Callable, TypeVar, Optional, Any
from typing_extensions import Protocol, runtime_checkable
R = TypeVar("R")
T = TypeVar("T")

@runtime_checkable
class BatchedData(Protocol):
    def replace(self: T, **vals: torch.Tensor) -> T:
        """Return a copy of self with some fields replaced with new values."""

    def get_batch_idx(self, field_name: str) -> torch.LongTensor | None:
        """Get the batch index (i.e., which row belongs to which sample) for a given field.
        For 'dense' type data, where every sample has the same shape and the first dimension is the
        batch dimension, this method should return None. Mathematically,
        returning None will be treated the same as returning a tensor [0, 1, 2, ..., batch_size - 1]
        but I expect memory access in other functions to be more efficient if you return None.
        """

    def get_batch_size(self) -> int:
        """Get the batch size."""

    def device(self) -> torch.device:
        """Get the device of the batch."""

    def __getitem__(self, field_name: str) -> torch.Tensor:
        """Get a field from the batch."""

    def to(self: T, device: torch.device) -> T:
        """Move the batch to a given device."""

    def clone(self: T) -> T:
        """Return a copy with all the tensors cloned."""

In [None]:
Diffusable = TypeVar("Diffusable", bound=BatchedData)
SampleAndMean = Tuple[Diffusable, Diffusable]
SampleAndMeanAndMaybeRecords = Tuple[Diffusable, Diffusable, list[Diffusable] | None]
SampleAndMeanAndRecords = Tuple[Diffusable, Diffusable, list[Diffusable]]
B = Optional[torch.LongTensor]

def apply(fns: Dict[str, Callable[..., R]], broadcast, **kwargs) -> Dict[str, R]:
    """Apply different function with different argument values to each field.
    fns: dict of the form {field_name: function_to_apply}
    broadcast: arguments that are identical for every field_name
    kwargs: dict of the form {argument_name: {field_name: argument_value}}
    """
    return {
        field_name: fn(
            **{k: v[field_name] for k, v in kwargs.items() if field_name in v},
            **(broadcast or dict()),
        )
        for field_name, fn in fns.items()
    }


def broadcast_like(x, like):
    """
    add broadcast dimensions to x so that it can be broadcast over ``like``
    """
    if like is None:
        return x
    return x[(...,) + (None,) * (like.ndim - x.ndim)]


def maybe_expand(x: torch.Tensor, batch: B, like: torch.Tensor = None) -> torch.Tensor:
    """

    Args:
        x: shape (batch_size, ...)
        batch: shape (num_thingies,) with integer entries in the range [0, batch_size), indicating which sample each thingy belongs to
        like: shape x.shape + potential additional dimensions
    Returns:
        expanded x with shape (num_thingies,), or if given like.shape, containing value of x for each thingy.
        If `batch` is None, just returns `x` unmodified, to avoid pointless work if you have exactly one thingy per sample.
    """
    x = broadcast_like(x, like)
    if batch is None:
        return x
    else:
        if x.shape[0] == batch.shape[0]:
            print(
                "Warning: batch shape is == x shape, are you trying to expand something that is already expanded?"
            )
        return x[batch]
    

def wrap(x, wrapping_boundary):
    return torch.remainder(
        x, wrapping_boundary
    )

In [None]:
from enum import Enum
from typing import Mapping, Union


class ModelTarget(Enum):
    """Specifies what the score model is trained to predict.
    Only relevant for fields that are corrupted with an SDE."""

    score_times_std = "score_times_std"  # Predict -z where z is gaussian noise with unit variance used to corrupt the data
    logits = "logits"  # Predict logits for a categorical variable

model_target = "score_times_std"  # default value

In [None]:
def convert_model_out_to_score(
    *,
    model_target: ModelTarget,
    sde: SDE,
    model_out: torch.Tensor,
    batch_idx: torch.LongTensor,
    t: torch.Tensor,
    batch: Any
) -> torch.Tensor:
    """
    Convert a model output to a score, according to the specified model_target.

    model_target: says what the model predicts.
        For example, in RFDiffusion the model predicts clean coordinates;
        in EDM the model predicts the raw noise.
    sde: corruption process
    model_out: model output
    batch_idx: indicates which sample each row of model_out belongs to
    noisy_x: noisy data
    t: diffusion timestep
    batch: noisy batch, ignored except by strange SDEs
    """
    _, std = sde.marginal_prob(
        x=torch.ones_like(model_out),
        t=t,
        batch_idx=batch_idx,
        batch=batch,
    )
    # Note the slack tolerances in test_model_utils.py: the choice of ModelTarget does make a difference.
    if model_target == ModelTarget.score_times_std:
        return model_out / std
    elif model_target == ModelTarget.logits:
        # Not really a score, but logits will be handled downstream.
        return model_out
    else:
        raise NotImplementedError

In [None]:
def score_fn(x: T, t: torch.Tensor) -> T:
    """Calculate the score of a batch of data at a given timestep

    Args:
        x: batch of data
        t: timestep

    Returns:
        score: score of the batch of data at the given timestep
        注意训练时的score function已经乘了std
    """
    model_out: T = model(x, t)
    fns = {k: convert_model_out_to_score for k in corruption.sdes.keys()}

    scores = apply(
        fns=fns,
        model_out=model_out,
        broadcast=dict(t=t, batch=x),
        sde=corruption.sdes,
        model_target=model_target,
        batch_idx=corruption._get_batch_indices(x),
    )

    return model_out.replace(**scores)

In [None]:
import numpy as np
import abc

class Corruption(abc.ABC):
    """Abstract base class for corruption processes"""

    @property
    @abc.abstractmethod
    def T(self) -> float:
        """End time of the corruption process."""
        pass

    @abc.abstractmethod
    def marginal_prob(
        self,
        x: torch.Tensor,
        t: torch.Tensor,
        batch_idx: B = None,
        batch: Optional[BatchedData] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Parameters to determine the marginal distribution of the SDE, $p_t(x)$."""
        pass  # mean: (num_nodes, num_features), std (num_nodes,)

    @abc.abstractmethod
    def prior_sampling(
        self,
        shape: Union[torch.Size, Tuple],
        conditioning_data: Optional[BatchedData] = None,
        batch_idx: B = None,  # This is normally unused but is needed for special cases such as sample-wise zero-centering.
    ) -> torch.Tensor:
        """Generate one sample from the prior distribution, $p_T(x)$."""
        pass

    @abc.abstractmethod
    def prior_logp(
        self,
        z: torch.Tensor,
        batch_idx: B = None,
        batch: Optional[BatchedData] = None,
    ) -> torch.Tensor:
        """Compute log-density of the prior distribution.

        Useful for computing the log-likelihood via probability flow ODE.

        Args:
          z: latent code
        Returns:
          log probability density
        """
        pass  # prior_logp: (batch_size,)

    @abc.abstractmethod
    def sample_marginal(
        self,
        x: torch.Tensor,
        t: torch.Tensor,
        batch_idx: B = None,
        batch: Optional[BatchedData] = None,
    ) -> torch.Tensor:
        """Sample marginal for x(t) given x(0).
        Returns:
          sampled x(t) (same shape as input x).
        """
        pass

class SDE(Corruption):
    """Corruption using a stochastic differential equation."""

    @abc.abstractmethod
    def sde(
        self,
        x: torch.Tensor,
        t: torch.Tensor,
        batch_idx: B = None,
        batch: Optional[BatchedData] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Returns drift f and diffusion coefficient g such that dx = f * dt + g * sqrt(dt) * standard Gaussian"""
        pass  # drift: (nodes_per_sample * batch_size, num_features), diffusion (batch_size,)

    @abc.abstractmethod
    def marginal_prob(
        self,
        x: torch.Tensor,
        t: torch.Tensor,
        batch_idx: B = None,
        batch: Optional[BatchedData] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Returns mean and standard deviation of the marginal distribution of the SDE, $p_t(x)$."""
        pass  # mean: (nodes_per_sample * batch_size, num_features), std: (nodes_per_sample * batch_size, 1)

    def mean_coeff_and_std(
        self,
        x: torch.Tensor,
        t: torch.Tensor,
        batch_idx: B = None,
        batch: Optional[BatchedData] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Returns mean coefficient and standard deviation of marginal distribution at time t."""
        return self.marginal_prob(
            torch.ones_like(x), t, batch_idx, batch
        )  # mean_coeff: same shape as x, std: same shape as x

    def sample_marginal(
        self,
        x: torch.Tensor,
        t: torch.Tensor,
        batch_idx: B = None,
        batch: Optional[BatchedData] = None,
    ) -> torch.Tensor:
        """Sample marginal for x(t) given x(0).
        Returns:
          sampled x(t)
        """
        mean, std = self.marginal_prob(x=x, t=t, batch_idx=batch_idx, batch=batch)
        z = torch.randn_like(x)

        return mean + std * z

class VESDE(SDE):
    def __init__(self, sigma_min: float = 0.01, sigma_max: float = 50.0):
        """Construct a Variance Exploding SDE.

        The marginal standard deviation grows exponentially from sigma_min to sigma_max.

        Args:
          sigma_min: smallest sigma.
          sigma_max: largest sigma.
        """
        super().__init__()
        self.sigma_min = sigma_min
        self.sigma_max = sigma_max

    @property
    def T(self) -> float:
        return 1.0

    def sde(
        self,
        x: torch.Tensor,
        t: torch.Tensor,
        batch_idx: B = None,
        batch: Optional[BatchedData] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        sigma = self.sigma_min * (self.sigma_max / self.sigma_min) ** t
        drift = torch.zeros_like(x)
        diffusion = maybe_expand(
            sigma
            * torch.sqrt(
                torch.tensor(2 * (np.log(self.sigma_max) - np.log(self.sigma_min)), device=t.device)
            ),
            batch_idx,
            x,
        )
        return drift, diffusion

    def marginal_prob(
        self,
        x: torch.Tensor,
        t: torch.Tensor,
        batch_idx: B = None,
        batch: Optional[BatchedData] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        std = maybe_expand(self.sigma_min * (self.sigma_max / self.sigma_min) ** t, batch_idx, x)
        mean = x
        return mean, std

    def prior_sampling(
        self,
        shape: Union[torch.Size, Tuple],
        conditioning_data: Optional[BatchedData] = None,
        batch_idx: B = None,
    ) -> torch.Tensor:
        return torch.randn(*shape) * self.sigma_max

    def prior_logp(
        self,
        z: torch.Tensor,
        batch_idx: B = None,
        batch: Optional[BatchedData] = None,
    ) -> torch.Tensor:
        shape = z.shape
        N = np.prod(shape[1:])
        if batch_idx is not None:
            return -N / 2.0 * np.log(2 * np.pi * self.sigma_max**2) - scatter_add(
                torch.sum(z**2, dim=1), batch_idx
            ) / (2 * self.sigma_max**2)
        else:
            return -N / 2.0 * np.log(2 * np.pi * self.sigma_max**2) - torch.sum(
                z**2, dim=tuple(range(1, z.ndim))
            ) / (2 * self.sigma_max**2)

class WrappedSDEMixin:
    def sample_marginal(
        self,
        x: torch.Tensor,
        t: torch.Tensor,
        batch_idx: torch.LongTensor = None,
        batch: Optional[BatchedData] = None,
    ) -> torch.Tensor:
        _super = super()
        assert (
            isinstance(self, SDE)
            and hasattr(_super, "sample_marginal")
            and hasattr(self, "wrapping_boundary")
        )
        if (x > self.wrapping_boundary).any() or (x < 0).any():
            # Values outside the wrapping boundary are valid in principle, but could point to an issue in the data preprocessing,
            # as typically we assume that the input data is inside the wrapping boundary (e.g., angles between 0 and 2*pi).
            print("Warning: Wrapped SDE has received input outside of the wrapping boundary.")
        noisy_x = _super.sample_marginal(x=x, t=t, batch_idx=batch_idx, batch=batch)
        return self.wrap(noisy_x)

    def prior_sampling(
        self,
        shape: Union[torch.Size, Tuple],
        conditioning_data: Optional[BatchedData] = None,
        batch_idx: B = None,
    ) -> torch.Tensor:
        _super = super()
        assert isinstance(self, SDE) and hasattr(_super, "prior_sampling")
        return self.wrap(_super.prior_sampling(shape=shape, conditioning_data=conditioning_data))

    def wrap(self, x):
        assert isinstance(self, SDE) and hasattr(self, "wrapping_boundary")
        return wrap(x, self.wrapping_boundary)

class WrappedVESDE(WrappedSDEMixin, VESDE):
    def __init__(
        self,
        wrapping_boundary: float = 1.0,
        sigma_min: float = 0.01,
        sigma_max: float = 50.0,
    ):
        super().__init__(sigma_min=sigma_min, sigma_max=sigma_max)
        self.wrapping_boundary = wrapping_boundary

class NumAtomsVarianceAdjustedWrappedVESDE(WrappedVESDE):
    """Wrapped VESDE with variance adjusted by number of atoms. We divide the standard deviation by the cubic root of the number of atoms.
    The goal is to reduce the influence by the cell size on the variance of the fractional coordinates.
    """

    def __init__(
        self,
        wrapping_boundary: float | torch.Tensor = 1.0,
        sigma_min: float = 0.01,
        sigma_max: float = 5.0,
        limit_info_key: str = "num_atoms",
    ):
        super().__init__(
            sigma_min=sigma_min, sigma_max=sigma_max, wrapping_boundary=wrapping_boundary
        )
        self.limit_info_key = limit_info_key

    def std_scaling(self, batch: BatchedData) -> torch.Tensor:
        return batch[self.limit_info_key] ** (-1 / 3)

    def marginal_prob(
        self,
        x: torch.Tensor,
        t: torch.Tensor,
        batch_idx: B = None,
        batch: BatchedData | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        mean, std = super().marginal_prob(x, t, batch_idx, batch)
        assert (
            batch is not None
        ), "batch must be provided when using NumAtomsVarianceAdjustedWrappedVESDEMixin"
        std_scale = self.std_scaling(batch)
        std = std * maybe_expand(std_scale, batch_idx, like=std)
        return mean, std

    def prior_sampling(
        self,
        shape: torch.Size | tuple,
        conditioning_data: BatchedData | None = None,
        batch_idx=None,
    ) -> torch.Tensor:
        _super = super()
        assert isinstance(self, DiffSDE) and hasattr(_super, "prior_sampling")
        assert (
            conditioning_data is not None
        ), "batch must be provided when using NumAtomsVarianceAdjustedWrappedVESDEMixin"
        num_atoms = conditioning_data[self.limit_info_key]
        batch_idx = torch.repeat_interleave(
            torch.arange(num_atoms.shape[0], device=num_atoms.device), num_atoms, dim=0
        )
        std_scale = self.std_scaling(conditioning_data)
        # prior sample is randn() * sigma_max, so we need additionally multiply by std_scale to get the correct variance.
        # We call VESDE.prior_sampling (a "grandparent" function) because the super() prior_sampling already does the wrapping,
        # which means we couldn't do the variance adjustment here anymore otherwise.
        prior_sample = DiffVESDE.prior_sampling(self, shape=shape).to(num_atoms.device)
        return self.wrap(prior_sample * maybe_expand(std_scale, batch_idx, like=prior_sample))

    def sde(
        self,
        x: torch.Tensor,
        t: torch.Tensor,
        batch_idx: B = None,
        batch: BatchedData | None = None,
    ) -> tuple[torch.Tensor, torch.Tensor]:
        sigma = self.marginal_prob(x, t, batch_idx, batch)[1]
        sigma_min = self.marginal_prob(x, torch.zeros_like(t), batch_idx, batch)[1]
        sigma_max = self.marginal_prob(x, torch.ones_like(t), batch_idx, batch)[1]
        drift = torch.zeros_like(x)
        diffusion = sigma * torch.sqrt(2 * (sigma_max.log() - sigma_min.log()))
        return drift, diffusion

In [None]:
class AncestralSamplingPredictor():
    """Suitable for all linear SDEs.

    This predictor is derived by converting the score prediction to a prediction of x_0 given x_t, and then
    sampling from the conditional distribution of x_{t-dt} given x_0 and x_t according to the corruption process.
    It corresponds to equation (47) in Song et al. for VESDE (https://openreview.net/forum?id=PxTIG12RRHS)
    and equation (7) in Ho et al. for VPSDE (https://arxiv.org/abs/2006.11239)

    In more detail: suppose the SDE has marginals x_t ~ N(alpha_t *x_0, sigma_t**2)

    We estimate x_0 as follows:
    x_0 \approx (x_t + sigma_t^2 * score) / alpha_t

    For any s < t, the forward corruption process implies that
    x_t| x_s ~ N(alpha_t/alpha_s * x_s, sigma_t^2 - sigma_s^2 * alpha_t^2 / alpha_s^2)

    Now go away and do some algebra to get the mean and variance of x_s given x_t
    and x_0, and you will get the coefficients in the `update_given_score` method below.

    """

    def update_fn(
        self,
        *,
        x: torch.Tensor,
        t: torch.Tensor,
        dt: torch.Tensor,
        batch_idx: torch.LongTensor,
        batch: BatchedData | None,
    ) -> SampleAndMean:
        """One update of the predictor.

        Args:
          x: current state
          t: timesteps
          batch_idx: indicates which sample each row of x belongs to

        Returns:
           (sampled next state, mean next state)
        """
        assert self.score_fn is not None
        score = self.score_fn(x=x, t=t, batch_idx=batch_idx)
        return self.update_given_score(
            x=x, t=t, dt=dt, batch_idx=batch_idx, score=score, batch=batch
        )


    def update_given_score(
        self,
        *,
        x: torch.Tensor,
        t: torch.Tensor,
        dt: torch.Tensor,
        batch_idx: torch.LongTensor,
        score: torch.Tensor,
        batch: BatchedData | None,
    ) -> SampleAndMean:
        x_coeff, score_coeff, std = self._get_coeffs(
            x=x,
            t=t,
            dt=dt,
            batch_idx=batch_idx,
            batch=batch,
        )
        # Sample random noise.
        z = torch.randn_like(x_coeff)

        mean = x_coeff * x + score_coeff * score
        sample = mean + std * z

        return sample, mean

    def _get_coeffs(self, x, t, dt, batch_idx, batch):
        """
        Compute coefficients for ancestral sampling.
        This is in a separate method to make it easier to test."""
        sde = self.corruption
        assert isinstance(sde, SDE)

        # Previous timestep
        s = t + dt

        alpha_t, sigma_t = sde.mean_coeff_and_std(x=x, t=t, batch_idx=batch_idx, batch=batch)
        if batch_idx is None:
            is_time_zero = s <= 0
        else:
            is_time_zero = s[batch_idx] <= 0
        alpha_s, sigma_s = sde.mean_coeff_and_std(x=x, t=s, batch_idx=batch_idx, batch=batch)
        sigma_s[is_time_zero] = 0

        # If you are trying to match this up with algebra in papers, it may help to
        # notice that for VPSDE, sigma2_t_given_s == 1 - alpha_t_given_s**2, except
        # that alpha_t_given_s**2 is clipped.
        sigma2_t_given_s = sigma_t**2 - sigma_s**2 * alpha_t**2 / alpha_s**2
        sigma_t_given_s = torch.sqrt(sigma2_t_given_s)
        std = sigma_t_given_s * sigma_s / sigma_t

        # Clip alpha_t_given_s so that we do not divide by zero.
        min_alpha_t_given_s = 0.001
        alpha_t_given_s = alpha_t / alpha_s
        if torch.any(alpha_t_given_s < min_alpha_t_given_s):
            # If this warning is raised, you probably should change something: either modify your noise schedule
            # so that the diffusion coefficient does not blow up near sde.T, or only denoise from sde.T - eps,
            # rather than sde.T.
            alpha_t_given_s = torch.clip(alpha_t_given_s, min_alpha_t_given_s, 1)

        score_coeff = sigma2_t_given_s / alpha_t_given_s

        x_coeff = 1.0 / alpha_t_given_s

        std[is_time_zero] = 0

        return x_coeff, score_coeff, std
    
ASPridictor = AncestralSamplingPredictor()

In [None]:
from torch_scatter import scatter_add

class ScoreFunction(Protocol):
    def __call__(
        self,
        x: torch.Tensor,
        t: torch.Tensor,
        batch_idx: B = None,
    ) -> torch.Tensor:
        """Calculate score.

        Args:
            x: Samples at which the score should be calculated. Shape [num_nodes, ...]
            t: Timestep for each sample. Shape [num_samples,]
            batch_idx: Indicates which sample each row of x belongs to. Shape [num_nodes,]

        """
        pass


class LangevinCorrector():
    def __init__(
        self,
        corruption: Corruption,
        score_fn: ScoreFunction | None,
        n_steps: int,
        snr: float = 0.2,
        max_step_size: float = 1.0,
    ):
        """The Langevin corrector.

        Args:
            corruption: corruption process
            score_fn: score function
            n_steps: number of Langevin steps at each noise level
            snr: signal-to-noise ratio
            max_step_size: largest coefficient that the score can be multiplied by for each Langevin step.
        """
        self.corruption = corruption
        self.score_fn = score_fn
        self.n_steps = n_steps
        self.snr = snr
        self.max_step_size = torch.tensor(max_step_size)

    @classmethod
    def is_compatible(cls, corruption: Corruption):
        return (
            isinstance(corruption, (VESDE, BaseVPSDE))
            and super().is_compatible(corruption)
            and not isinstance(corruption, WrappedSDEMixin)
        )

    def update_fn(self, *, x, t, batch_idx, dt: torch.Tensor) -> SampleAndMean:
        assert self.score_fn is not None, "Did you mean to use step_given_score?"
        for _ in range(self.n_steps):
            score = self.score_fn(x, t, batch_idx)
            x, x_mean = self.step_given_score(x=x, batch_idx=batch_idx, score=score, t=t, dt=dt)

        return x, x_mean

    def get_alpha(self, t: torch.FloatTensor, dt: torch.FloatTensor) -> torch.Tensor:
        sde = self.corruption

        if isinstance(sde, VPSDE):
            alpha_bar = sde._marginal_mean_coeff(t) ** 2
            alpha_bar_before = sde._marginal_mean_coeff(t + dt) ** 2
            alpha = alpha_bar / alpha_bar_before
        else:
            alpha = torch.ones_like(t)
        return alpha

    def step_given_score(
        self, *, x, batch_idx: torch.LongTensor | None, score, t: torch.Tensor, dt: torch.Tensor
    ) -> SampleAndMean:
        alpha = self.get_alpha(t, dt=dt)
        snr = self.snr
        noise = torch.randn_like(score)
        grad_norm_square = torch.square(score).reshape(score.shape[0], -1).sum(dim=1)
        noise_norm_square = torch.square(noise).reshape(noise.shape[0], -1).sum(dim=1)
        if batch_idx is None:
            grad_norm = grad_norm_square.sqrt().mean()
            noise_norm = noise_norm_square.sqrt().mean()
        else:
            grad_norm = torch.sqrt(scatter_add(grad_norm_square, dim=-1, index=batch_idx)).mean()

            noise_norm = torch.sqrt(scatter_add(noise_norm_square, dim=-1, index=batch_idx)).mean()

        # If gradient is zero (i.e., we are sampling from an improper distribution that's flat over the whole of R^n)
        # the step_size blows up. Clip step_size to avoid this.
        # The EGNN reports zero scores when there are no edges between nodes.
        step_size = (snr * noise_norm / grad_norm) ** 2 * 2 * alpha
        step_size = torch.minimum(step_size, self.max_step_size)
        step_size[grad_norm == 0, :] = self.max_step_size

        # Expand step size to batch structure (score and noise have the same shape).
        step_size = maybe_expand(step_size, batch_idx, score)

        # Perform update, using custom update for SO(3) diffusion on frames.
        mean = x + step_size * score
        x = mean + torch.sqrt(step_size * 2) * noise

        return x, mean

LCorrector = LangevinCorrector()

In [None]:
class WrappedCorrectorMixin:
    """A mixin for wrapping the corrector in a WrappedSDE."""

    def step_given_score(
        self,
        *,
        x: torch.Tensor,
        batch_idx: torch.LongTensor,
        score: torch.Tensor,
        t: torch.Tensor,
        dt: torch.Tensor,
    ) -> SampleAndMean:
        # mypy
        assert isinstance(self, LangevinCorrector)
        _super = super()
        assert hasattr(_super, "step_given_score")
        assert hasattr(self, "corruption") and hasattr(self.corruption, "wrap")
        sample, mean = _super.step_given_score(x=x, score=score, t=t, batch_idx=batch_idx, dt=dt)
        return wrap(sample), wrap(mean)


class WrappedPredictorMixin:
    """A mixin for wrapping the predictor in a WrappedSDE."""

    def update_given_score(
        self,
        *,
        x: torch.Tensor,
        t: torch.Tensor,
        dt: torch.Tensor,
        batch_idx: torch.LongTensor,
        score: torch.Tensor,
        batch: Optional[BatchedData],
    ) -> SampleAndMean:

        _super = super()

        sample, mean = _super.update_given_score(
            x=x, t=t, dt=dt, batch_idx=batch_idx, score=score, batch=batch
        )
        return self.corruption.wrap(sample), self.corruption.wrap(mean)


class WrappedAncestralSamplingPredictor(
    WrappedPredictorMixin, AncestralSamplingPredictor
):
    @classmethod
    def is_compatible(cls, corruption: Corruption):
        return isinstance(corruption, (VESDE)) and isinstance(
            corruption, WrappedSDEMixin
        )


class WrappedLangevinCorrector(WrappedCorrectorMixin, LangevinCorrector):
    @classmethod
    def is_compatible(cls, corruption: Corruption):
        return isinstance(corruption, (VESDE)) and isinstance(
            corruption, WrappedSDEMixin
        )
    
WLCorrector = WrappedLangevinCorrector(LCorrector)
WASPredictor = WrappedAncestralSamplingPredictor(ASPridictor)

In [None]:
def mask_both(
    *, sample_and_mean: Tuple[torch.Tensor, torch.Tensor], old_x: torch.Tensor, mask: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
    return tuple(_mask(old_x=old_x, new_x=x, mask=mask) for x in sample_and_mean)  # type: ignore


def mask_replace(
    samples_means: dict[str, Tuple[torch.Tensor, torch.Tensor]],
    batch: BatchedData,
    mean_batch: BatchedData,
    mask: dict[str, torch.Tensor | None],
) -> SampleAndMean:
    # Apply masks
    samples_means = apply(
        fns={k: mask_both for k in samples_means},
        broadcast={},
        sample_and_mean=samples_means,
        mask=mask,
        old_x=batch,
    )

    # Put the updated values in `batch` and `mean_batch`
    batch = batch.replace(**{k: v[0] for k, v in samples_means.items()})
    mean_batch = mean_batch.replace(**{k: v[1] for k, v in samples_means.items()})
    return batch, mean_batch

def _get_batch_indices(self, batch: Diffusable) -> Dict[str, torch.Tensor]:
    return {k: batch.get_batch_idx(k) for k in self.corrupted_fields}


@torch.no_grad()
def denoise(
    batch: Diffusable,
    record: bool = False,
    max_t=1000, 
    eps_t=1, 
    device=torch.device("cpu"),
    N=50,
    n_steps_corrector=1,
) -> SampleAndMeanAndMaybeRecords:
    """Denoise from a prior sample to a t=eps_t sample."""

    mean_batch = batch.clone()

    # Decreasing timesteps from T to eps_t
    timesteps = torch.linspace(max_t, eps_t, N, device=device)
    dt = -torch.tensor((max_t - eps_t) / (N - 1)).to(device)

    for i in tqdm(range(N), miniters=50, mininterval=5):
        # Set the timestep
        t = torch.full((batch.get_batch_size(),), timesteps[i], device=device)

        # Corrector updates.
        for _ in range(n_steps_corrector):
            score = score_fn(batch, t)

            samples_means = WLCorrector.step_given_score(
                x=batch,
                t=t,
                dt=dt,
                batch_idx=_get_batch_indices(batch),
                score=score,
            )

        # Predictor updates
        score = score_fn(batch, t)

        samples_means = WASPredictor.update_given_score(x=batch, t=t, dt=dt, batch_idx=_get_batch_indices(batch), score=score)
        
    pass