Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Flax image captioning example #14864

Merged

Conversation

ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Dec 21, 2021

What does this PR do?

Add run_image_captioning_flax.py (modified from run_summarization_flax.py).

Who can review

Examples: @patil-suraj + cc @patrickvonplaten @NielsRogge @sgugger for info

@ydshieh ydshieh changed the title Add Flax image captioning example [WIP] Add Flax image captioning example Dec 21, 2021
@ydshieh ydshieh changed the title [WIP] Add Flax image captioning example Add Flax image captioning example Dec 21, 2021
Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work @ydshieh !
The examples look nice. I left a few comments below, specifically the initialization logic looks quite complex, it would be nice if we could keep it simple so the script would be easier to read.

Maybe just not support training from scratch if it's not useful and adds a lot of code. It's important to keep the script simple so users could understand it and modify it to their needs.

examples/flax/image-captioning/README.md Outdated Show resolved Hide resolved
examples/flax/image-captioning/README.md Outdated Show resolved Hide resolved
examples/flax/image-captioning/README.md Outdated Show resolved Hide resolved
examples/flax/image-captioning/README.md Outdated Show resolved Hide resolved
examples/flax/image-captioning/README.md Outdated Show resolved Hide resolved
Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your PR!

I think this example tries to do too much in one file and should be made easier to read. Examples are not supposed to do everything for every possible model, and here there are way too many lines just for the model creation which makes it super hard to read. Users would benefit more from simpler code that they can understand and customize IMO.

@ydshieh
Copy link
Collaborator Author

ydshieh commented Dec 23, 2021

@patil-suraj , @sgugger

Thank you for the comments. I will make the example much simpler by only supporting loading pre-trained vision encoder and text models.

About casting jax array to np arrary, there is a significant slow down (at least, in image examples) when using jax array as indices for accessing datasets.Dataset. I will re-produce the timing comparison.

@ydshieh ydshieh marked this pull request as draft December 23, 2021 21:06
@ydshieh ydshieh changed the title Add Flax image captioning example [WIP] Add Flax image captioning example Dec 23, 2021
@ydshieh ydshieh force-pushed the add_flax_example_image_captioning branch from 807c92f to 650fb4a Compare December 31, 2021 12:47
@ydshieh
Copy link
Collaborator Author

ydshieh commented Dec 31, 2021

Hi, @patil-suraj @sgugger

I simplified the config/model initialization parts (only support loading pretrained encoder & decoder).


For @patil-suraj

About using numpy array instead of jnp.array when it comes to datasets,

For this line

takes 30 seconds (for selecting 16384 elements) using jax.numpy, while using numpy only takes 0.005 second.

For this line (take 256 elements - with image data)

jax.numpy: 0.45 second / numpy: 0.10 - 0.15 second

A singe training step (global batch size: 256 images) takes < 0.5 seconds on TPU.

Due to this significant differences in processing speed, I think it is worth keeping using numpy when dealing with datasets.
Let me know if you have different opinions about this :-)

@ydshieh ydshieh marked this pull request as ready for review December 31, 2021 13:14
@ydshieh ydshieh changed the title [WIP] Add Flax image captioning example Add Flax image captioning example Dec 31, 2021
# Replicate the train state on each device
state = state.replicate()

if training_args.do_train:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

think we can reduce the lines quite significantly here by making use of somehow using f"Num {dataset.keys()}"

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, the keys() is available for 'datasets.dataset_dict.DatasetDict' (the one with the different splits), and it gives dict_keys(['train', 'validation', 'test']).

And after taking the splits, like train_dataset = dataset['train'], it becomes 'datasets.arrow_dataset.Dataset' and there is no more keys() method.

It's not clear to me how to use it at this place.

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks already really nice. I think we should try to make the notebook a bit easier (maybe at the expense of not covering every edge case) and try to shorten some of the code a bit.

ydshieh and others added 2 commits January 3, 2022 12:17
Co-authored-by: Suraj Patil <surajp815@gmail.com>
Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's looking good now! I left a few more comments. It would be nice if we could just use from_pretrained instead of from_encoder_decoder_pretrained as @patrickvonplaten suggested.

examples/flax/image-captioning/README.md Outdated Show resolved Hide resolved
Comment on lines +944 to +950
def decay_mask_fn(params):
flat_params = traverse_util.flatten_dict(params)
layer_norm_params = [
(name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"]
]
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params}
return traverse_util.unflatten_dict(flat_mask)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the default decoder model, is it bart or gpt2 ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I used GPT2 for my image captioning training. (no default value in the argument)

@ydshieh
Copy link
Collaborator Author

ydshieh commented Jan 5, 2022

Hi

  • support from_pretrained (the model creation from encoder/decoder is done in another script)
  • README updated
  • rename to coco_dataset_script to avoid confusion
  • other nits applied

Thanks for the reviews :-)

Copy link
Contributor

@patil-suraj patil-suraj left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Great work @ydshieh and thanks a lot for being patient with the review :)

@patrickvonplaten do you want to take another look?

examples/flax/image-captioning/README.md Outdated Show resolved Hide resolved
Comment on lines +16 to +20
"""
Create a VisionEncoderDecoderModel instance from pretrained encoder/decoder models.

The cross-attention will be randomly initialized.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!

Comment on lines +303 to +305
# We use `numpy.ndarray` to interact with `datasets.Dataset`, since using `jax.numpy.array` to index into a
# dataset is significantly slow. Using JAX array at the 1st place is only to keep JAX's PRNGs generation
# mechanism, which works differently from NumPy/SciPy.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice comment!

ydshieh and others added 3 commits January 5, 2022 17:24
Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
model.config.decoder_start_token_id = decoder_start_token_id
model.config.pad_token_id = pad_token_id

feature_extractor = AutoFeatureExtractor.from_pretrained(model_args.encoder_model_name_or_path)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) Is there no ...Processor class? Would be nice to save with a Processor class so that it can be loaded with AutoProcessor. But I'm fine with doing it in a follow-up PR

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could probably add a generic processor class that can accept any tokenizer and feature extractor, we have one such processor for the VisionTextDudalEncoder model https://github.com/huggingface/transformers/blob/master/src/transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py

Copy link
Collaborator Author

@ydshieh ydshieh Jan 6, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not very familiar with Processor class, but it seems to me for composite models, e.g. (Vision or Text) Encoder-Decoder models (not standalone like Bart or Marian), the encoder & decoder's feature extractors and/or tokenizers are not packed into a single class, at this moment.

We can discuss if it is good to create something like VisionEncoderDecoderProcessor, EncoderDecoderProcessor etc

Copy link
Contributor

@patrickvonplaten patrickvonplaten left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good for me for merge. Could combine tokenizer and feature extractor into a processor class but happy to do it in a follow-up PR

@patil-suraj patil-suraj merged commit 9f89fa0 into huggingface:master Jan 6, 2022
Albertobegue pushed a commit to Albertobegue/transformers that referenced this pull request Jan 27, 2022
* add image captioning example

* update README

* fix style & quality

* simplify

* apply review suggestions

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Apply review suggestions

* add comments about using np instead jax array

* remove unused lines

* add model creation script

* only support from_pretrained

* fix style

* fix

* not use cache_dir when creating model

* fix tokenizer creation

* update README

* fix quality

* apply suggestion

* simplify some blocks

* Update examples/flax/image-captioning/README.md


* Update examples/flax/image-captioning/run_image_captioning_flax.py

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* apply suggestion

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
@ydshieh ydshieh deleted the add_flax_example_image_captioning branch May 5, 2022 10:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants