New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Flax image captioning example #14864
Add Flax image captioning example #14864
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great work @ydshieh !
The examples look nice. I left a few comments below, specifically the initialization logic looks quite complex, it would be nice if we could keep it simple so the script would be easier to read.
Maybe just not support training from scratch if it's not useful and adds a lot of code. It's important to keep the script simple so users could understand it and modify it to their needs.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for your PR!
I think this example tries to do too much in one file and should be made easier to read. Examples are not supposed to do everything for every possible model, and here there are way too many lines just for the model creation which makes it super hard to read. Users would benefit more from simpler code that they can understand and customize IMO.
Thank you for the comments. I will make the example much simpler by only supporting loading pre-trained vision encoder and text models. About casting jax array to np arrary, there is a significant slow down (at least, in image examples) when using jax array as indices for accessing |
807c92f
to
650fb4a
Compare
Hi, @patil-suraj @sgugger I simplified the config/model initialization parts (only support loading pretrained encoder & decoder). For @patil-suraj About using For this line
takes 30 seconds (for selecting For this line (take 256 elements - with image data)
jax.numpy: 0.45 second / numpy: 0.10 - 0.15 second A singe training step (global batch size: 256 images) takes < 0.5 seconds on TPU. Due to this significant differences in processing speed, I think it is worth keeping using |
# Replicate the train state on each device | ||
state = state.replicate() | ||
|
||
if training_args.do_train: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
think we can reduce the lines quite significantly here by making use of somehow using f"Num {dataset.keys()}"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi, the keys()
is available for 'datasets.dataset_dict.DatasetDict'
(the one with the different splits), and it gives dict_keys(['train', 'validation', 'test'])
.
And after taking the splits, like train_dataset = dataset['train']
, it becomes 'datasets.arrow_dataset.Dataset'
and there is no more keys()
method.
It's not clear to me how to use it at this place.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks already really nice. I think we should try to make the notebook a bit easier (maybe at the expense of not covering every edge case) and try to shorten some of the code a bit.
Co-authored-by: Suraj Patil <surajp815@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's looking good now! I left a few more comments. It would be nice if we could just use from_pretrained
instead of from_encoder_decoder_pretrained
as @patrickvonplaten suggested.
def decay_mask_fn(params): | ||
flat_params = traverse_util.flatten_dict(params) | ||
layer_norm_params = [ | ||
(name, "scale") for name in ["self_attn_layer_norm", "layernorm_embedding", "final_layer_norm"] | ||
] | ||
flat_mask = {path: (path[-1] != "bias" and path[-2:] not in layer_norm_params) for path in flat_params} | ||
return traverse_util.unflatten_dict(flat_mask) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the default decoder model, is it bart or gpt2 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I used GPT2 for my image captioning training. (no default value in the argument)
Co-authored-by: Suraj Patil <surajp815@gmail.com>
…m/ydshieh/transformers into add_flax_example_image_captioning
Hi
Thanks for the reviews :-) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Great work @ydshieh and thanks a lot for being patient with the review :)
@patrickvonplaten do you want to take another look?
""" | ||
Create a VisionEncoderDecoderModel instance from pretrained encoder/decoder models. | ||
|
||
The cross-attention will be randomly initialized. | ||
""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice!
# We use `numpy.ndarray` to interact with `datasets.Dataset`, since using `jax.numpy.array` to index into a | ||
# dataset is significantly slow. Using JAX array at the 1st place is only to keep JAX's PRNGs generation | ||
# mechanism, which works differently from NumPy/SciPy. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice comment!
Co-authored-by: Suraj Patil <surajp815@gmail.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
model.config.decoder_start_token_id = decoder_start_token_id | ||
model.config.pad_token_id = pad_token_id | ||
|
||
feature_extractor = AutoFeatureExtractor.from_pretrained(model_args.encoder_model_name_or_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) Is there no ...Processor
class? Would be nice to save with a Processor class so that it can be loaded with AutoProcessor
. But I'm fine with doing it in a follow-up PR
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could probably add a generic processor class that can accept any tokenizer and feature extractor, we have one such processor for the VisionTextDudalEncoder
model https://github.com/huggingface/transformers/blob/master/src/transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not very familiar with Processor class, but it seems to me for composite models, e.g. (Vision or Text) Encoder-Decoder models (not standalone like Bart or Marian), the encoder & decoder's feature extractors and/or tokenizers are not packed into a single class, at this moment.
We can discuss if it is good to create something like VisionEncoderDecoderProcessor
, EncoderDecoderProcessor
etc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good for me for merge. Could combine tokenizer and feature extractor into a processor class but happy to do it in a follow-up PR
* add image captioning example * update README * fix style & quality * simplify * apply review suggestions * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> * Apply review suggestions * add comments about using np instead jax array * remove unused lines * add model creation script * only support from_pretrained * fix style * fix * not use cache_dir when creating model * fix tokenizer creation * update README * fix quality * apply suggestion * simplify some blocks * Update examples/flax/image-captioning/README.md * Update examples/flax/image-captioning/run_image_captioning_flax.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * apply suggestion Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
What does this PR do?
Add
run_image_captioning_flax.py
(modified fromrun_summarization_flax.py
).Who can review
Examples: @patil-suraj + cc @patrickvonplaten @NielsRogge @sgugger for info