Skip to content

Commit ab949ee

Browse files
committed
Merge remote-tracking branch 'origin/sd_xl' into dreambooth/sd-xl
2 parents 9a45d7f + 13107bb commit ab949ee

File tree

11 files changed

+1841
-40
lines changed

11 files changed

+1841
-40
lines changed

scripts/convert_original_stable_diffusion_to_diffusers.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,13 @@
126126
"--controlnet", action="store_true", default=None, help="Set flag if this is a controlnet checkpoint."
127127
)
128128
parser.add_argument("--half", action="store_true", help="Save weights in half precision.")
129+
parser.add_argument(
130+
"--vae_path",
131+
type=str,
132+
default=None,
133+
required=False,
134+
help="Set to a path, hub id to an already converted vae to not convert it again."
135+
)
129136
args = parser.parse_args()
130137

131138
pipe = download_from_original_stable_diffusion_ckpt(
@@ -144,6 +151,7 @@
144151
stable_unclip_prior=args.stable_unclip_prior,
145152
clip_stats_path=args.clip_stats_path,
146153
controlnet=args.controlnet,
154+
vae_path=args.vae_path,
147155
)
148156

149157
if args.half:

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@
160160
StableDiffusionPix2PixZeroPipeline,
161161
StableDiffusionSAGPipeline,
162162
StableDiffusionUpscalePipeline,
163+
StableDiffusionXLPipeline,
164+
StableDiffusionXLImg2ImgPipeline,
163165
StableUnCLIPImg2ImgPipeline,
164166
StableUnCLIPPipeline,
165167
TextToVideoSDPipeline,

