Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding CoCa #256

Closed
wants to merge 123 commits into from
Closed

adding CoCa #256

wants to merge 123 commits into from

Conversation

gpucce
Copy link
Contributor

@gpucce gpucce commented Nov 25, 2022

The PR idea is to add the CoCa model as implemented in https://github.com/lucidrains/CoCa-pytorch, using existing parts as much as possible.

Ideally adding possibilty to choose between custom and non custom Attention implementation as is done for CLIP.

@rom1504
Copy link
Collaborator

rom1504 commented Nov 25, 2022

It would be best to see if it's possible to unify with existing code both the tower support and the losses
So we can train with a variety of text towers and we may benefit from the existing efficient implementation of losses without too much duplicate code

@gpucce
Copy link
Contributor Author

gpucce commented Nov 25, 2022

Sure, I will try and reuse as much as I can, for now it is mostly copied from the coca-pytorch repo, will probably ask for some help while I move on :)

@gpucce
Copy link
Contributor Author

gpucce commented Nov 27, 2022

@rom1504 I will reuse the visual_model from open_clip, however in coca-pytorch the transformer layer for the text model are different from the regular ones, feed_forward and attention are parallel, do you prefer like that or regular ones? I have no idea how much difference it makes.

Even if I use the regular attention I think the current implementation doesn't allow cross attention, would you prefer a CrossAttention layer or adding the crossattention possibility to the regular attention with kwargs to the forward?

@rom1504
Copy link
Collaborator

rom1504 commented Nov 27, 2022

I think let's bring options into current text model so they support coca

  • Parallel feed forward and self attention can be added as an option. I think it won't make a huge difference for these relatively small models. Gpt-j and palm used this mecanism to improve speed at large scale.
  • cross attention is indeed the important feature. I think bringing it to the existing model could be great. That should be possible for our text attention implementation. For HF encoders there's a chance it's already implemented and we can just use their implementation

Thanks for working on this!

@gpucce
Copy link
Contributor Author

gpucce commented Dec 1, 2022

@rom1504 I am moving forward, if you have time could you just have a look at how the cross attention and decoder are added to existing models to see if the integration is going in a reasonable direction?

@rom1504
Copy link
Collaborator

rom1504 commented Dec 18, 2022

Another idea of bonus feature (not for this PR probably) : support many HF decoder for the "multimodal transformer" that got added here

@iejMac
Copy link
Contributor

iejMac commented Dec 18, 2022

@gpucce Do you think you could give me push access to your fork? I'd love to help out but I don't want you to have to manually merge all of my suggested changes each time I make them

@gpucce
Copy link
Contributor Author

gpucce commented Dec 18, 2022

@gpucce Do you think you could give me push access to your fork? I'd love to help out but I don't want you to have to manually merge all of my suggested changes each time I make them

Sure, I will do it as soon as I am on a computer

@rwightman
Copy link
Collaborator

would appreciate your review @rwightman if you think anything big need to be done

Made some code review comments, most important points:

  • Make CoCa model dual tower (.visual, .text) from the start so builtin and HF text transformer models are the same, no bwd compat to be concerned about
  • Revert changes to Attention / CustomResidualAttentionBlock, we should avoid using them with the CoCa model
  • Move all models to output dicts instead of tuples, makes it more flexible pass output through optional loss / filters w/o brittle indexing and tuple len checks that are bound to fail with one more addition. Could make the dict output optional at construction time to prevent possible breaks for other downstream users.

and then test test test.

@gpucce
Copy link
Contributor Author

gpucce commented Dec 18, 2022

Thanks for the review @rwightman. @iejMac had raised a similar point to the second one. To coordinate with everyone (@rom1504) since the list of todos is getting longer, I am planning to address all of them, however I don't proceed too fast. If someone is working on some of them to speed up the whole process, please share with me.

Otherwise I will make everything as suggested taking a bit of time, in general will start from the generative part.

@rom1504
Copy link
Collaborator

rom1504 commented Dec 18, 2022

yes I think starting with implementing captioning will give us confidence that things are working

else LayerNorm
)

