From 12372cc5a7201420b277bec3c6a5cc9d350e6e11 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Fri, 6 Oct 2023 15:03:24 -0700 Subject: [PATCH] replace references to deprecated KeyArray & PRNGKeyArray --- setup.py | 4 ++-- src/diffusers/dependency_versions_table.py | 4 ++-- src/diffusers/models/controlnet_flax.py | 2 +- src/diffusers/models/modeling_flax_utils.py | 2 +- src/diffusers/models/unet_2d_condition_flax.py | 2 +- src/diffusers/models/vae_flax.py | 2 +- .../pipelines/controlnet/pipeline_flax_controlnet.py | 6 +++--- .../stable_diffusion/pipeline_flax_stable_diffusion.py | 4 ++-- .../pipeline_flax_stable_diffusion_img2img.py | 6 +++--- .../pipeline_flax_stable_diffusion_inpaint.py | 4 ++-- .../pipelines/stable_diffusion/safety_checker_flax.py | 2 +- .../pipeline_flax_stable_diffusion_xl.py | 4 ++-- src/diffusers/schedulers/scheduling_ddpm_flax.py | 4 ++-- src/diffusers/schedulers/scheduling_karras_ve_flax.py | 3 ++- src/diffusers/schedulers/scheduling_sde_ve_flax.py | 5 +++-- 15 files changed, 28 insertions(+), 26 deletions(-) diff --git a/setup.py b/setup.py index a2201ac5b3b1..6cb5eac17b37 100644 --- a/setup.py +++ b/setup.py @@ -102,8 +102,8 @@ "importlib_metadata", "invisible-watermark>=0.2.0", "isort>=5.5.4", - "jax>=0.2.8,!=0.3.2", - "jaxlib>=0.1.65", + "jax>=0.4.1", + "jaxlib>=0.4.1", "Jinja2", "k-diffusion>=0.0.12", "torchsde", diff --git a/src/diffusers/dependency_versions_table.py b/src/diffusers/dependency_versions_table.py index d4b94ba6d4ed..970013c31a20 100644 --- a/src/diffusers/dependency_versions_table.py +++ b/src/diffusers/dependency_versions_table.py @@ -15,8 +15,8 @@ "importlib_metadata": "importlib_metadata", "invisible-watermark": "invisible-watermark>=0.2.0", "isort": "isort>=5.5.4", - "jax": "jax>=0.2.8,!=0.3.2", - "jaxlib": "jaxlib>=0.1.65", + "jax": "jax>=0.4.1", + "jaxlib": "jaxlib>=0.4.1", "Jinja2": "Jinja2", "k-diffusion": "k-diffusion>=0.0.12", "torchsde": "torchsde", diff --git a/src/diffusers/models/controlnet_flax.py b/src/diffusers/models/controlnet_flax.py index a826df48e41a..076e6183211b 100644 --- a/src/diffusers/models/controlnet_flax.py +++ b/src/diffusers/models/controlnet_flax.py @@ -168,7 +168,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin): controlnet_conditioning_channel_order: str = "rgb" conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256) - def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: + def init_weights(self, rng: jax.Array) -> FrozenDict: # init input tensors sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample = jnp.zeros(sample_shape, dtype=jnp.float32) diff --git a/src/diffusers/models/modeling_flax_utils.py b/src/diffusers/models/modeling_flax_utils.py index 97f7b43bc64e..ea4d1bfea548 100644 --- a/src/diffusers/models/modeling_flax_utils.py +++ b/src/diffusers/models/modeling_flax_utils.py @@ -192,7 +192,7 @@ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None): ```""" return self._cast_floating_to(params, jnp.float16, mask) - def init_weights(self, rng: jax.random.KeyArray) -> Dict: + def init_weights(self, rng: jax.Array) -> Dict: raise NotImplementedError(f"init_weights method has to be implemented for {self}") @classmethod diff --git a/src/diffusers/models/unet_2d_condition_flax.py b/src/diffusers/models/unet_2d_condition_flax.py index a3aebde7bf16..77ff08e40a37 100644 --- a/src/diffusers/models/unet_2d_condition_flax.py +++ b/src/diffusers/models/unet_2d_condition_flax.py @@ -126,7 +126,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin): addition_embed_type_num_heads: int = 64 projection_class_embeddings_input_dim: Optional[int] = None - def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: + def init_weights(self, rng: jax.Array) -> FrozenDict: # init input tensors sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample = jnp.zeros(sample_shape, dtype=jnp.float32) diff --git a/src/diffusers/models/vae_flax.py b/src/diffusers/models/vae_flax.py index b8f5b1d0e399..d2dde2ba197b 100644 --- a/src/diffusers/models/vae_flax.py +++ b/src/diffusers/models/vae_flax.py @@ -817,7 +817,7 @@ def setup(self): dtype=self.dtype, ) - def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict: + def init_weights(self, rng: jax.Array) -> FrozenDict: # init input tensors sample_shape = (1, self.in_channels, self.sample_size, self.sample_size) sample = jnp.zeros(sample_shape, dtype=jnp.float32) diff --git a/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py index b2c8871aa0d6..b57e776e49eb 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_flax_controlnet.py @@ -241,7 +241,7 @@ def _generate( prompt_ids: jnp.array, image: jnp.array, params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, + prng_seed: jax.Array, num_inference_steps: int, guidance_scale: float, latents: Optional[jnp.array] = None, @@ -351,7 +351,7 @@ def __call__( prompt_ids: jnp.array, image: jnp.array, params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, + prng_seed: jax.Array, num_inference_steps: int = 50, guidance_scale: Union[float, jnp.array] = 7.5, latents: jnp.array = None, @@ -370,7 +370,7 @@ def __call__( Array representing the ControlNet input condition to provide guidance to the `unet` for generation. params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights. - prng_seed (`jax.random.KeyArray` or `jax.Array`): + prng_seed (`jax.Array` or `jax.Array`): Array containing random number generator key. num_inference_steps (`int`, *optional*, defaults to 50): The number of denoising steps. More denoising steps usually lead to a higher quality image at the diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py index 131a7c7bc2bd..a847cd15c6ce 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py @@ -215,7 +215,7 @@ def _generate( self, prompt_ids: jnp.array, params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, + prng_seed: jax.Array, num_inference_steps: int, height: int, width: int, @@ -312,7 +312,7 @@ def __call__( self, prompt_ids: jnp.array, params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, + prng_seed: jax.Array, num_inference_steps: int = 50, height: Optional[int] = None, width: Optional[int] = None, diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py index a9717533fa93..42a79db6b2b2 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py @@ -235,7 +235,7 @@ def _generate( prompt_ids: jnp.array, image: jnp.array, params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, + prng_seed: jax.Array, start_timestep: int, num_inference_steps: int, height: int, @@ -340,7 +340,7 @@ def __call__( prompt_ids: jnp.array, image: jnp.array, params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, + prng_seed: jax.Array, strength: float = 0.8, num_inference_steps: int = 50, height: Optional[int] = None, @@ -361,7 +361,7 @@ def __call__( Array representing an image batch to be used as the starting point. params (`Dict` or `FrozenDict`): Dictionary containing the model parameters/weights. - prng_seed (`jax.random.KeyArray` or `jax.Array`): + prng_seed (`jax.Array` or `jax.Array`): Array containing random number generator key. strength (`float`, *optional*, defaults to 0.8): Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py index b43fa3837062..153267da1067 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py @@ -270,7 +270,7 @@ def _generate( mask: jnp.array, masked_image: jnp.array, params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, + prng_seed: jax.Array, num_inference_steps: int, height: int, width: int, @@ -398,7 +398,7 @@ def __call__( mask: jnp.array, masked_image: jnp.array, params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, + prng_seed: jax.Array, num_inference_steps: int = 50, height: Optional[int] = None, width: Optional[int] = None, diff --git a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py index 3a8c31679540..5966600462bf 100644 --- a/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py +++ b/src/diffusers/pipelines/stable_diffusion/safety_checker_flax.py @@ -87,7 +87,7 @@ def __init__( module = self.module_class(config=config, dtype=dtype, **kwargs) super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init) - def init_weights(self, rng: jax.random.KeyArray, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: + def init_weights(self, rng: jax.Array, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict: # init input tensor clip_input = jax.random.normal(rng, input_shape) diff --git a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py index 3acb5ae538a4..8f043c7c6657 100644 --- a/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +++ b/src/diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py @@ -89,7 +89,7 @@ def __call__( self, prompt_ids: jax.Array, params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, + prng_seed: jax.Array, num_inference_steps: int = 50, guidance_scale: Union[float, jax.Array] = 7.5, height: Optional[int] = None, @@ -170,7 +170,7 @@ def _generate( self, prompt_ids: jnp.array, params: Union[Dict, FrozenDict], - prng_seed: jax.random.KeyArray, + prng_seed: jax.Array, num_inference_steps: int, height: int, width: int, diff --git a/src/diffusers/schedulers/scheduling_ddpm_flax.py b/src/diffusers/schedulers/scheduling_ddpm_flax.py index 529d2bd03a75..ab7d70f466e6 100644 --- a/src/diffusers/schedulers/scheduling_ddpm_flax.py +++ b/src/diffusers/schedulers/scheduling_ddpm_flax.py @@ -198,7 +198,7 @@ def step( model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, - key: Optional[jax.random.KeyArray] = None, + key: Optional[jax.Array] = None, return_dict: bool = True, ) -> Union[FlaxDDPMSchedulerOutput, Tuple]: """ @@ -211,7 +211,7 @@ def step( timestep (`int`): current discrete timestep in the diffusion chain. sample (`jnp.ndarray`): current instance of sample being created by diffusion process. - key (`jax.random.KeyArray`): a PRNG key. + key (`jax.Array`): a PRNG key. return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class Returns: diff --git a/src/diffusers/schedulers/scheduling_karras_ve_flax.py b/src/diffusers/schedulers/scheduling_karras_ve_flax.py index 45c0dbddf7ef..4a8606007d5f 100644 --- a/src/diffusers/schedulers/scheduling_karras_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_karras_ve_flax.py @@ -17,6 +17,7 @@ from typing import Optional, Tuple, Union import flax +import jax import jax.numpy as jnp from jax import random @@ -139,7 +140,7 @@ def add_noise_to_input( state: KarrasVeSchedulerState, sample: jnp.ndarray, sigma: float, - key: random.KeyArray, + key: jax.Array, ) -> Tuple[jnp.ndarray, float]: """ Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a diff --git a/src/diffusers/schedulers/scheduling_sde_ve_flax.py b/src/diffusers/schedulers/scheduling_sde_ve_flax.py index b6240559fc88..935f972a9bdb 100644 --- a/src/diffusers/schedulers/scheduling_sde_ve_flax.py +++ b/src/diffusers/schedulers/scheduling_sde_ve_flax.py @@ -18,6 +18,7 @@ from typing import Optional, Tuple, Union import flax +import jax import jax.numpy as jnp from jax import random @@ -169,7 +170,7 @@ def step_pred( model_output: jnp.ndarray, timestep: int, sample: jnp.ndarray, - key: random.KeyArray, + key: jax.Array, return_dict: bool = True, ) -> Union[FlaxSdeVeOutput, Tuple]: """ @@ -228,7 +229,7 @@ def step_correct( state: ScoreSdeVeSchedulerState, model_output: jnp.ndarray, sample: jnp.ndarray, - key: random.KeyArray, + key: jax.Array, return_dict: bool = True, ) -> Union[FlaxSdeVeOutput, Tuple]: """