src/diffusers/models/unet_2d_blocks.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def get_down_block(
3838
add_downsample,
3939
resnet_eps,
4040
resnet_act_fn,
41+
num_transformer_blocks=1,
4142
num_attention_heads=None,
4243
resnet_groups=None,
4344
cross_attention_dim=None,
@@ -106,6 +107,7 @@ def get_down_block(
106107
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
107108
return CrossAttnDownBlock2D(
108109
num_layers=num_layers,
110+
num_transformer_blocks=num_transformer_blocks,
109111
in_channels=in_channels,
110112
out_channels=out_channels,
111113
temb_channels=temb_channels,
@@ -227,6 +229,7 @@ def get_up_block(
227229
add_upsample,
228230
resnet_eps,
229231
resnet_act_fn,
232+
num_transformer_blocks=1,
230233
num_attention_heads=None,
231234
resnet_groups=None,
232235
cross_attention_dim=None,
@@ -281,6 +284,7 @@ def get_up_block(
281284
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
282285
return CrossAttnUpBlock2D(
283286
num_layers=num_layers,
287+
num_transformer_blocks=num_transformer_blocks,
284288
in_channels=in_channels,
285289
out_channels=out_channels,
286290
prev_output_channel=prev_output_channel,
@@ -506,6 +510,7 @@ def __init__(
506510
temb_channels: int,
507511
dropout: float = 0.0,
508512
num_layers: int = 1,
513+
num_transformer_blocks: int = 1,
509514
resnet_eps: float = 1e-6,
510515
resnet_time_scale_shift: str = "default",
511516
resnet_act_fn: str = "swish",
@@ -548,7 +553,7 @@ def __init__(
548553
num_attention_heads,
549554
in_channels // num_attention_heads,
550555
in_channels=in_channels,
551-
num_layers=1,
556+
num_layers=num_transformer_blocks,
552557
cross_attention_dim=cross_attention_dim,
553558
norm_num_groups=resnet_groups,
554559
use_linear_projection=use_linear_projection,
@@ -829,6 +834,7 @@ def __init__(
829834
temb_channels: int,
830835
dropout: float = 0.0,
831836
num_layers: int = 1,
837+
num_transformer_blocks: int = 1,
832838
resnet_eps: float = 1e-6,
833839
resnet_time_scale_shift: str = "default",
834840
resnet_act_fn: str = "swish",
@@ -873,7 +879,7 @@ def __init__(
873879
num_attention_heads,
874880
out_channels // num_attention_heads,
875881
in_channels=out_channels,
876-
num_layers=1,
882+
num_layers=num_transformer_blocks,
877883
cross_attention_dim=cross_attention_dim,
878884
norm_num_groups=resnet_groups,
879885
use_linear_projection=use_linear_projection,
@@ -1939,6 +1945,7 @@ def __init__(
19391945
temb_channels: int,
19401946
dropout: float = 0.0,
19411947
num_layers: int = 1,
1948+
num_transformer_blocks: int = 1,
19421949
resnet_eps: float = 1e-6,
19431950
resnet_time_scale_shift: str = "default",
19441951
resnet_act_fn: str = "swish",
@@ -1984,7 +1991,7 @@ def __init__(
19841991
num_attention_heads,
19851992
out_channels // num_attention_heads,
19861993
in_channels=out_channels,
1987-
num_layers=1,
1994+
num_layers=num_transformer_blocks,
19881995
cross_attention_dim=cross_attention_dim,
19891996
norm_num_groups=resnet_groups,
19901997
use_linear_projection=use_linear_projection,

src/diffusers/models/unet_2d_condition.py

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
9696
norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
9797
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
9898
The dimension of the cross attention features.
99+
num_transformer_blocks (`int` or `Tuple[int]`, *optional*, defaults to 1):
100+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
99101
encoder_hid_dim (`int`, *optional*, defaults to None):
100102
If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
101103
dimension to `cross_attention_dim`.
@@ -168,6 +170,7 @@ def __init__(
168170
norm_num_groups: Optional[int] = 32,
169171
norm_eps: float = 1e-5,
170172
cross_attention_dim: Union[int, Tuple[int]] = 1280,
173+
num_transformer_blocks: Union[int, Tuple[int]] = 1,
171174
encoder_hid_dim: Optional[int] = None,
172175
encoder_hid_dim_type: Optional[str] = None,
173176
attention_head_dim: Union[int, Tuple[int]] = 8,
@@ -176,6 +179,7 @@ def __init__(
176179
use_linear_projection: bool = False,
177180
class_embed_type: Optional[str] = None,
178181
addition_embed_type: Optional[str] = None,
182+
addition_time_embed_dim: Optional[int] = None,
179183
num_class_embeds: Optional[int] = None,
180184
upcast_attention: bool = False,
181185
resnet_time_scale_shift: str = "default",
@@ -349,6 +353,10 @@ def __init__(
349353
self.add_embedding = TextImageTimeEmbedding(
350354
text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
351355
)
356+
elif addition_embed_type == "text_time":
357+
self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
358+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
359+
352360
elif addition_embed_type is not None:
353361
raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
354362

@@ -381,6 +389,9 @@ def __init__(
381389
if isinstance(layers_per_block, int):
382390
layers_per_block = [layers_per_block] * len(down_block_types)
383391

392+
if isinstance(num_transformer_blocks, int):
393+
num_transformer_blocks = [num_transformer_blocks] * len(down_block_types)
394+
384395
if class_embeddings_concat:
385396
# The time embeddings are concatenated with the class embeddings. The dimension of the
386397
# time embeddings passed to the down, middle, and up blocks is twice the dimension of the
@@ -399,6 +410,7 @@ def __init__(
399410
down_block = get_down_block(
400411
down_block_type,
401412
num_layers=layers_per_block[i],
413+
num_transformer_blocks=num_transformer_blocks[i],
402414
in_channels=input_channel,
403415
out_channels=output_channel,
404416
temb_channels=blocks_time_embed_dim,
@@ -424,6 +436,7 @@ def __init__(
424436
# mid
425437
if mid_block_type == "UNetMidBlock2DCrossAttn":
426438
self.mid_block = UNetMidBlock2DCrossAttn(
439+
num_transformer_blocks=num_transformer_blocks[-1],
427440
in_channels=block_out_channels[-1],
428441
temb_channels=blocks_time_embed_dim,
429442
resnet_eps=norm_eps,
@@ -465,6 +478,7 @@ def __init__(
465478
reversed_num_attention_heads = list(reversed(num_attention_heads))
466479
reversed_layers_per_block = list(reversed(layers_per_block))
467480
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
481+
reversed_num_transformer_blocks = list(reversed(num_transformer_blocks))
468482
only_cross_attention = list(reversed(only_cross_attention))
469483

470484
output_channel = reversed_block_out_channels[0]
@@ -485,6 +499,7 @@ def __init__(
485499
up_block = get_up_block(
486500
up_block_type,
487501
num_layers=reversed_layers_per_block[i] + 1,
502+
num_transformer_blocks=reversed_num_transformer_blocks[i],
488503
in_channels=input_channel,
489504
out_channels=output_channel,
490505
prev_output_channel=prev_output_channel,
@@ -779,7 +794,6 @@ def forward(
779794

780795
if self.config.addition_embed_type == "text":
781796
aug_emb = self.add_embedding(encoder_hidden_states)
782-
emb = emb + aug_emb
783797
elif self.config.addition_embed_type == "text_image":
784798
# Kadinsky 2.1 - style
785799
if "image_embeds" not in added_cond_kwargs:
@@ -791,7 +805,25 @@ def forward(
791805
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
792806

793807
aug_emb = self.add_embedding(text_embs, image_embs)
794-
emb = emb + aug_emb
808+
elif self.config.addition_embed_type == "text_time":
809+
if "text_embeds" not in added_cond_kwargs:
810+
raise ValueError(
811+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
812+
)
813+
text_embeds = added_cond_kwargs.get("text_embeds")
814+
if "time_ids" not in added_cond_kwargs:
815+
raise ValueError(
816+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
817+
)
818+
time_ids = added_cond_kwargs.get("time_ids")
819+
time_embeds = self.add_time_proj(time_ids.flatten())
820+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
821+
822+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
823+
add_embeds = add_embeds.to(emb.dtype)
824+
aug_emb = self.add_embedding(add_embeds)
825+
826+
emb = emb + aug_emb
795827

796828
if self.time_embed_act is not None:
797829
emb = self.time_embed_act(emb)

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
StableUnCLIPPipeline,
9090
)
9191
from .stable_diffusion_safe import StableDiffusionPipelineSafe
92+
from .stable_diffusion_xl import StableDiffusionXLPipeline, StableDiffusionXLImg2ImgPipeline
9293
from .text_to_video_synthesis import TextToVideoSDPipeline, TextToVideoZeroPipeline
9394
from .unclip import UnCLIPImageVariationPipeline, UnCLIPPipeline
9495
from .unidiffuser import ImageTextPipelineOutput, UniDiffuserModel, UniDiffuserPipeline, UniDiffuserTextDecoder

0 commit comments

Comments
 (0)