From 0fba4cc6a5f4bb8bba99dae4e663dac93b1226ae Mon Sep 17 00:00:00 2001 From: Lin Jia Date: Sat, 30 Nov 2024 18:20:58 -0800 Subject: [PATCH 1/4] wip --- .../models/controlnets/controlnet_sd3.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/src/diffusers/models/controlnets/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py index 2a5fcf35498e..fa45a4dde7c8 100644 --- a/src/diffusers/models/controlnets/controlnet_sd3.py +++ b/src/diffusers/models/controlnets/controlnet_sd3.py @@ -379,13 +379,18 @@ def custom_forward(*inputs): return custom_forward ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} - encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( - create_custom_forward(block), - hidden_states, - encoder_hidden_states, - temb, - **ckpt_kwargs, - ) + if self.context_embedder is not None: + encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + encoder_hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), hidden_states, temb, **ckpt_kwargs + ) else: if self.context_embedder is not None: From afaebbd0cdcf333bc9fba62c1341b129d1d54b64 Mon Sep 17 00:00:00 2001 From: Lin Jia Date: Sat, 30 Nov 2024 18:31:01 -0800 Subject: [PATCH 2/4] wip --- src/diffusers/models/controlnets/controlnet_sd3.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/diffusers/models/controlnets/controlnet_sd3.py b/src/diffusers/models/controlnets/controlnet_sd3.py index fa45a4dde7c8..dd7672c5a62a 100644 --- a/src/diffusers/models/controlnets/controlnet_sd3.py +++ b/src/diffusers/models/controlnets/controlnet_sd3.py @@ -388,6 +388,7 @@ def custom_forward(*inputs): **ckpt_kwargs, ) else: + # SD3.5 8b controlnet use single transformer block, which does not use `encoder_hidden_states` hidden_states = torch.utils.checkpoint.checkpoint( create_custom_forward(block), hidden_states, temb, **ckpt_kwargs ) From 8c36c6623d204461fb2f3861e16d85164dcfc7c9 Mon Sep 17 00:00:00 2001 From: Lin Jia Date: Sat, 30 Nov 2024 23:03:25 -0800 Subject: [PATCH 3/4] wip --- src/diffusers/models/transformers/transformer_sd3.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index a1ce9a2412c5..aef6ddbf08c2 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -424,8 +424,7 @@ def custom_forward(*inputs): # controlnet residual if block_controlnet_hidden_states is not None and block.context_pre_only is False: interval_control = len(self.transformer_blocks) / len(block_controlnet_hidden_states) - interval_control = int(np.ceil(interval_control)) - hidden_states = hidden_states + block_controlnet_hidden_states[index_block // interval_control] + hidden_states = hidden_states + block_controlnet_hidden_states[int(index_block / interval_control)] hidden_states = self.norm_out(hidden_states, temb) hidden_states = self.proj_out(hidden_states) From 38b0c9e679b0894d0483df9da80bc103aa1bff52 Mon Sep 17 00:00:00 2001 From: Lin Jia Date: Tue, 3 Dec 2024 22:45:14 -0800 Subject: [PATCH 4/4] make style changes --- src/diffusers/models/transformers/transformer_sd3.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_sd3.py b/src/diffusers/models/transformers/transformer_sd3.py index aef6ddbf08c2..887e8afd2106 100644 --- a/src/diffusers/models/transformers/transformer_sd3.py +++ b/src/diffusers/models/transformers/transformer_sd3.py @@ -15,7 +15,6 @@ from typing import Any, Dict, List, Optional, Tuple, Union -import numpy as np import torch import torch.nn as nn import torch.nn.functional as F