Skip to content

Allowing other timestep_spacing heuristics for repaint scheduler. Leading in general is a problematic default. #12588

@MarkTension

Description

@MarkTension

Is your feature request related to a problem? Please describe.
I don't like that the repaint scheduler implements the timestep_spacing leading method by default, with no way to control it. The "leading" method is problematic for the average user: When the number of train timesteps isn't divisible by the num inference steps it produces counter-intuitive results. e.g. with 300 inference & 500 training timesteps, the timesteps array starts from 300 instead of 500. This is rarely desirable

Describe the solution you'd like.
The option to set timestep_spacing like other schedulers

Describe alternatives you've considered.
Currently I'm making a new class, inherit from RePaintScheduler, overwrite the set_timesteps method, and default to trailing:

class RePaintSchedulerTSSpacing(RePaintScheduler):
    def set_timesteps(
        self,
        num_inference_steps: int,
        jump_length: int = 10,
        jump_n_sample: int = 10,
        device: str | torch.device = None,
    ):
        """
        Sets the discrete timesteps used for the diffusion chain (to be run before inference).

        Args:
            num_inference_steps (`int`):
                The number of diffusion steps used when generating samples with a pre-trained model.
            jump_length (`int`, defaults to 10):
                The number of steps taken forward in time before going backward in time for a single jump.
            jump_n_sample (`int`, defaults to 10):
                The number of times to make a forward time jump for a given chosen time sample.
            device (`str` or `torch.device`, *optional*):
                The device to which the timesteps should be moved to.
        """
        num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps)
        self.num_inference_steps = num_inference_steps

        # 1. Generate the correctly spaced base timesteps
        if self.config.timestep_spacing == "linspace":
            base_timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps)
            base_timesteps = np.round(base_timesteps).astype(np.int64)
        elif self.config.timestep_spacing == "leading":
            step_ratio = self.config.num_train_timesteps // num_inference_steps
            base_timesteps = (np.arange(0, num_inference_steps) * step_ratio).round().astype(np.int64)
            steps_offset = getattr(self.config, "steps_offset", 0)
            base_timesteps += steps_offset
        elif self.config.timestep_spacing == "trailing":
            step_ratio = self.config.num_train_timesteps / num_inference_steps
            base_timesteps = np.round(np.arange(self.config.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
            base_timesteps -= 1
            base_timesteps = base_timesteps[::-1].copy()  # Ensure ascending order for index mapping
        else:
            raise ValueError(
                f"timestep_spacing must be one of ['linspace', 'trailing', 'leading'], got {self.config.timestep_spacing}"
            )

        # 2. Generate the jumping sequence of indices (0 to num_inference_steps-1)
        # This is the original RePaint logic, but on indices instead of final timesteps
        jump_indices = []
        jumps = {}
        for j in range(0, num_inference_steps - jump_length, jump_length):
            jumps[j] = jump_n_sample - 1

        t_idx = num_inference_steps
        while t_idx >= 1:
            t_idx = t_idx - 1
            jump_indices.append(t_idx)

            if jumps.get(t_idx, 0) > 0:
                jumps[t_idx] = jumps[t_idx] - 1
                for _ in range(jump_length):
                    t_idx = t_idx + 1
                    jump_indices.append(t_idx)

        # 3. Map the jumping indices to the correctly spaced base_timesteps
        timesteps = base_timesteps[jump_indices]
        self.timesteps = torch.from_numpy(timesteps).to(device)

I think this should be the fix. Can make a PR if this is considered OK

Additional context.
Visual comparison.

Image

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions