Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,8 @@
title: LDM3D Text-to-(RGB, Depth)
- local: api/pipelines/stable_diffusion/adapter
title: Stable Diffusion T2I-adapter
- local: api/pipelines/stable_diffusion/gligen
title: GLIGEN (Grounded Language-to-Image Generation)
title: Stable Diffusion
- local: api/pipelines/stable_unclip
title: Stable unCLIP
Expand Down
46 changes: 46 additions & 0 deletions docs/source/en/api/pipelines/stable_diffusion/gligen.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
<!--Copyright 2023 The GLIGEN Authors and The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# GLIGEN (Grounded Language-to-Image Generation)

The GLIGEN model was created by researchers and engineers from [University of Wisconsin-Madison, Columbia University, and Microsoft](https://github.com/gligen/GLIGEN). The [`StableDiffusionGLIGENPipeline`] can generate photorealistic images conditioned on grounding inputs. Along with text and bounding boxes, if input images are given, this pipeline can insert objects described by text at the region defined by bounding boxes. Otherwise, it'll generate an image described by the caption/prompt and insert objects described by text at the region defined by bounding boxes. It's trained on COCO2014D and COCO2014CD datasets, and the model uses a frozen CLIP ViT-L/14 text encoder to condition itself on grounding inputs.

The abstract from the [paper](https://huggingface.co/papers/2301.07093) is:

*Large-scale text-to-image diffusion models have made amazing advances. However, the status quo is to use text input alone, which can impede controllability. In this work, we propose GLIGEN, Grounded-Language-to-Image Generation, a novel approach that builds upon and extends the functionality of existing pre-trained text-to-image diffusion models by enabling them to also be conditioned on grounding inputs. To preserve the vast concept knowledge of the pre-trained model, we freeze all of its weights and inject the grounding information into new trainable layers via a gated mechanism. Our model achieves open-world grounded text2img generation with caption and bounding box condition inputs, and the grounding ability generalizes well to novel spatial configurations and concepts. GLIGEN’s zeroshot performance on COCO and LVIS outperforms existing supervised layout-to-image baselines by a large margin.*

<Tip>

Make sure to check out the Stable Diffusion [Tips](https://huggingface.co/docs/diffusers/en/api/pipelines/stable_diffusion/overview#tips) section to learn how to explore the tradeoff between scheduler speed and quality and how to reuse pipeline components efficiently!

If you want to use one of the official checkpoints for a task, explore the [gligen](https://huggingface.co/gligen) Hub organizations!

</Tip>

This pipeline was contributed by [Nikhil Gajendrakumar](https://github.com/nikhil-masterful).

## StableDiffusionGLIGENPipeline

[[autodoc]] StableDiffusionGLIGENPipeline
- all
- __call__
- enable_vae_slicing
- disable_vae_slicing
- enable_vae_tiling
- disable_vae_tiling
- enable_model_cpu_offload
- prepare_latents
- enable_fuser

## StableDiffusionPipelineOutput

[[autodoc]] pipelines.stable_diffusion.StableDiffusionPipelineOutput
1 change: 1 addition & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@
StableDiffusionControlNetPipeline,
StableDiffusionDepth2ImgPipeline,
StableDiffusionDiffEditPipeline,
StableDiffusionGLIGENPipeline,
StableDiffusionImageVariationPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
Expand Down
46 changes: 45 additions & 1 deletion src/diffusers/models/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,38 @@
from .lora import LoRACompatibleLinear


@maybe_allow_in_graph
class GatedSelfAttentionDense(nn.Module):
def __init__(self, query_dim, context_dim, n_heads, d_head):
super().__init__()

# we need a linear projection since we need cat visual feature and obj feature
self.linear = nn.Linear(context_dim, query_dim)

self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
self.ff = FeedForward(query_dim, activation_fn="geglu")

self.norm1 = nn.LayerNorm(query_dim)
self.norm2 = nn.LayerNorm(query_dim)

self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))

self.enabled = True

def forward(self, x, objs):
if not self.enabled:
return x

n_visual = x.shape[1]
objs = self.linear(objs)

x = x + self.alpha_attn.tanh() * self.attn(self.norm1(torch.cat([x, objs], dim=1)))[:, :n_visual, :]
x = x + self.alpha_dense.tanh() * self.ff(self.norm2(x))

return x


@maybe_allow_in_graph
class BasicTransformerBlock(nn.Module):
r"""
Expand Down Expand Up @@ -62,6 +94,7 @@ def __init__(
norm_elementwise_affine: bool = True,
norm_type: str = "layer_norm",
final_dropout: bool = False,
attention_type: str = "default",
):
super().__init__()
self.only_cross_attention = only_cross_attention
Expand Down Expand Up @@ -120,6 +153,10 @@ def __init__(
self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)

# 4. Fuser
if attention_type == "gated":
self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)

# let chunk size default to None
self._chunk_size = None
self._chunk_dim = 0
Expand Down Expand Up @@ -150,7 +187,9 @@ def forward(
else:
norm_hidden_states = self.norm1(hidden_states)

cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
# 0. Prepare GLIGEN inputs
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)

attn_output = self.attn1(
norm_hidden_states,
Expand All @@ -162,6 +201,11 @@ def forward(
attn_output = gate_msa.unsqueeze(1) * attn_output
hidden_states = attn_output + hidden_states

# 1.5 GLIGEN Control
if gligen_kwargs is not None:
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
# 1.5 ends

# 2. Cross-Attention
if self.attn2 is not None:
norm_hidden_states = (
Expand Down
56 changes: 56 additions & 0 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,3 +544,59 @@ def shape(x):
a = a.reshape(bs, -1, 1).transpose(1, 2)

return a[:, 0, :] # cls_token


class FourierEmbedder(nn.Module):
def __init__(self, num_freqs=64, temperature=100):
super().__init__()

self.num_freqs = num_freqs
self.temperature = temperature

freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
freq_bands = freq_bands[None, None, None]
self.register_buffer("freq_bands", freq_bands, persistent=False)

def __call__(self, x):
x = self.freq_bands * x.unsqueeze(-1)
return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)


class PositionNet(nn.Module):
def __init__(self, positive_len, out_dim, fourier_freqs=8):
super().__init__()
self.positive_len = positive_len
self.out_dim = out_dim

self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy

if isinstance(out_dim, tuple):
out_dim = out_dim[0]
self.linears = nn.Sequential(
nn.Linear(self.positive_len + self.position_dim, 512),
nn.SiLU(),
nn.Linear(512, 512),
nn.SiLU(),
nn.Linear(512, out_dim),
)

self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))

def forward(self, boxes, masks, positive_embeddings):
masks = masks.unsqueeze(-1)

# embedding position (it may includes padding as placeholder)
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C

# learnable null embedding
positive_null = self.null_positive_feature.view(1, 1, -1)
xyxy_null = self.null_position_feature.view(1, 1, -1)

# replace padding with learnable null embedding
positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null

objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
return objs
2 changes: 2 additions & 0 deletions src/diffusers/models/transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def __init__(
upcast_attention: bool = False,
norm_type: str = "layer_norm",
norm_elementwise_affine: bool = True,
attention_type: str = "default",
):
super().__init__()
self.use_linear_projection = use_linear_projection
Expand Down Expand Up @@ -183,6 +184,7 @@ def __init__(
upcast_attention=upcast_attention,
norm_type=norm_type,
norm_elementwise_affine=norm_elementwise_affine,
attention_type=attention_type,
)
for d in range(num_layers)
]
Expand Down
10 changes: 10 additions & 0 deletions src/diffusers/models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ def get_down_block(
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
attention_type="default",
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
cross_attention_norm=None,
Expand Down Expand Up @@ -129,6 +130,7 @@ def get_down_block(
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_type=attention_type,
)
elif down_block_type == "SimpleCrossAttnDownBlock2D":
if cross_attention_dim is None:
Expand Down Expand Up @@ -244,6 +246,7 @@ def get_up_block(
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
attention_type="default",
resnet_skip_time_act=False,
resnet_out_scale_factor=1.0,
cross_attention_norm=None,
Expand Down Expand Up @@ -307,6 +310,7 @@ def get_up_block(
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_type=attention_type,
)
elif up_block_type == "SimpleCrossAttnUpBlock2D":
if cross_attention_dim is None:
Expand Down Expand Up @@ -556,6 +560,7 @@ def __init__(
dual_cross_attention=False,
use_linear_projection=False,
upcast_attention=False,
attention_type="default",
):
super().__init__()

Expand Down Expand Up @@ -592,6 +597,7 @@ def __init__(
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
attention_type=attention_type,
)
)
else:
Expand Down Expand Up @@ -934,6 +940,7 @@ def __init__(
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
attention_type="default",
):
super().__init__()
resnets = []
Expand Down Expand Up @@ -970,6 +977,7 @@ def __init__(
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
attention_type=attention_type,
)
)
else:
Expand Down Expand Up @@ -2068,6 +2076,7 @@ def __init__(
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
attention_type="default",
):
super().__init__()
resnets = []
Expand Down Expand Up @@ -2106,6 +2115,7 @@ def __init__(
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
attention_type=attention_type,
)
)
else:
Expand Down
19 changes: 19 additions & 0 deletions src/diffusers/models/unet_2d_condition.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
ImageHintTimeEmbedding,
ImageProjection,
ImageTimeEmbedding,
PositionNet,
TextImageProjection,
TextImageTimeEmbedding,
TextTimeEmbedding,
Expand Down Expand Up @@ -198,6 +199,7 @@ def __init__(
conv_in_kernel: int = 3,
conv_out_kernel: int = 3,
projection_class_embeddings_input_dim: Optional[int] = None,
attention_type: str = "default",
class_embeddings_concat: bool = False,
mid_block_only_cross_attention: Optional[bool] = None,
cross_attention_norm: Optional[str] = None,
Expand Down Expand Up @@ -446,6 +448,7 @@ def __init__(
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_type=attention_type,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
Expand All @@ -469,6 +472,7 @@ def __init__(
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
attention_type=attention_type,
)
elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
self.mid_block = UNetMidBlock2DSimpleCrossAttn(
Expand Down Expand Up @@ -535,6 +539,7 @@ def __init__(
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_type=attention_type,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
Expand All @@ -560,6 +565,14 @@ def __init__(
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
)

if attention_type == "gated":
positive_len = 768
if isinstance(cross_attention_dim, int):
positive_len = cross_attention_dim
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
positive_len = cross_attention_dim[0]
self.position_net = PositionNet(positive_len=positive_len, out_dim=cross_attention_dim)

@property
def attn_processors(self) -> Dict[str, AttentionProcessor]:
r"""
Expand Down Expand Up @@ -895,6 +908,12 @@ def forward(
# 2. pre-process
sample = self.conv_in(sample)

# 2.5 GLIGEN position net
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
cross_attention_kwargs = cross_attention_kwargs.copy()
gligen_args = cross_attention_kwargs.pop("gligen")
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}

# 3. down

is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
StableDiffusionAttendAndExcitePipeline,
StableDiffusionDepth2ImgPipeline,
StableDiffusionDiffEditPipeline,
StableDiffusionGLIGENPipeline,
StableDiffusionImageVariationPipeline,
StableDiffusionImg2ImgPipeline,
StableDiffusionInpaintPipeline,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/stable_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class StableDiffusionPipelineOutput(BaseOutput):
from .pipeline_cycle_diffusion import CycleDiffusionPipeline
from .pipeline_stable_diffusion import StableDiffusionPipeline
from .pipeline_stable_diffusion_attend_and_excite import StableDiffusionAttendAndExcitePipeline
from .pipeline_stable_diffusion_gligen import StableDiffusionGLIGENPipeline
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
Expand Down
Loading