text = _build_input_dependent_text_tower(embed_dim, text_cfg, quick_gelu, cast_dtype, multimodal=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could be self.text here

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think adding this would make the state_dict of the model you have just trained incompatible, is that fine?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, we'll retrain

text_embs = text_embs.permute(1, 0, 2) # NLD -> LND
image_embs = image_embs.permute(1, 0, 2) # NLD -> LND

for r, ca in zip(self.resblocks, self.cross_attn):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you name those resblock and cross_attn ? r and ca are a bit confusing

mask.triu_(1) # zero out the lower diagonal
return mask

def forward(self, image_embs, text_embs):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add a comment explaining the shape of those, I think it'll help

def _repeat(self, t, N):
return t.reshape(1, 1, -1).repeat(N, 1, 1)

def encode_text(self, text, normalize=True, return_tokens=False):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same comment as for visual, can we use the text tower much more ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one is a bit harder than the visual one I think

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is missing? don't we simply need to add that return tokens option in the text encoder too ?

@AwePhD
Copy link

AwePhD commented Dec 20, 2022

Hello,

I would like to pre-train coca and so, from CoCa implementation's repo, I saw this PR/branch. I am not familiar with the base code of OpenCLIP, but I think it could be a good opportunity for me to get my hands dirty.
I cannot estimate if you would appreciate help for developing this feature or not. So let me know, I would be glad to participate :)

@rom1504
Copy link
Collaborator

rom1504 commented Dec 20, 2022

@AwePhD help would definitely be appreciated. See above comments for what we need

You can open PRs on the branch of this PR

@Soonhwan-Kwon
Copy link

I'm also want to help, and I've done many experiments with CoCa model on most public datasets, and caption generation also(but w/o HF compatibility).

@rom1504
Copy link
Collaborator

rom1504 commented Dec 20, 2022 via email

@rom1504
Copy link
Collaborator

rom1504 commented Dec 20, 2022 via email

@gpucce
Copy link
Contributor Author

gpucce commented Dec 20, 2022

Nice that it gets same performance!

If you can wait a moment before merging in a few moments I should be able to simplify the logic for the visual part, while for the text one more things would need changing.

@gpucce
Copy link
Contributor Author

gpucce commented Dec 20, 2022

And if you can somehow share the model that would be very useful

@@ -160,19 +155,31 @@ def encode_image(self, images, normalize=True, return_tokens=False):
def _repeat(self, t, N):
return t.reshape(1, 1, -1).repeat(N, 1, 1)

def _build_cls_mask(self, text, cast_dtype):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What should be the impact of this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think right now the cls token at the end can attend to pad tokens in the sequence, this should not be possible with this extra mask

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sounds good!

@gpucce
Copy link
Contributor Author

gpucce commented Dec 20, 2022

@rom1504 the visual part should be simpler now, can make the new branch now, I will still work on the generative part as soon as I have time

@@ -465,6 +465,9 @@ def forward(self, x: torch.Tensor):
x = self.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD

if output_tokens:
return x

if self.global_average_pool:
x = x.mean(dim=1)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if this can be done after the ln post

if yes, then it will make it possible to do the ln post only here and not in coca

x = x.permute(1, 0, 2) # NLD -> LND
x = self.visual.transformer(x)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.visual(images, output_tokens=True)
x = self.visual.ln_post(x)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this ln post call makes a big assumption on the API of the visual encoder

@rom1504 rom1504 mentioned this pull request Dec 20, 2022
@rom1504
Copy link
Collaborator

rom1504 commented Dec 20, 2022

@gpucce I merged into coca branch. I excluded the 2 last commits you added today to avoid discrepancies with our trained model.
Can you please open a PR with your 2 last commits and also any further improvement ? targeting coca branch, and not main
Thanks

All comments mentioned here stay valid

@rom1504 rom1504 closed this Dec 20, 2022
@rom1504 rom1504 mentioned this pull request Dec 20, 2022
6 tasks
@rom1504
Copy link
Collaborator

rom1504 commented Dec 20, 2022

Please refer to #308 for follow ups

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

6 participants