From 50bc0a6560748f9facae91073c54d98e4c93afa8 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 30 Oct 2023 14:00:21 +0100 Subject: [PATCH 1/2] finish --- .../pipelines/blip_diffusion/modeling_ctx_clip.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py b/src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py index 53d57188743d..69a46e9596d7 100644 --- a/src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +++ b/src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py @@ -19,11 +19,12 @@ from transformers import CLIPPreTrainedModel from transformers.modeling_outputs import BaseModelOutputWithPooling from transformers.models.clip.configuration_clip import CLIPTextConfig -from transformers.models.clip.modeling_clip import ( - CLIPEncoder, - _expand_mask, -) +from transformers.models.clip.modeling_clip import CLIPEncoder +try: + from transformers.models.clip.modeling_clip import _expand_mask +except ImportError: + from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask as _expand_mask # This is a modified version of the CLIPTextModel from transformers.models.clip.modeling_clip # Which allows for an extra input of "context embeddings", which are the query embeddings used in Qformer From f4364201066b0f92eeeb6251bc34f84cb9dd2b67 Mon Sep 17 00:00:00 2001 From: Patrick von Platen Date: Mon, 30 Oct 2023 14:01:48 +0100 Subject: [PATCH 2/2] finish --- .../blip_diffusion/modeling_ctx_clip.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py b/src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py index 69a46e9596d7..19f62e789e2d 100644 --- a/src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +++ b/src/diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py @@ -21,10 +21,20 @@ from transformers.models.clip.configuration_clip import CLIPTextConfig from transformers.models.clip.modeling_clip import CLIPEncoder -try: - from transformers.models.clip.modeling_clip import _expand_mask -except ImportError: - from transformers.modeling_attn_mask_utils import _prepare_4d_attention_mask as _expand_mask + +def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None): + """ + Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`. + """ + bsz, src_len = mask.size() + tgt_len = tgt_len if tgt_len is not None else src_len + + expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype) + + inverted_mask = 1.0 - expanded_mask + + return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) + # This is a modified version of the CLIPTextModel from transformers.models.clip.modeling_clip # Which allows for an extra input of "context embeddings", which are the query embeddings used in Qformer