Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@
"importlib_metadata",
"invisible-watermark>=0.2.0",
"isort>=5.5.4",
"jax>=0.2.8,!=0.3.2",
"jaxlib>=0.1.65",
"jax>=0.4.1",
"jaxlib>=0.4.1",
"Jinja2",
"k-diffusion>=0.0.12",
"torchsde",
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
"importlib_metadata": "importlib_metadata",
"invisible-watermark": "invisible-watermark>=0.2.0",
"isort": "isort>=5.5.4",
"jax": "jax>=0.2.8,!=0.3.2",
"jaxlib": "jaxlib>=0.1.65",
"jax": "jax>=0.4.1",
"jaxlib": "jaxlib>=0.4.1",
"Jinja2": "Jinja2",
"k-diffusion": "k-diffusion>=0.0.12",
"torchsde": "torchsde",
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/controlnet_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ class FlaxControlNetModel(nn.Module, FlaxModelMixin, ConfigMixin):
controlnet_conditioning_channel_order: str = "rgb"
conditioning_embedding_out_channels: Tuple[int] = (16, 32, 96, 256)

def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
def init_weights(self, rng: jax.Array) -> FrozenDict:
# init input tensors
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/modeling_flax_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def to_fp16(self, params: Union[Dict, FrozenDict], mask: Any = None):
```"""
return self._cast_floating_to(params, jnp.float16, mask)

def init_weights(self, rng: jax.random.KeyArray) -> Dict:
def init_weights(self, rng: jax.Array) -> Dict:
raise NotImplementedError(f"init_weights method has to be implemented for {self}")

@classmethod
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/unet_2d_condition_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
addition_embed_type_num_heads: int = 64
projection_class_embeddings_input_dim: Optional[int] = None

def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
def init_weights(self, rng: jax.Array) -> FrozenDict:
# init input tensors
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
Expand Down
2 changes: 1 addition & 1 deletion src/diffusers/models/vae_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,7 +817,7 @@ def setup(self):
dtype=self.dtype,
)

def init_weights(self, rng: jax.random.KeyArray) -> FrozenDict:
def init_weights(self, rng: jax.Array) -> FrozenDict:
# init input tensors
sample_shape = (1, self.in_channels, self.sample_size, self.sample_size)
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def _generate(
prompt_ids: jnp.array,
image: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray,
prng_seed: jax.Array,
num_inference_steps: int,
guidance_scale: float,
latents: Optional[jnp.array] = None,
Expand Down Expand Up @@ -351,7 +351,7 @@ def __call__(
prompt_ids: jnp.array,
image: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray,
prng_seed: jax.Array,
num_inference_steps: int = 50,
guidance_scale: Union[float, jnp.array] = 7.5,
latents: jnp.array = None,
Expand All @@ -370,7 +370,7 @@ def __call__(
Array representing the ControlNet input condition to provide guidance to the `unet` for generation.
params (`Dict` or `FrozenDict`):
Dictionary containing the model parameters/weights.
prng_seed (`jax.random.KeyArray` or `jax.Array`):
prng_seed (`jax.Array` or `jax.Array`):
Array containing random number generator key.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def _generate(
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray,
prng_seed: jax.Array,
num_inference_steps: int,
height: int,
width: int,
Expand Down Expand Up @@ -312,7 +312,7 @@ def __call__(
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray,
prng_seed: jax.Array,
num_inference_steps: int = 50,
height: Optional[int] = None,
width: Optional[int] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ def _generate(
prompt_ids: jnp.array,
image: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray,
prng_seed: jax.Array,
start_timestep: int,
num_inference_steps: int,
height: int,
Expand Down Expand Up @@ -340,7 +340,7 @@ def __call__(
prompt_ids: jnp.array,
image: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray,
prng_seed: jax.Array,
strength: float = 0.8,
num_inference_steps: int = 50,
height: Optional[int] = None,
Expand All @@ -361,7 +361,7 @@ def __call__(
Array representing an image batch to be used as the starting point.
params (`Dict` or `FrozenDict`):
Dictionary containing the model parameters/weights.
prng_seed (`jax.random.KeyArray` or `jax.Array`):
prng_seed (`jax.Array` or `jax.Array`):
Array containing random number generator key.
strength (`float`, *optional*, defaults to 0.8):
Indicates extent to transform the reference `image`. Must be between 0 and 1. `image` is used as a
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ def _generate(
mask: jnp.array,
masked_image: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray,
prng_seed: jax.Array,
num_inference_steps: int,
height: int,
width: int,
Expand Down Expand Up @@ -398,7 +398,7 @@ def __call__(
mask: jnp.array,
masked_image: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray,
prng_seed: jax.Array,
num_inference_steps: int = 50,
height: Optional[int] = None,
width: Optional[int] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(
module = self.module_class(config=config, dtype=dtype, **kwargs)
super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype, _do_init=_do_init)

def init_weights(self, rng: jax.random.KeyArray, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
def init_weights(self, rng: jax.Array, input_shape: Tuple, params: FrozenDict = None) -> FrozenDict:
# init input tensor
clip_input = jax.random.normal(rng, input_shape)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __call__(
self,
prompt_ids: jax.Array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray,
prng_seed: jax.Array,
num_inference_steps: int = 50,
guidance_scale: Union[float, jax.Array] = 7.5,
height: Optional[int] = None,
Expand Down Expand Up @@ -170,7 +170,7 @@ def _generate(
self,
prompt_ids: jnp.array,
params: Union[Dict, FrozenDict],
prng_seed: jax.random.KeyArray,
prng_seed: jax.Array,
num_inference_steps: int,
height: int,
width: int,
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/schedulers/scheduling_ddpm_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def step(
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
key: Optional[jax.random.KeyArray] = None,
key: Optional[jax.Array] = None,
return_dict: bool = True,
) -> Union[FlaxDDPMSchedulerOutput, Tuple]:
"""
Expand All @@ -211,7 +211,7 @@ def step(
timestep (`int`): current discrete timestep in the diffusion chain.
sample (`jnp.ndarray`):
current instance of sample being created by diffusion process.
key (`jax.random.KeyArray`): a PRNG key.
key (`jax.Array`): a PRNG key.
return_dict (`bool`): option for returning tuple rather than FlaxDDPMSchedulerOutput class

Returns:
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/schedulers/scheduling_karras_ve_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from typing import Optional, Tuple, Union

import flax
import jax
import jax.numpy as jnp
from jax import random

Expand Down Expand Up @@ -139,7 +140,7 @@ def add_noise_to_input(
state: KarrasVeSchedulerState,
sample: jnp.ndarray,
sigma: float,
key: random.KeyArray,
key: jax.Array,
) -> Tuple[jnp.ndarray, float]:
"""
Explicit Langevin-like "churn" step of adding noise to the sample according to a factor gamma_i ≥ 0 to reach a
Expand Down
5 changes: 3 additions & 2 deletions src/diffusers/schedulers/scheduling_sde_ve_flax.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Optional, Tuple, Union

import flax
import jax
import jax.numpy as jnp
from jax import random

Expand Down Expand Up @@ -169,7 +170,7 @@ def step_pred(
model_output: jnp.ndarray,
timestep: int,
sample: jnp.ndarray,
key: random.KeyArray,
key: jax.Array,
return_dict: bool = True,
) -> Union[FlaxSdeVeOutput, Tuple]:
"""
Expand Down Expand Up @@ -228,7 +229,7 @@ def step_correct(
state: ScoreSdeVeSchedulerState,
model_output: jnp.ndarray,
sample: jnp.ndarray,
key: random.KeyArray,
key: jax.Array,
return_dict: bool = True,
) -> Union[FlaxSdeVeOutput, Tuple]:
"""
Expand Down