-
Notifications
You must be signed in to change notification settings - Fork 6.1k
Description
Introduction
clip skip is a trick to feed the early-stopped features encoded by CLIPTextModel
into the cross-attention. If clip_skip = 2
, it means that we want to use the features from the layer before the last of the clip text encoder to guide our image generation. And our current diffusion pipeline can be regarded as clip_skip = 1
, which means that we just use the feature from the last layer of clip text encoder.
Here is a brief introduction to clip skip webui-wiki and related discussion link
A dominant majority of models need clip_skip=2
to reach a more aesthetic generation. I think adding this feature can give people more choices to optimize their generation.
Implementation
Adding clip_skip into diffusers is both simple and difficult. The main idea of clip_skip is simple, however, since our text encoder is imported from transformers, it is not easy to hack the CLIPTexModel
in diffusers.
To do so, we need to overwrite CLIPTextModel
and CLIPTextTransformer
. Here is my implementation:
class MyCLIPTextTransformer(CLIPTextTransformer):
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
clip_skip: Optional[int] = 1, # <-- newly added: take the last N layer of encoder as output
) -> Union[Tuple, BaseModelOutputWithPooling]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if input_ids is None:
raise ValueError("You have to specify input_ids")
input_shape = input_ids.size()
input_ids = input_ids.view(-1, input_shape[-1])
hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)
bsz, seq_len = input_shape
# CLIP's text model uses causal mask, prepare it here.
# https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
hidden_states.device
)
# expand attention_mask
if attention_mask is not None:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
encoder_outputs = self.encoder(
inputs_embeds=hidden_states,
attention_mask=attention_mask,
causal_attention_mask=causal_attention_mask,
output_attentions=output_attentions,
output_hidden_states=True,
return_dict=return_dict,
)
last_hidden_state = encoder_outputs.hidden_states[-clip_skip] # <-- newly added: take the last N layer of encoder as output
last_hidden_state = self.final_layer_norm(last_hidden_state)
# text_embeds.shape = [batch_size, sequence_length, transformer.width]
# take features from the eot embedding (eot_token is the highest number in each sequence)
# casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
]
if not return_dict:
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
return BaseModelOutputWithPooling(
last_hidden_state=last_hidden_state,
pooler_output=pooled_output,
hidden_states=encoder_outputs.hidden_states,
attentions=encoder_outputs.attentions,
)
def _build_causal_attention_mask(self, bsz, seq_len, dtype):
# lazily create causal attention mask, with full attention between the vision tokens
# pytorch uses additive attention mask; fill with -inf
mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
mask.fill_(torch.tensor(torch.finfo(dtype).min))
mask.triu_(1) # zero out the lower diagonal
mask = mask.unsqueeze(1) # expand mask
return mask
class MyCLIPTextModel(CLIPTextModel):
config_class = CLIPTextConfig
_no_split_modules = ["CLIPEncoderLayer"]
def __init__(self, config: CLIPTextConfig):
super().__init__(config)
self.text_model = MyCLIPTextTransformer(config) # <-- newly added: use the overrided clip_text_transformer
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
clip_skip: Optional[int] = 1, # <-- newly added: take the last N layer of encoder as output
) -> Union[Tuple, BaseModelOutputWithPooling]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
return self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
clip_skip=clip_skip # <-- newly added: take the last N layer of encoder as output
)
We can then use the overwritten clip_text_encoder in any_encode_prompt
function of the diffuser pipeline. For example,
in pipeline_stable_diffusion.py
def _encode_prompt(
self,
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt=None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
clip_skip: Optional[int] = 1,
):
# ...... # Omit Unchanged Codes
prompt_embeds = self.text_encoder(
text_input_ids.to(device),
clip_skip=clip_skip,
attention_mask=attention_mask,
)
prompt_embeds = prompt_embeds[0]
prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
# ...... # Omit Unchanged Codes
# this trick often applies on prompt embedding instead of negative prompt embedding
negative_prompt_embeds = self.text_encoder(
uncond_input.input_ids.to(device),
attention_mask=attention_mask,
)
negative_prompt_embeds = negative_prompt_embeds[0]
# ...... # Omit Unchanged Codes
return prompt_embeds
ImplementingCLIPTextTransformer
and CLIPTextModel
to support clip_skip cleanly and nicely is difficult for me. I'd like to leave this issue to the diffusers team.