diff --git a/keras_hub/src/models/image_to_image.py b/keras_hub/src/models/image_to_image.py index d3194a5815..5684fc64d6 100644 --- a/keras_hub/src/models/image_to_image.py +++ b/keras_hub/src/models/image_to_image.py @@ -284,8 +284,8 @@ def generate( self, inputs, num_steps, - guidance_scale, strength, + guidance_scale=None, seed=None, ): """Generate image based on the provided `inputs`. @@ -313,30 +313,36 @@ def generate( - A `tf.data.Dataset` with `"images"`, `"prompts"` and/or `"negative_prompts"` keys. num_steps: int. The number of diffusion steps to take. - guidance_scale: float. The classifier free guidance scale defined in - [Classifier-Free Diffusion Guidance]( - https://arxiv.org/abs/2207.12598). A higher scale encourages - generating images more closely related to the prompts, typically - at the cost of lower image quality. strength: float. Indicates the extent to which the reference `images` are transformed. Must be between `0.0` and `1.0`. When `strength=1.0`, `images` is essentially ignore and added noise is maximum and the denoising process runs for the full number of iterations specified in `num_steps`. + guidance_scale: Optional float. The classifier free guidance scale + defined in [Classifier-Free Diffusion Guidance]( + https://arxiv.org/abs/2207.12598). A higher scale encourages + generating images more closely related to the prompts, typically + at the cost of lower image quality. Note that some models don't + utilize classifier-free guidance. seed: optional int. Used as a random seed. """ num_steps = int(num_steps) - guidance_scale = float(guidance_scale) strength = float(strength) + guidance_scale = ( + float(guidance_scale) if guidance_scale is not None else None + ) if strength < 0.0 or strength > 1.0: raise ValueError( "`strength` must be between `0.0` and `1.0`. " f"Received strength={strength}." ) + if guidance_scale is not None and guidance_scale > 1.0: + guidance_scale = ops.convert_to_tensor(float(guidance_scale)) + else: + guidance_scale = None starting_step = int(num_steps * (1.0 - strength)) starting_step = ops.convert_to_tensor(starting_step, "int32") - num_steps = ops.convert_to_tensor(num_steps, "int32") - guidance_scale = ops.convert_to_tensor(guidance_scale) + num_steps = ops.convert_to_tensor(int(num_steps), "int32") # Check `inputs` format. required_keys = ["images", "prompts"] diff --git a/keras_hub/src/models/inpaint.py b/keras_hub/src/models/inpaint.py index 40bcc7ad15..f0ccb3a5b2 100644 --- a/keras_hub/src/models/inpaint.py +++ b/keras_hub/src/models/inpaint.py @@ -376,8 +376,8 @@ def generate( self, inputs, num_steps, - guidance_scale, strength, + guidance_scale=None, seed=None, ): """Generate image based on the provided `inputs`. @@ -406,26 +406,33 @@ def generate( - A `tf.data.Dataset` with `"images"`, `"masks"`, `"prompts"` and/or `"negative_prompts"` keys. num_steps: int. The number of diffusion steps to take. - guidance_scale: float. The classifier free guidance scale defined in - [Classifier-Free Diffusion Guidance]( - https://arxiv.org/abs/2207.12598). A higher scale encourages - generating images more closely related to the prompts, typically - at the cost of lower image quality. strength: float. Indicates the extent to which the reference `images` are transformed. Must be between `0.0` and `1.0`. When `strength=1.0`, `images` is essentially ignore and added noise is maximum and the denoising process runs for the full number of iterations specified in `num_steps`. + guidance_scale: Optional float. The classifier free guidance scale + defined in [Classifier-Free Diffusion Guidance]( + https://arxiv.org/abs/2207.12598). A higher scale encourages + generating images more closely related to the prompts, typically + at the cost of lower image quality. Note that some models don't + utilize classifier-free guidance. seed: optional int. Used as a random seed. """ num_steps = int(num_steps) - guidance_scale = float(guidance_scale) strength = float(strength) + guidance_scale = ( + float(guidance_scale) if guidance_scale is not None else None + ) if strength < 0.0 or strength > 1.0: raise ValueError( "`strength` must be between `0.0` and `1.0`. " f"Received strength={strength}." ) + if guidance_scale is not None and guidance_scale > 1.0: + guidance_scale = ops.convert_to_tensor(guidance_scale) + else: + guidance_scale = None starting_step = int(num_steps * (1.0 - strength)) starting_step = ops.convert_to_tensor(starting_step, "int32") num_steps = ops.convert_to_tensor(num_steps, "int32") diff --git a/keras_hub/src/models/stable_diffusion_3/mmdit.py b/keras_hub/src/models/stable_diffusion_3/mmdit.py index 546d56f13a..083e4a359a 100644 --- a/keras_hub/src/models/stable_diffusion_3/mmdit.py +++ b/keras_hub/src/models/stable_diffusion_3/mmdit.py @@ -354,6 +354,8 @@ class DismantledBlock(layers.Layer): mlp_ratio: float. The expansion ratio of `MLP`. use_projection: bool. Whether to use an attention projection layer at the end of the block. + qk_norm: Optional str. Whether to normalize the query and key tensors. + Available options are `None` and `"rms_norm"`. Defaults to `None`. **kwargs: other keyword arguments passed to `keras.layers.Layer`, including `name`, `dtype` etc. """ @@ -364,6 +366,7 @@ def __init__( hidden_dim, mlp_ratio=4.0, use_projection=True, + qk_norm=None, **kwargs, ): super().__init__(**kwargs) @@ -371,6 +374,7 @@ def __init__( self.hidden_dim = hidden_dim self.mlp_ratio = mlp_ratio self.use_projection = use_projection + self.qk_norm = qk_norm head_dim = hidden_dim // num_heads self.head_dim = head_dim @@ -391,6 +395,18 @@ def __init__( self.attention_qkv = layers.Dense( hidden_dim * 3, dtype=self.dtype_policy, name="attention_qkv" ) + if qk_norm is not None and qk_norm == "rms_norm": + self.q_norm = layers.LayerNormalization( + epsilon=1e-6, rms_scaling=True, dtype="float32", name="q_norm" + ) + self.k_norm = layers.LayerNormalization( + epsilon=1e-6, rms_scaling=True, dtype="float32", name="q_norm" + ) + elif qk_norm is not None: + raise NotImplementedError( + "Supported `qk_norm` are `'rms_norm'` and `None`. " + f"Received: qk_norm={qk_norm}." + ) if use_projection: self.attention_proj = layers.Dense( hidden_dim, dtype=self.dtype_policy, name="attention_proj" @@ -413,6 +429,10 @@ def __init__( def build(self, inputs_shape, timestep_embedding): self.ada_layer_norm.build(inputs_shape, timestep_embedding) self.attention_qkv.build(inputs_shape) + if self.qk_norm is not None: + # [batch_size, sequence_length, num_heads, head_dim] + self.q_norm.build([None, None, self.num_heads, self.head_dim]) + self.k_norm.build([None, None, self.num_heads, self.head_dim]) if self.use_projection: self.attention_proj.build(inputs_shape) self.norm2.build(inputs_shape) @@ -435,6 +455,9 @@ def _compute_pre_attention(self, inputs, timestep_embedding, training=None): qkv, (batch_size, -1, 3, self.num_heads, self.head_dim) ) q, k, v = ops.unstack(qkv, 3, axis=2) + if self.qk_norm is not None: + q = self.q_norm(q, training=training) + k = self.k_norm(k, training=training) return (q, k, v), (inputs, gate_msa, shift_mlp, scale_mlp, gate_mlp) else: x = self.ada_layer_norm( @@ -445,6 +468,9 @@ def _compute_pre_attention(self, inputs, timestep_embedding, training=None): qkv, (batch_size, -1, 3, self.num_heads, self.head_dim) ) q, k, v = ops.unstack(qkv, 3, axis=2) + if self.qk_norm is not None: + q = self.q_norm(q, training=training) + k = self.k_norm(k, training=training) return (q, k, v) def _compute_post_attention( @@ -494,6 +520,7 @@ def get_config(self): "hidden_dim": self.hidden_dim, "mlp_ratio": self.mlp_ratio, "use_projection": self.use_projection, + "qk_norm": self.qk_norm, } ) return config @@ -513,6 +540,8 @@ class MMDiTBlock(layers.Layer): mlp_ratio: float. The expansion ratio of `MLP`. use_context_projection: bool. Whether to use an attention projection layer at the end of the context block. + qk_norm: Optional str. Whether to normalize the query and key tensors. + Available options are `None` and `"rms_norm"`. Defaults to `None`. **kwargs: other keyword arguments passed to `keras.layers.Layer`, including `name`, `dtype` etc. @@ -527,6 +556,7 @@ def __init__( hidden_dim, mlp_ratio=4.0, use_context_projection=True, + qk_norm=None, **kwargs, ): super().__init__(**kwargs) @@ -534,6 +564,7 @@ def __init__( self.hidden_dim = hidden_dim self.mlp_ratio = mlp_ratio self.use_context_projection = use_context_projection + self.qk_norm = qk_norm head_dim = hidden_dim // num_heads self.head_dim = head_dim @@ -544,6 +575,7 @@ def __init__( hidden_dim=hidden_dim, mlp_ratio=mlp_ratio, use_projection=True, + qk_norm=qk_norm, dtype=self.dtype_policy, name="x_block", ) @@ -552,6 +584,7 @@ def __init__( hidden_dim=hidden_dim, mlp_ratio=mlp_ratio, use_projection=use_context_projection, + qk_norm=qk_norm, dtype=self.dtype_policy, name="context_block", ) @@ -629,6 +662,7 @@ def get_config(self): "hidden_dim": self.hidden_dim, "mlp_ratio": self.mlp_ratio, "use_context_projection": self.use_context_projection, + "qk_norm": self.qk_norm, } ) return config @@ -705,6 +739,9 @@ class MMDiT(Backbone): latent_shape: tuple. The shape of the latent image. context_shape: tuple. The shape of the context. pooled_projection_shape: tuple. The shape of the pooled projection. + qk_norm: Optional str. Whether to normalize the query and key tensors in + the intermediate blocks. Available options are `None` and + `"rms_norm"`. Defaults to `None`. data_format: `None` or str. If specified, either `"channels_last"` or `"channels_first"`. The ordering of the dimensions in the inputs. `"channels_last"` corresponds to inputs with shape @@ -729,6 +766,7 @@ def __init__( latent_shape=(64, 64, 16), context_shape=(None, 4096), pooled_projection_shape=(2048,), + qk_norm=None, data_format=None, dtype=None, **kwargs, @@ -782,6 +820,7 @@ def __init__( hidden_dim, mlp_ratio, use_context_projection=not (i == num_layers - 1), + qk_norm=qk_norm, dtype=dtype, name=f"joint_block_{i}", ) @@ -851,6 +890,7 @@ def __init__( self.latent_shape = latent_shape self.context_shape = context_shape self.pooled_projection_shape = pooled_projection_shape + self.qk_norm = qk_norm def get_config(self): config = super().get_config() @@ -865,6 +905,7 @@ def get_config(self): "latent_shape": self.latent_shape, "context_shape": self.context_shape, "pooled_projection_shape": self.pooled_projection_shape, + "qk_norm": self.qk_norm, } ) return config diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py index 4dd3e4403d..65ba12a549 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone.py @@ -202,6 +202,10 @@ class StableDiffusion3Backbone(Backbone): transformer in MMDiT. mmdit_position_size: int. The size of the height and width for the position embedding in MMDiT. + mmdit_qk_norm: Optional str. Whether to normalize the query and key + tensors for each transformer in MMDiT. Available options are `None` + and `"rms_norm"`. Typically, this is set to `None` for 3.0 version + and to `"rms_norm" for 3.5 version. vae: The VAE used for transformations between pixel space and latent space. clip_l: The CLIP text encoder for encoding the inputs. @@ -248,6 +252,7 @@ class StableDiffusion3Backbone(Backbone): mmdit_hidden_dim=256, mmdit_depth=4, mmdit_position_size=192, + mmdit_qk_norm=None, vae=vae, clip_l=clip_l, clip_g=clip_g, @@ -262,6 +267,7 @@ def __init__( mmdit_num_layers, mmdit_num_heads, mmdit_position_size, + mmdit_qk_norm, vae, clip_l, clip_g, @@ -312,6 +318,7 @@ def __init__( latent_shape=latent_shape, context_shape=context_shape, pooled_projection_shape=pooled_projection_shape, + qk_norm=mmdit_qk_norm, data_format=data_format, dtype=dtype, name="diffuser", @@ -446,6 +453,7 @@ def __init__( self.mmdit_num_layers = mmdit_num_layers self.mmdit_num_heads = mmdit_num_heads self.mmdit_position_size = mmdit_position_size + self.mmdit_qk_norm = mmdit_qk_norm self.latent_channels = latent_channels self.output_channels = output_channels self.num_train_timesteps = num_train_timesteps @@ -532,7 +540,7 @@ def denoise_step( embeddings, step, num_steps, - guidance_scale, + guidance_scale=None, ): step = ops.convert_to_tensor(step) next_step = ops.add(step, 1) @@ -540,9 +548,15 @@ def denoise_step( next_sigma, _ = self.scheduler(next_step, num_steps) # Concatenation for classifier-free guidance. - concated_latents, contexts, pooled_projs, timesteps = self.cfg_concat( - latents, *embeddings, timestep - ) + if guidance_scale is not None: + concated_latents, contexts, pooled_projs, timesteps = ( + self.cfg_concat(latents, *embeddings, timestep) + ) + else: + timesteps = ops.broadcast_to(timestep, ops.shape(latents)[:1]) + concated_latents = latents + contexts = embeddings[0] + pooled_projs = embeddings[2] # Diffusion. predicted_noise = self.diffuser( @@ -556,7 +570,8 @@ def denoise_step( ) # Classifier-free guidance. - predicted_noise = self.cfg(predicted_noise, guidance_scale) + if guidance_scale is not None: + predicted_noise = self.cfg(predicted_noise, guidance_scale) # Euler step. return self.euler_step(latents, predicted_noise, sigma, next_sigma) @@ -574,6 +589,7 @@ def get_config(self): "mmdit_num_layers": self.mmdit_num_layers, "mmdit_num_heads": self.mmdit_num_heads, "mmdit_position_size": self.mmdit_position_size, + "mmdit_qk_norm": self.mmdit_qk_norm, "vae": layers.serialize(self.vae), "clip_l": layers.serialize(self.clip_l), "clip_g": layers.serialize(self.clip_g), @@ -620,4 +636,9 @@ def from_config(cls, config, custom_objects=None): config["t5"] = layers.deserialize( config["t5"], custom_objects=custom_objects ) + + # To maintain backward compatibility, we need to ensure that + # `mmdit_qk_norm` is included in the config. + if "mmdit_qk_norm" not in config: + config["mmdit_qk_norm"] = None return cls(**config) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone_test.py index 77415a6eec..500836368f 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone_test.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_backbone_test.py @@ -34,6 +34,7 @@ def setUp(self): "mmdit_num_layers": 2, "mmdit_num_heads": 2, "mmdit_position_size": 192, + "mmdit_qk_norm": None, "vae": vae, "clip_l": clip_l, "clip_g": clip_g, @@ -66,6 +67,22 @@ def test_backbone_basics(self): run_quantization_check=False, ) + # Test `mmdit_qk_norm="rms_norm"`. + self.run_backbone_test( + cls=StableDiffusion3Backbone, + init_kwargs={**self.init_kwargs, "mmdit_qk_norm": "rms_norm"}, + input_data=self.input_data, + expected_output_shape={ + "images": (2, 64, 64, 3), + "latents": (2, 8, 8, 16), + }, + # Since `clip_l` and `clip_g` were instantiated outside of + # `StableDiffusion3Backbone`, the mixed precision and + # quantization checks will fail. + run_mixed_precision_check=False, + run_quantization_check=False, + ) + @pytest.mark.large def test_saved_model(self): self.run_model_saving_test( diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py index 285ba834b4..dfd26d54c2 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image.py @@ -158,14 +158,14 @@ def generate( self, inputs, num_steps=50, - guidance_scale=7.0, strength=0.8, + guidance_scale=7.0, seed=None, ): return super().generate( inputs, num_steps=num_steps, - guidance_scale=guidance_scale, strength=strength, + guidance_scale=guidance_scale, seed=seed, ) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py index 8fa5b167ab..16f40caaa4 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py @@ -40,6 +40,7 @@ def setUp(self): mmdit_num_layers=2, mmdit_num_heads=2, mmdit_position_size=192, + mmdit_qk_norm=None, vae=VAEBackbone( [32, 32, 32, 32], [1, 1, 1, 1], diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py index 8d5ed7c6af..f6ec50058f 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint.py @@ -181,14 +181,14 @@ def generate( self, inputs, num_steps=50, - guidance_scale=7.0, strength=0.6, + guidance_scale=7.0, seed=None, ): return super().generate( inputs, num_steps=num_steps, - guidance_scale=guidance_scale, strength=strength, + guidance_scale=guidance_scale, seed=seed, ) diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py index 5e8ddd32c6..4b01af40d3 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py @@ -40,6 +40,7 @@ def setUp(self): mmdit_num_layers=2, mmdit_num_heads=2, mmdit_position_size=192, + mmdit_qk_norm=None, vae=VAEBackbone( [32, 32, 32, 32], [1, 1, 1, 1], diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py index a7756fc645..5ce7a904ce 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_presets.py @@ -11,8 +11,38 @@ "params": 2987080931, "official_name": "StableDiffusion3", "path": "stable_diffusion_3", - "model_card": "https://arxiv.org/abs/2110.00476", + "model_card": "https://huggingface.co/stabilityai/stable-diffusion-3-medium", }, "kaggle_handle": "kaggle://keras/stablediffusion3/keras/stable_diffusion_3_medium/3", - } + }, + "stable_diffusion_3.5_large": { + "metadata": { + "description": ( + "9 billion parameter, including CLIP L and CLIP G text " + "encoders, MMDiT generative model, and VAE autoencoder. " + "Developed by Stability AI." + ), + "params": 9048410595, + "official_name": "StableDiffusion3", + "path": "stable_diffusion_3", + "model_card": "https://huggingface.co/stabilityai/stable-diffusion-3.5-large", + }, + "kaggle_handle": "kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3.5_large/1", + }, + "stable_diffusion_3.5_large_turbo": { + "metadata": { + "description": ( + "9 billion parameter, including CLIP L and CLIP G text " + "encoders, MMDiT generative model, and VAE autoencoder. " + "A timestep-distilled version that eliminates classifier-free " + "guidance and uses fewer steps for generation. " + "Developed by Stability AI." + ), + "params": 9048410595, + "official_name": "StableDiffusion3", + "path": "stable_diffusion_3", + "model_card": "https://huggingface.co/stabilityai/stable-diffusion-3.5-large-turbo", + }, + "kaggle_handle": "kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3.5_large/1", + }, } diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py index 69d30de834..2f11f43339 100644 --- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py +++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py @@ -40,6 +40,7 @@ def setUp(self): mmdit_num_layers=2, mmdit_num_heads=2, mmdit_position_size=192, + mmdit_qk_norm=None, vae=VAEBackbone( [32, 32, 32, 32], [1, 1, 1, 1], diff --git a/keras_hub/src/models/text_to_image.py b/keras_hub/src/models/text_to_image.py index 54b8dcdae2..1e1f677b3b 100644 --- a/keras_hub/src/models/text_to_image.py +++ b/keras_hub/src/models/text_to_image.py @@ -249,7 +249,7 @@ def generate( self, inputs, num_steps, - guidance_scale, + guidance_scale=None, seed=None, ): """Generate image based on the provided `inputs`. @@ -283,15 +283,23 @@ def generate( - A `tf.data.Dataset` with "prompts" and/or "negative_prompts" keys num_steps: int. The number of diffusion steps to take. - guidance_scale: float. The classifier free guidance scale defined in - [Classifier-Free Diffusion Guidance]( + guidance_scale: Optional float. The classifier free guidance scale + defined in [Classifier-Free Diffusion Guidance]( https://arxiv.org/abs/2207.12598). A higher scale encourages generating images more closely related to the prompts, typically - at the cost of lower image quality. + at the cost of lower image quality. Note that some models don't + utilize classifier-free guidance. seed: optional int. Used as a random seed. """ + num_steps = int(num_steps) + guidance_scale = ( + float(guidance_scale) if guidance_scale is not None else None + ) num_steps = ops.convert_to_tensor(num_steps, "int32") - guidance_scale = ops.convert_to_tensor(guidance_scale) + if guidance_scale is not None and guidance_scale > 1.0: + guidance_scale = ops.convert_to_tensor(guidance_scale) + else: + guidance_scale = None # Setup our three main passes. # 1. Preprocessing strings to dense integer tensors. diff --git a/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py b/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py index 38e19cf107..b98ea77386 100644 --- a/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py +++ b/tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py @@ -5,6 +5,10 @@ python tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py \ --preset stable_diffusion_3_medium --upload_uri kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3_medium +python tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py \ + --preset stable_diffusion_3.5_large --upload_uri kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3.5_large --dtype bfloat16 +python tools/checkpoint_conversion/convert_stable_diffusion_3_checkpoints.py \ + --preset stable_diffusion_3.5_large_turbo --upload_uri kaggle://kerashub/stablediffusion3/keras/stable_diffusion_3.5_large_turbo --dtype bfloat16 """ import os @@ -46,7 +50,29 @@ "vae": "sd3_medium.safetensors", # Tokenizer "clip_tokenizer": "hf://openai/clip-vit-large-patch14", - } + }, + "stable_diffusion_3.5_large": { + # HF root + "root": "hf://stabilityai/stable-diffusion-3.5-large", + # Model <-> Path + "clip_l": "text_encoder/model.safetensors", + "clip_g": "text_encoder_2/model.safetensors", + "diffuser": "sd3.5_large.safetensors", + "vae": "sd3.5_large.safetensors", + # Tokenizer + "clip_tokenizer": "hf://openai/clip-vit-large-patch14", + }, + "stable_diffusion_3.5_large_turbo": { + # HF root + "root": "hf://stabilityai/stable-diffusion-3.5-large-turbo", + # Model <-> Path + "clip_l": "text_encoder/model.safetensors", + "clip_g": "text_encoder_2/model.safetensors", + "diffuser": "sd3.5_large_turbo.safetensors", + "vae": "sd3.5_large_turbo.safetensors", + # Tokenizer + "clip_tokenizer": "hf://openai/clip-vit-large-patch14", + }, } flags.DEFINE_string( @@ -61,6 +87,12 @@ "The generated image will be saved here.", required=False, ) +flags.DEFINE_string( + "dtype", + "float16", + "The variable and compute dtype of the converted checkpoint.", + required=False, +) flags.DEFINE_string( "upload_uri", None, @@ -110,12 +142,32 @@ def convert_model(preset, height, width): 24, 24, 192, + None, # qk_norm vae, clip_l, clip_g, image_shape=(height, width, 3), name="stable_diffusion_3_backbone", ) + elif preset in ( + "stable_diffusion_3.5_large", + "stable_diffusion_3.5_large_turbo", + ): + backbone = StableDiffusion3Backbone( + 2, + 64 * 38, + 38, + 38, + 192, + "rms_norm", # qk_norm + vae, + clip_l, + clip_g, + image_shape=(height, width, 3), + name="stable_diffusion_3.5_backbone", + ) + else: + raise ValueError(f"Unknown preset={preset}.") return backbone @@ -234,7 +286,8 @@ def port_mha(loader, keras_variable, hf_weight_key, num_heads, hidden_dim): def port_ln_or_gn(loader, keras_variable, hf_weight_key): loader.port_weight(keras_variable.gamma, f"{hf_weight_key}.weight") - loader.port_weight(keras_variable.beta, f"{hf_weight_key}.bias") + if keras_variable.beta is not None: + loader.port_weight(keras_variable.beta, f"{hf_weight_key}.bias") def port_clip(preset, filename, model, projection_layer): with SafetensorLoader(preset, prefix="", fname=filename) as loader: @@ -343,6 +396,13 @@ def port_diffuser(preset, filename, model): port_dense( loader, block.attention_qkv, f"{prefix}.attn.qkv" ) + if block.qk_norm is not None: + port_ln_or_gn( + loader, block.q_norm, f"{prefix}.attn.ln_q" + ) + port_ln_or_gn( + loader, block.k_norm, f"{prefix}.attn.ln_k" + ) if block_name == "context_block" and (i == num_layers - 1): continue @@ -493,30 +553,44 @@ def port_attention(loader, keras_variable, hf_weight_key): port_vae(config["root"], config["vae"], keras_model.vae) -def validate_output(keras_model, keras_preprocessor, output_dir): +def validate_output(preset, keras_model, keras_preprocessor, output_dir): + if preset == "stable_diffusion_3_medium": + num_steps = 28 + guidance_scale = 7.0 + elif preset == "stable_diffusion_3.5_large": + num_steps = 40 + guidance_scale = 4.5 + elif preset == "stable_diffusion_3.5_large_turbo": + num_steps = 4 + guidance_scale = None # No CFG in turbo. + # TODO: Verify the numerics. prompt = "A cat holding a sign that says hello world" text_to_image = StableDiffusion3TextToImage(keras_model, keras_preprocessor) - image = text_to_image.generate(prompt, seed=42) + image = text_to_image.generate( + prompt, + num_steps=num_steps, + guidance_scale=guidance_scale, + seed=42, + ) image = Image.fromarray(image) - image.save(os.path.join(output_dir, "test.png")) + image.save(os.path.join(output_dir, f"{preset}.png")) def main(_): preset = FLAGS.preset output_dir = FLAGS.output_dir + dtype = FLAGS.dtype if os.path.exists(preset): shutil.rmtree(preset) - if os.path.exists(output_dir): - shutil.rmtree(output_dir) - os.makedirs(preset) - os.makedirs(output_dir) + os.makedirs(preset, exist_ok=True) + os.makedirs(output_dir, exist_ok=True) print(f"🏃 Coverting {preset}") - # Currently SD3 weights are float16 (and have much faster download - # times for it). We follow suit with Keras weights. - keras.config.set_dtype_policy("float16") + # Currently SD3 weights are float16 or bfloat16 (and have much faster + # download times for it). We follow suit with Keras weights. + keras.config.set_dtype_policy(dtype) height, width = 800, 800 # Use a smaller image size to speed up generation. keras_preprocessor = convert_preprocessor() @@ -526,7 +600,7 @@ def main(_): convert_weights(preset, keras_model) print("✅ Weights converted.") - validate_output(keras_model, keras_preprocessor, output_dir) + validate_output(preset, keras_model, keras_preprocessor, output_dir) print("✅ Output validated.") keras_preprocessor.save_to_preset(preset)