From 0a516324dae7c58b888c8fc2f530e50b81b2b491 Mon Sep 17 00:00:00 2001 From: Pedro Cuenca Date: Mon, 19 Sep 2022 17:29:53 +0000 Subject: [PATCH] PNDM: replace control flow with jax functions. Otherwise jitting/parallelization don't work properly as they don't know how to deal with traced objects. I temporarily removed `step_prk`. --- .../schedulers/scheduling_pndm_flax.py | 247 +++++++++++------- 1 file changed, 158 insertions(+), 89 deletions(-) diff --git a/src/diffusers/schedulers/scheduling_pndm_flax.py b/src/diffusers/schedulers/scheduling_pndm_flax.py index efc3858ca75a..7f6e8c6b38aa 100644 --- a/src/diffusers/schedulers/scheduling_pndm_flax.py +++ b/src/diffusers/schedulers/scheduling_pndm_flax.py @@ -19,6 +19,7 @@ from typing import Optional, Tuple, Union import flax +import jax import jax.numpy as jnp from ..configuration_utils import ConfigMixin, register_to_config @@ -150,7 +151,12 @@ def __init__( self.state = PNDMSchedulerState.create(num_train_timesteps=num_train_timesteps) - def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> PNDMSchedulerState: + def set_timesteps( + self, + state: PNDMSchedulerState, + shape: Tuple, + num_inference_steps: int + ) -> PNDMSchedulerState: """ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. @@ -191,8 +197,11 @@ def set_timesteps(self, state: PNDMSchedulerState, num_inference_steps: int) -> return state.replace( timesteps=jnp.concatenate([state.prk_timesteps, state.plms_timesteps]).astype(jnp.int64), - ets=jnp.array([]), counter=0, + # Will be zeros, not really empty + cur_model_output = jnp.empty(shape), + cur_sample = jnp.empty(shape), + ets = jnp.empty((4,) + shape), ) def step( @@ -222,73 +231,77 @@ def step( When returning a tuple, the first element is the sample tensor. """ - if state.counter < len(state.prk_timesteps) and not self.config.skip_prk_steps: - return self.step_prk( - state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict - ) - else: - return self.step_plms( - state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict - ) - - def step_prk( - self, - state: PNDMSchedulerState, - model_output: jnp.ndarray, - timestep: int, - sample: jnp.ndarray, - return_dict: bool = True, - ) -> Union[FlaxSchedulerOutput, Tuple]: - """ - Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the - solution to the differential equation. - - Args: - state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. - model_output (`jnp.ndarray`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`jnp.ndarray`): - current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - Returns: - [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. - When returning a tuple, the first element is the sample tensor. - - """ - if state.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2 - prev_timestep = timestep - diff_to_prev - timestep = state.prk_timesteps[state.counter // 4 * 4] - - if state.counter % 4 == 0: - state = state.replace( - cur_model_output=state.cur_model_output + 1 / 6 * model_output, - ets=state.ets.append(model_output), - cur_sample=sample, - ) - elif (self.counter - 1) % 4 == 0: - state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output) - elif (self.counter - 2) % 4 == 0: - state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output) - elif (self.counter - 3) % 4 == 0: - model_output = state.cur_model_output + 1 / 6 * model_output - state = state.replace(cur_model_output=0) - - # cur_sample should not be `None` - cur_sample = state.cur_sample if state.cur_sample is not None else sample - - prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) - state = state.replace(counter=state.counter + 1) - - if not return_dict: - return (prev_sample, state) + return self.step_plms( + state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict + ) - return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) + # if state.counter < len(state.prk_timesteps) and not self.config.skip_prk_steps: + # return self.step_prk( + # state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict + # ) + # else: + # return self.step_plms( + # state=state, model_output=model_output, timestep=timestep, sample=sample, return_dict=return_dict + # ) + + # def step_prk( + # self, + # state: PNDMSchedulerState, + # model_output: jnp.ndarray, + # timestep: int, + # sample: jnp.ndarray, + # return_dict: bool = True, + # ) -> Union[FlaxSchedulerOutput, Tuple]: + # """ + # Step function propagating the sample with the Runge-Kutta method. RK takes 4 forward passes to approximate the + # solution to the differential equation. + + # Args: + # state (`PNDMSchedulerState`): the `FlaxPNDMScheduler` state data class instance. + # model_output (`jnp.ndarray`): direct output from learned diffusion model. + # timestep (`int`): current discrete timestep in the diffusion chain. + # sample (`jnp.ndarray`): + # current instance of sample being created by diffusion process. + # return_dict (`bool`): option for returning tuple rather than SchedulerOutput class + + # Returns: + # [`FlaxSchedulerOutput`] or `tuple`: [`FlaxSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. + # When returning a tuple, the first element is the sample tensor. + + # """ + # if state.num_inference_steps is None: + # raise ValueError( + # "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + # ) + + # diff_to_prev = 0 if state.counter % 2 else self.config.num_train_timesteps // state.num_inference_steps // 2 + # prev_timestep = timestep - diff_to_prev + # timestep = state.prk_timesteps[state.counter // 4 * 4] + + # if state.counter % 4 == 0: + # state = state.replace( + # cur_model_output=state.cur_model_output + 1 / 6 * model_output, + # ets=state.ets.append(model_output), + # cur_sample=sample, + # ) + # elif (self.counter - 1) % 4 == 0: + # state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output) + # elif (self.counter - 2) % 4 == 0: + # state = state.replace(cur_model_output=state.cur_model_output + 1 / 3 * model_output) + # elif (self.counter - 3) % 4 == 0: + # model_output = state.cur_model_output + 1 / 6 * model_output + # state = state.replace(cur_model_output=0) + + # # cur_sample should not be `None` + # cur_sample = state.cur_sample if state.cur_sample is not None else sample + + # prev_sample = self._get_prev_sample(cur_sample, timestep, prev_timestep, model_output) + # state = state.replace(counter=state.counter + 1) + + # if not return_dict: + # return (prev_sample, state) + + # return FlaxSchedulerOutput(prev_sample=prev_sample, state=state) def step_plms( self, @@ -329,29 +342,85 @@ def step_plms( ) prev_timestep = timestep - self.config.num_train_timesteps // state.num_inference_steps + prev_timestep = jnp.where(prev_timestep > 0, prev_timestep, 0) + + # Reference: + # if state.counter != 1: + # state.ets.append(model_output) + # else: + # prev_timestep = timestep + # timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps + + prev_timestep = jnp.where(state.counter == 1, timestep, prev_timestep) + timestep = jnp.where(state.counter == 1, timestep + self.config.num_train_timesteps // state.num_inference_steps, timestep) + + # Reference: + # if len(state.ets) == 1 and state.counter == 0: + # model_output = model_output + # state.cur_sample = sample + # elif len(state.ets) == 1 and state.counter == 1: + # model_output = (model_output + state.ets[-1]) / 2 + # sample = state.cur_sample + # state.cur_sample = None + # elif len(state.ets) == 2: + # model_output = (3 * state.ets[-1] - state.ets[-2]) / 2 + # elif len(state.ets) == 3: + # model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12 + # else: + # model_output = (1 / 24) * (55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4]) + + def counter_0(state: PNDMSchedulerState): + ets = state.ets.at[0].set(model_output) + return state.replace( + ets = ets, + cur_sample = sample, + cur_model_output = jnp.array(model_output, dtype=jnp.float32), + ) - if state.counter != 1: - state = state.replace(ets=state.ets.append(model_output)) - else: - prev_timestep = timestep - timestep = timestep + self.config.num_train_timesteps // state.num_inference_steps - - if len(state.ets) == 1 and state.counter == 0: - model_output = model_output - state = state.replace(cur_sample=sample) - elif len(state.ets) == 1 and state.counter == 1: - model_output = (model_output + state.ets[-1]) / 2 - sample = state.cur_sample - state = state.replace(cur_sample=None) - elif len(state.ets) == 2: - model_output = (3 * state.ets[-1] - state.ets[-2]) / 2 - elif len(state.ets) == 3: - model_output = (23 * state.ets[-1] - 16 * state.ets[-2] + 5 * state.ets[-3]) / 12 - else: - model_output = (1 / 24) * ( - 55 * state.ets[-1] - 59 * state.ets[-2] + 37 * state.ets[-3] - 9 * state.ets[-4] + def counter_1(state: PNDMSchedulerState): + return state.replace( + cur_model_output = (model_output + state.ets[0]) / 2, ) + def counter_2(state: PNDMSchedulerState): + ets = state.ets.at[1].set(model_output) + return state.replace( + ets = ets, + cur_model_output = (3 * ets[1] - ets[0]) / 2, + cur_sample = sample, + ) + + def counter_3(state: PNDMSchedulerState): + ets = state.ets.at[2].set(model_output) + return state.replace( + ets = ets, + cur_model_output = (23 * ets[2] - 16 * ets[1] + 5 * ets[0]) / 12, + cur_sample = sample, + ) + + def counter_other(state: PNDMSchedulerState): + ets = state.ets.at[3].set(model_output) + next_model_output = (1 / 24) * (55 * ets[3] - 59 * ets[2] + 37 * ets[1] - 9 * ets[0]) + + ets = ets.at[0].set(ets[1]) + ets = ets.at[1].set(ets[2]) + ets = ets.at[2].set(ets[3]) + + return state.replace( + ets = ets, + cur_model_output = next_model_output, + cur_sample = sample, + ) + + counter = jnp.clip(state.counter, 0, 4) + state = jax.lax.switch( + counter, + [counter_0, counter_1, counter_2, counter_3, counter_other], + state, + ) + + sample = state.cur_sample + model_output = state.cur_model_output prev_sample = self._get_prev_sample(sample, timestep, prev_timestep, model_output) state = state.replace(counter=state.counter + 1) @@ -374,7 +443,7 @@ def _get_prev_sample(self, sample, timestep, prev_timestep, model_output): # model_output -> e_θ(x_t, t) # prev_sample -> x_(t−δ) alpha_prod_t = self.alphas_cumprod[timestep] - alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + alpha_prod_t_prev = jnp.where(prev_timestep >= 0, self.alphas_cumprod[prev_timestep], self.final_alpha_cumprod) beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev