Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add coca trained (#307) #308

Merged
merged 34 commits into from
Jan 29, 2023
Merged

Add coca trained (#307) #308

merged 34 commits into from
Jan 29, 2023

Conversation

rom1504
Copy link
Collaborator

@rom1504 rom1504 commented Dec 20, 2022

Proposal to merge coca into main.

This will be merged only once we are happy with results.

Reminder of what is still needed, as mentioned in #256 👍

Priority:

Also important:

  • make sure normal clip is still working the same (Add non regression tests #198 can help but otherwise just doing runs)
  • Move all models to output dicts instead of tuples, makes it more flexible pass output through optional loss / filters w/o brittle indexing and tuple len checks that are bound to fail with one more addition. Could make the dict output optional at construction time to prevent possible breaks for other downstream users.
  • report the captioning loss and contrastive loss independently in wandb

@gpucce is the main contributor on this
@iejMac is helping

I will review along with @rwightman

@AwePhD and @Soonhwan-Kwon have proposed their help

we can train some models to validate things work

This is a first model https://huggingface.co/laion/CoCa-ViT-B-32-laion2B-s13B-b90k/blob/main/epoch_95.pt that gets same performance as clip (a bit better), see https://wandb.ai/rom1504/open-clip/reports/CoCa-ViT-B-32-v1-vs-clip--VmlldzozMTgzNTEy : 63.79% for coca vs 63.57% for a clip with same settings

For further improvements, please open PRs targeting the coca branch

This PR will be merged to main only when everything is good

* initial setup

* add coca loss

* remove loss from the model

* fix loss

* add underscores

* name changes

* add cross attention to Residual and CustomResidual

* fix if

* ädd transformer 'decoder'

* minor fix

* looks better

* initlize coca model structure

* clean

* typo and format

* checkpoint signature

* adjust multimodal decoder and add CoCaTransformer

* keep older logic

* remove chunk

* typo

* fix

* make chunk dim explicit

* adjust cfg names

* add attentionalpooling

* add attentional pooling to coca

* small change

* add cocatransformer variants and AttentionPooling

* remoive older attention pooler

* adapt embed text to coca text transformer

* rm coca layers

* rename and remove useless CoCa models

* make attentionpooler pooler only

* refactor for one transformer only

* coca forward works

* separatae context and n_queries

* add inital coca_base config

* remove config

* small loss change

* init training file

* make variable order right

* remove print

* uniform names

* renaming

* add coca funcs to init

* add coca config and exclude from testing

* add and comment simple test (no trained model)

* add L2 norm

* make L2 same as in clip

* remove unused temperature

* type

* clean

* fix config

* make rename and move cfg

* rename

* temptative add coca to factory

* fix config

* update config

* embed contrastive cls token in model

* remove unused arg

* import create_loss

* make factory accept coca

* make caption loss distributed

* make loss customizable

* pass loss trhough training_epoch

* add coca specific params to params

* removed decoder unused parameters

* remove unused attributes

* adjust coca_config

* fix config and remove unused parameters

* remove comment

* remove more comments

* rename attention pooler

* rename TransformerDecoder

* make AttentionalPooler clearer

* add local loss logic to cocaloss

* only create loss if train in data

* remove wrong file

* fix attentional pooler call

* not ready for testing

* really not ready for testing

* eof lien

* uniform names

* add possible generative loss to evaluate

* change _build function names

* remove wrong import

* remove local_loss from captioning loss

* indexing error

* finish renaming

* adjust configs

* add training test for coca

* simplify captioning loss

* remove hf

* fix evaluate and loss

* remove print

* move projection

* add coca vit 32 config

* test on new config

* adjust coca_base config

* remove coca from test_inference

* maybe fix regression test

* make logits and labels contiguous

* simpler logic

* make contiguous after transpose

* last test

* try fix loss

* CoCa PR: loss fix + rename file

* wait for feedback on this

* cleanup

* CoCa PR: add set_grad_checkpointing + fix checkpoint API

* CoCa PR: fix eval (which uses encode_x instead of forward)

* move making space for CLS token into encode_text

* rever zs changes + fix

Co-authored-by: gpucce <g.puccetti92@gmail.com>
Co-authored-by: gpucce <g.puccetti@gmail.com>
Co-authored-by: iejmac <iejmac@ip-172-31-44-155.ec2.internal>
@rom1504 rom1504 mentioned this pull request Dec 20, 2022
@rom1504
Copy link
Collaborator Author

rom1504 commented Dec 20, 2022

The thing that would be most important to work on right now is captioning.
Using https://github.com/lucidrains/x-transformers/blob/main/x_transformers/autoregressive_wrapper.py should make it possible while using this model https://huggingface.co/laion/CoCa-ViT-B-32-laion2B-s13B-b90k/blob/main/epoch_95.pt

we did some preliminary testing without proper AR sampling and it seems to do something which is good.

Now if this gets implemented, the basic are down and we can focus on the cleaning up tasks outlined above

@rwightman
Copy link
Collaborator

Also pinging @gmittal and @apsdehal from FLAVA (#218), it looks like this will be a big merge prior to FLAVA being ready. So, some decisions here will impact that PR. One of those being a move to dict outputs from the model (I believe FLAVA has already done that).

We dict outputs we should be able to have a fairly clean custom loss + logging mechanism that doesn't polute the main script that works for both.

rom1504 and others added 6 commits December 21, 2022 18:30
Co-authored-by: Romain Beaumont <romain.rom1@gmail.com>
* buil_cls_mask

* add cls_mask to encode_text

* add model properties

Co-authored-by: Romain Beaumont <romain.rom1@gmail.com>
Co-authored-by: gpucce <g.puccetti@gmail.com>
* add ignore_index

* just need to pick right index

Co-authored-by: gpucce <g.puccetti@gmail.com>
* add initial generative support

* make generation context_length independend

* remove kwargs

* last positional embeddings for CLS

* typo

* fix mask len

* add comment

* remove unused args

* simpler logic for input shorter than context length

Co-authored-by: gpucce <g.puccetti@gmail.com>
@Soonhwan-Kwon
Copy link

@gpucce @rom1504 I'm middle of writing code for beam search and have a question that https://huggingface.co/laion/CoCa-ViT-B-32-laion2B-s13B-b90k/blob/main/epoch_95.pt is now outdated or not.

@gpucce
Copy link
Contributor

gpucce commented Dec 23, 2022

@gpucce @rom1504 I'm middle of writing code for beam search and have a question that https://huggingface.co/laion/CoCa-ViT-B-32-laion2B-s13B-b90k/blob/main/epoch_95.pt is now outdated or not.

@Soonhwan-Kwon it is still the only one there is, and it should work almost as well as when it was trained, the only change introduced since it was trained is attention masking of padded tokens, however it only has a small effect on performance

@gpucce
Copy link
Contributor

gpucce commented Dec 23, 2022

@rom1504 I think adding https://github.com/sks3i/pycocoevalcap to the dependencies would make adding cider and other captioning metrics very easy, is it fine to add it? otherwise I can add similar code to a eval_generation_utils.py file

@gpucce
Copy link
Contributor

gpucce commented Dec 23, 2022

Initial super naive generation evaluation on coco(val2017):

Bleu_1 Bleu_2 Bleu_3 Bleu_4 METEOR ROUGE_L CIDEr SPICE
0 0.254451 0.0968875 0.0283395 0.00957243 0.0941467 0.208789 0.108593 0.0592597

@Soonhwan-Kwon
Copy link

Soonhwan-Kwon commented Dec 24, 2022

beam_search(num_beams=6) evaluation on coco(val2017):

Blue_1 Blue_2 Blue_3 Blue_4 METEOR ROUGE_L CIDEr SPICE
0 0.213 0.121 0.069 0.039 0.094 0.225 0.204 (will be updated)

gpucce and others added 2 commits January 6, 2023 01:16
* use self.text in encode image

* unused var

* rever aAtention and CustoResidualAttentionBlock

* remove whiteline

* add dict output

* bintegrate self.text attributes

* HF compatibility

* better config and minor fixes

* clean

* remove eembed_cls option from HF

* use cls_token_position

* fix cls masking

* resize labels

* text -> self.text

* split loss logging

* add total loss

* minor logs formatting

* fix generate

* simpler logic

* disentangle proj for HF too

* adjust config

* only norm cls

* move attn_pool to VisionTransformer

* adjust coca_base config

* fix grad checkpointing in MultimodalTransformer

Co-authored-by: gpucce <g.puccetti@gmail.com>
Co-authored-by: iejMac <kilianmaciej6@gmail.com>
attn_mask = attn_mask.to(x.dtype) if attn_mask is not None else None
return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0]
def attention(
self,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like this is not changing anything, let's revert it (same below)

agreed @gpucce @iejMac ?

@rom1504
Copy link
Collaborator Author

rom1504 commented Jan 6, 2023

looks good now

I think we need this now:

@rom1504
Copy link
Collaborator Author

rom1504 commented Jan 6, 2023

would appreciate a second look @rwightman

@rwightman
Copy link
Collaborator

would appreciate a second look @rwightman

k, will take a closer look soon.

A higher level question, are se satisfied with the CoCa text gen results so far? I assume those text gen numbers posted so far are 1/100, so the CIDer is 20.4? Is that on the low side for zero-shot or good for a B/32 capacity model?

Do we have a target level of text gen performance to verify before merge? Or are we happy with current signs of life as long as it's cleaned and doesn't break CLIP training?

@rom1504
Copy link
Collaborator Author

rom1504 commented Jan 6, 2023

I think the captioning zero shot is mostly terrible compared to https://paperswithcode.com/sota/image-captioning-on-coco-captions where they get numbers like 150 but using much larger decoders

would def be ideal to fix that first, but still thinking as to whether that's a blocker for merging

@rom1504
Copy link
Collaborator Author

rom1504 commented Jan 7, 2023

This repo is moving fast these days, so we should try to either get this merged or get the most of the refactoring merged in the next week.
Otherwise there will be merge conflicts everywhere and it'll be annoying

@rwightman
Copy link
Collaborator

@rom1504 this is still most significant change so I don't see the rush but will start picking through, conflicts are small relative to the commitment to maintain and continue getting this to work well. Is everyone involved in this definitely going to continue dev and get it to a useable/worthwhie point after merge? if it's merged and left with current result it'd be a net drag.

Without having dug in fully yet, my main concern is that this might break use of models in other train scripts? We should ensure that any changes to core modelling interface (ie dict outputs, tuple changes etc) are opt-in so that if there are notebooks, scripts out there just using 'open_clip' (modelling) they won't all break ...

@rom1504
Copy link
Collaborator Author

rom1504 commented Jan 7, 2023

Without having dug in fully yet, my main concern is that this might break use of models in other train scripts? We should ensure that any changes to core modelling interface (ie dict outputs, tuple changes etc) are opt-in so that if there are notebooks, scripts out there just using 'open_clip' (modelling) they won't all break ...

We should try to add any method we consider as supported/exposed into the tests to make sure they keep working. That's the only reasonable path.
That and making it clear what is exposed and what is not

@rom1504
Copy link
Collaborator Author

rom1504 commented Jan 7, 2023

Regarding what other things it's going to conflict with, most likely text text clip, flava and fdsp. Those are less ready than coca but better get things in a good state early.

My remark was more in the direction of: we should either figure out how to merge this early, or if we can't, then it's better to extract the non-coca specific parts into another PR and merge those first, to make things easier. And I was saying that more for contributors than reviewers.

@gpucce
Copy link
Contributor

gpucce commented Jan 7, 2023

@rom1504 @rwightman For info I will continue working on this until it is at a good point as far as it is needed (hoping my effort gets the whole thing somewhere) regardless of merge.

Besides this, I think there should be one run going, when it finishes, if clip evaluation is ok, I will try and fine tune on coco train dataset to replicate evaluation as it is done in the coca paper and see if we get somewhere closer to the original.

EDIT: if you can share the model

@rwightman
Copy link
Collaborator

if there's a commitment from @gpucce @iejMac to keep pushing this to a desired performance level once it's merged (and not just let it dangle), we can make a push to try and merge in the comming week. Need to figure out the bwd compat concerns, and also take a closer look at the direction Flava is headed wrt to the dictionary outputs and loss fn to ensure we have compatible directions there...

@rwightman
Copy link
Collaborator

@gpucce thanks! looks great and merged to this PR. About to merge the jit tweaks before I do some more local testing, I think that jit approach is good for now

@rom1504
Copy link
Collaborator Author

rom1504 commented Jan 27, 2023

Are there any next steps here ?
Any reason not to press merge ?

@rwightman
Copy link
Collaborator

rwightman commented Jan 27, 2023

@rom1504 I think it's pretty much ready, I suppose should go through one more check of the jit stuff since I haven't re-tested since last few additions

I was watching the L/14 run that was active to see if there is any reason to be concerned

EDIT: I'll bring this up to date with the main branch and re-check jit tomorrow, maybe merge sometime before Monday?

@gpucce
Copy link
Contributor

gpucce commented Jan 27, 2023

Are there any next steps here ?

Any reason not to press merge ?

Maybe finish generate_beamsearch? For the evaluation of the model I am using it. @Soonhwan-Kwon can I help somehow?

@Soonhwan-Kwon
Copy link

@gpucce It's a pleasure, if you help me. I gave you invitation for collaborator in my repo.

@rom1504
Copy link
Collaborator Author

rom1504 commented Jan 27, 2023

up to date with main branch now

@rwightman
Copy link
Collaborator

Looking at this now, running into some sort of None / dtype issue in MHA in JIT, not quite sure what's going on just yet...

@gpucce
Copy link
Contributor

gpucce commented Jan 28, 2023

Looking at this now, running into some sort of None / dtype issue in MHA in JIT, not quite sure what's going on just yet...

Could be this pytorch/pytorch#92073 maybe?

@rwightman
Copy link
Collaborator

rwightman commented Jan 28, 2023

Looking at this now, running into some sort of None / dtype issue in MHA in JIT, not quite sure what's going on just yet...

Could be this pytorch/pytorch#92073 maybe?

Yup, it's exactly that and it's been a known issue for over a year! pytorch/pytorch#71470

def __init__(
self,
embed_dim,
multimodal_cfg: MultimodalCfg,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just noticed this, embed_dim isn't used for CoCa as it's taken from multimodal_cfg.latent_dim ... a little bit weird to have the values in cfg, and the arg, and then not use it.. hmmm

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesn't look like the MultimodalTransformer tower uses the latent_dim itself, so should that just be the determined by the cfg['embed_dim'] like the other models and remove multimodal_cfg['latent_dim'] ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is it resolved ? if not can you create an issue for it ?

@rwightman
Copy link
Collaborator

For JIT, given the only issue I see right now is a PyTorch bug (that should be easy to fix, and is CoCa specific) I don't see any issues there... hopefully it gets fixed soon in PT.

@rwightman
Copy link
Collaborator

I think we are close, I have another bit of logic to visit in TextTransformer, that embed_dim above in CocaModel to resolve

# for clip if it does not output_dict
module = model.module if type(model) == DistributedDataParallel else model
if is_clip(module) and not module.output_dict:
model_out = postprocess_clip_output(model_out)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that last thing of note in my list, I don't feel there is any need to supoort clip not outping a dict

  • allowing both output modes in the model was that so users in other training codebases, or in existing colab notebooks, etc wouldn't have full break in their code when they update open_clip
  • our training code can always operate in outpu_dict mode, I don't see a need to support both
  • if anyone is using open_clip train code with their own models, they can adapt it easily enough (they'll be more advanced users)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, can you create an issue for it ?

iejMac and others added 2 commits January 29, 2023 01:38
* add B/32 pretrained

* fix

* no capital

* slash
@rom1504 rom1504 merged commit 76c8f85 into main Jan 29, 2023
@rom1504 rom1504 deleted the coca branch January 29, 2023 00:41
@rom1504
Copy link
Collaborator Author

rom1504 commented Jan 29, 2023

thanks for all the work on coding this @gpucce !

thanks @iejMac for running the experiments !

and thanks @rwightman for reviews!

also thanks @lucidrains for initial implementation at https://github.com/lucidrains/CoCa-pytorch

@rom1504
Copy link
Collaborator Author

rom1504 commented Jan 29, 2023

Please create dedicated issues for follow ups

@rom1504 rom1504 mentioned this pull request Jan 29, 2023
9 tasks
@rom1504
Copy link
Collaborator Author

rom1504 commented Jan 29, 2023

currently known follow ups listed at #390

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants