Skip to content

Add clip skip for diffusion pipeline #3212

@NormXU

Description

@NormXU

Introduction

clip skip is a trick to feed the early-stopped features encoded by CLIPTextModel into the cross-attention. If clip_skip = 2, it means that we want to use the features from the layer before the last of the clip text encoder to guide our image generation. And our current diffusion pipeline can be regarded as clip_skip = 1, which means that we just use the feature from the last layer of clip text encoder.

Here is a brief introduction to clip skip webui-wiki and related discussion link

A dominant majority of models need clip_skip=2 to reach a more aesthetic generation. I think adding this feature can give people more choices to optimize their generation.

Implementation

Adding clip_skip into diffusers is both simple and difficult. The main idea of clip_skip is simple, however, since our text encoder is imported from transformers, it is not easy to hack the CLIPTexModel in diffusers.

To do so, we need to overwrite CLIPTextModel and CLIPTextTransformer. Here is my implementation:

class MyCLIPTextTransformer(CLIPTextTransformer):

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        clip_skip: Optional[int] = 1, # <-- newly added: take the last N layer of encoder as output
    ) -> Union[Tuple, BaseModelOutputWithPooling]:

        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if input_ids is None:
            raise ValueError("You have to specify input_ids")

        input_shape = input_ids.size()
        input_ids = input_ids.view(-1, input_shape[-1])

        hidden_states = self.embeddings(input_ids=input_ids, position_ids=position_ids)

        bsz, seq_len = input_shape
        # CLIP's text model uses causal mask, prepare it here.
        # https://github.com/openai/CLIP/blob/cfcffb90e69f37bf2ff1e988237a0fbe41f33c04/clip/model.py#L324
        causal_attention_mask = self._build_causal_attention_mask(bsz, seq_len, hidden_states.dtype).to(
            hidden_states.device
        )
        # expand attention_mask
        if attention_mask is not None:
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            attention_mask = _expand_mask(attention_mask, hidden_states.dtype)


        encoder_outputs = self.encoder(
            inputs_embeds=hidden_states,
            attention_mask=attention_mask,
            causal_attention_mask=causal_attention_mask,
            output_attentions=output_attentions,
            output_hidden_states=True,
            return_dict=return_dict,
        )

        last_hidden_state = encoder_outputs.hidden_states[-clip_skip] # <-- newly added:  take the last N layer of encoder as output
        last_hidden_state = self.final_layer_norm(last_hidden_state)

        # text_embeds.shape = [batch_size, sequence_length, transformer.width]
        # take features from the eot embedding (eot_token is the highest number in each sequence)
        # casting to torch.int for onnx compatibility: argmax doesn't support int64 inputs with opset 14
        pooled_output = last_hidden_state[
            torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
            input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
        ]

        if not return_dict:
            return (last_hidden_state, pooled_output) + encoder_outputs[1:]

        return BaseModelOutputWithPooling(
            last_hidden_state=last_hidden_state,
            pooler_output=pooled_output,
            hidden_states=encoder_outputs.hidden_states,
            attentions=encoder_outputs.attentions,
        )

    def _build_causal_attention_mask(self, bsz, seq_len, dtype):
        # lazily create causal attention mask, with full attention between the vision tokens
        # pytorch uses additive attention mask; fill with -inf
        mask = torch.empty(bsz, seq_len, seq_len, dtype=dtype)
        mask.fill_(torch.tensor(torch.finfo(dtype).min))
        mask.triu_(1)  # zero out the lower diagonal
        mask = mask.unsqueeze(1)  # expand mask
        return mask


class MyCLIPTextModel(CLIPTextModel):
    config_class = CLIPTextConfig

    _no_split_modules = ["CLIPEncoderLayer"]

    def __init__(self, config: CLIPTextConfig):
        super().__init__(config)
        self.text_model = MyCLIPTextTransformer(config) # <-- newly added:  use the overrided clip_text_transformer
        # Initialize weights and apply final processing
        self.post_init()

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        clip_skip: Optional[int] = 1, # <-- newly added:  take the last N layer of encoder as output
    ) -> Union[Tuple, BaseModelOutputWithPooling]:
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        return self.text_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            clip_skip=clip_skip  # <-- newly added:  take the last N layer of encoder as output
        )

We can then use the overwritten clip_text_encoder in any_encode_prompt function of the diffuser pipeline. For example,
in pipeline_stable_diffusion.py

    def _encode_prompt(
        self,
        prompt,
        device,
        num_images_per_prompt,
        do_classifier_free_guidance,
        negative_prompt=None,
        prompt_embeds: Optional[torch.FloatTensor] = None,
        negative_prompt_embeds: Optional[torch.FloatTensor] = None,
       clip_skip: Optional[int] = 1,
    ):
        # ......  # Omit Unchanged Codes

        prompt_embeds = self.text_encoder(
            text_input_ids.to(device),
            clip_skip=clip_skip,  
            attention_mask=attention_mask,
        )
        prompt_embeds = prompt_embeds[0]

        prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)

        # ......  # Omit Unchanged Codes
        # this trick often applies on prompt embedding instead of negative prompt embedding
        negative_prompt_embeds = self.text_encoder(
            uncond_input.input_ids.to(device),
            attention_mask=attention_mask,
        )
        negative_prompt_embeds = negative_prompt_embeds[0]

       # ......  # Omit Unchanged Codes

        return prompt_embeds

ImplementingCLIPTextTransformer and CLIPTextModel to support clip_skip cleanly and nicely is difficult for me. I'd like to leave this issue to the diffusers team.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions