Skip to content

Conversation

@jenkspt
Copy link
Contributor

@jenkspt jenkspt commented Jul 21, 2022

This is a simple fix that removes unnecessary attention computation from the AttentionPool2d Module.
In the existing version, self attention is calculated on the full spatial + average embedding sequence with shape [(HW+1), N, C]. In the proposed fix, attention is calculated with the average embedding [1, N, C] as the query and the spatial + average embedding sequence [(HW+1), N, C] as the key/value.

I created this gist: https://gist.github.com/jenkspt/3a09cc150ab531781c6084c166047639 to demonstrate the equivalence of the existing implementation and the proposed one. There is parity in both the computation and the parameter state -- so there shouldn't be any breaking changes introduced.

I realize that AttentionPool2d is only used once in the CLIP model, so this fix will not have a huge impact -- however I arrived here from https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/unet.py#L22-L51, which is based on the clip version (and has the same problem) -- so I think there is the added benefit for posterity

@jongwook jongwook merged commit f69a9bc into openai:main Jul 21, 2022
@jongwook
Copy link
Collaborator

Thanks for the PR, a nice fix! There's a similar inefficiency in the last layer of the vision transformer but it won't be as simple as this PR to fix it..

@jenkspt jenkspt deleted the fix-attention-pool2d branch July 21, 2022 20:28
@jenkspt
Copy link
Contributor Author

jenkspt commented Jul 21, 2022

Thanks for the PR, a nice fix! There's a similar inefficiency in the last layer of the vision transformer but it won't be as simple as this PR to fix it..

NP! -- are you referring to this?

CLIP/clip/model.py

Lines 185 to 187 in f69a9bc

def attention(self, x: torch.Tensor):
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]

@jongwook
Copy link
Collaborator

That [0] was just to throw out the attn_output_weights returned by nn.MultiHeadAttention. The inefficiency that I mentioned is at:

CLIP/clip/model.py

Lines 231 to 235 in f69a9bc

x = x.permute(1, 0, 2) # NLD -> LND
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_post(x[:, 0, :])

where the vision encoder takes activations from just the CLS position as its output. But I don't think it needs a fix here anytime soon; was just noting!

rom1504 pushed a commit to rom1504/CLIP that referenced this pull request Jan 13, 2024
)

* fix inefficient attention computation

* remove erroneous formatting

* simplified flatten

Co-authored-by: Jong Wook Kim <jongwook@nyu.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants