Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add Flax image captioning example (#14864)
* add image captioning example * update README * fix style & quality * simplify * apply review suggestions * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> * Apply suggestions from code review Co-authored-by: Suraj Patil <surajp815@gmail.com> * Apply review suggestions * add comments about using np instead jax array * remove unused lines * add model creation script * only support from_pretrained * fix style * fix * not use cache_dir when creating model * fix tokenizer creation * update README * fix quality * apply suggestion * simplify some blocks * Update examples/flax/image-captioning/README.md * Update examples/flax/image-captioning/run_image_captioning_flax.py Co-authored-by: Suraj Patil <surajp815@gmail.com> * apply suggestion Co-authored-by: ydshieh <ydshieh@users.noreply.github.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
- Loading branch information
1 parent
2e9af29
commit 9f89fa0
Showing
3 changed files
with
1,388 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
# Image Captioning (vision-encoder-text-decoder model) training example | ||
|
||
The following example showcases how to finetune a vision-encoder-text-decoder model for image captioning | ||
using the JAX/Flax backend, leveraging 馃 Transformers library's [FlaxVisionEncoderDecoderModel](https://huggingface.co/docs/transformers/model_doc/visionencoderdecoder#transformers.FlaxVisionEncoderDecoderModel). | ||
|
||
JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU. | ||
Models written in JAX/Flax are **immutable** and updated in a purely functional | ||
way which enables simple and efficient model parallelism. | ||
|
||
`run_image_captioning_flax.py` is a lightweight example of how to download and preprocess a dataset from the 馃 Datasets | ||
library or use your own files (jsonlines or csv), then fine-tune one of the architectures above on it. | ||
|
||
For custom datasets in `jsonlines` format please see: https://huggingface.co/docs/datasets/loading_datasets.html#json-files and you also will find examples of these below. | ||
|
||
### Download COCO dataset (2017) | ||
This example uses COCO dataset (2017) through a custom dataset script, which requires users to manually download the | ||
COCO dataset before training. | ||
|
||
```bash | ||
mkdir data | ||
cd data | ||
wget http://images.cocodataset.org/zips/train2017.zip | ||
wget http://images.cocodataset.org/zips/val2017.zip | ||
wget http://images.cocodataset.org/zips/test2017.zip | ||
wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip | ||
wget http://images.cocodataset.org/annotations/image_info_test2017.zip | ||
cd .. | ||
``` | ||
|
||
### Create a model from a vision encoder model and a text decoder model | ||
Next, we create a [FlaxVisionEncoderDecoderModel](https://huggingface.co/docs/transformers/model_doc/visionencoderdecoder#transformers.FlaxVisionEncoderDecoderModel) instance from a pre-trained vision encoder ([ViT](https://huggingface.co/docs/transformers/model_doc/vit#transformers.FlaxViTModel)) and a pre-trained text decoder ([GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.FlaxGPT2Model)): | ||
|
||
```bash | ||
python3 create_model_from_encoder_decoder_models.py \ | ||
--output_dir model \ | ||
--encoder_model_name_or_path google/vit-base-patch16-224-in21k \ | ||
--decoder_model_name_or_path gpt2 | ||
``` | ||
|
||
### Train the model | ||
Finally, we can run the example script to train the model: | ||
|
||
```bash | ||
python3 run_image_captioning_flax.py \ | ||
--output_dir ./image-captioning-training-results \ | ||
--model_name_or_path model \ | ||
--dataset_name ydshieh/coco_dataset_script \ | ||
--dataset_config_name=2017 \ | ||
--data_dir $PWD/data \ | ||
--image_column image_path \ | ||
--caption_column caption \ | ||
--do_train --do_eval --predict_with_generate \ | ||
--num_train_epochs 1 \ | ||
--eval_steps 500 \ | ||
--learning_rate 3e-5 --warmup_steps 0 \ | ||
--per_device_train_batch_size 32 \ | ||
--per_device_eval_batch_size 32 \ | ||
--overwrite_output_dir \ | ||
--max_target_length 32 \ | ||
--num_beams 8 \ | ||
--preprocessing_num_workers 16 \ | ||
--logging_steps 10 \ | ||
--block_size 16384 \ | ||
--push_to_hub | ||
``` | ||
|
||
This should finish in about 1h30 on Cloud TPU, with validation loss and ROUGE2 score of 2.0153 and 14.64 respectively | ||
after 1 epoch. Training statistics can be accessed on [Models](https://huggingface.co/ydshieh/image-captioning-training-results/tensorboard). |
118 changes: 118 additions & 0 deletions
118
examples/flax/image-captioning/create_model_from_encoder_decoder_models.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
#!/usr/bin/env python | ||
# coding=utf-8 | ||
# Copyright 2022 The HuggingFace Team All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
""" | ||
Create a VisionEncoderDecoderModel instance from pretrained encoder/decoder models. | ||
The cross-attention will be randomly initialized. | ||
""" | ||
|
||
from dataclasses import dataclass, field | ||
from typing import Optional | ||
|
||
from transformers import ( | ||
AutoConfig, | ||
AutoFeatureExtractor, | ||
AutoTokenizer, | ||
FlaxVisionEncoderDecoderModel, | ||
HfArgumentParser, | ||
) | ||
|
||
|
||
@dataclass | ||
class ModelArguments: | ||
""" | ||
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch. | ||
""" | ||
|
||
output_dir: str = field( | ||
metadata={"help": "The output directory where the model will be written."}, | ||
) | ||
encoder_model_name_or_path: str = field( | ||
metadata={ | ||
"help": "The encoder model checkpoint for weights initialization." | ||
"Don't set if you want to train an encoder model from scratch." | ||
}, | ||
) | ||
decoder_model_name_or_path: str = field( | ||
metadata={ | ||
"help": "The decoder model checkpoint for weights initialization." | ||
"Don't set if you want to train a decoder model from scratch." | ||
}, | ||
) | ||
encoder_config_name: Optional[str] = field( | ||
default=None, metadata={"help": "Pretrained encoder config name or path if not the same as encoder_model_name"} | ||
) | ||
decoder_config_name: Optional[str] = field( | ||
default=None, metadata={"help": "Pretrained decoder config name or path if not the same as decoder_model_name"} | ||
) | ||
|
||
|
||
def main(): | ||
parser = HfArgumentParser((ModelArguments,)) | ||
(model_args,) = parser.parse_args_into_dataclasses() | ||
|
||
# Load pretrained model and tokenizer | ||
|
||
# Use explicit specified encoder config | ||
if model_args.encoder_config_name: | ||
encoder_config = AutoConfig.from_pretrained(model_args.encoder_config_name) | ||
# Use pretrained encoder model's config | ||
else: | ||
encoder_config = AutoConfig.from_pretrained(model_args.encoder_model_name_or_path) | ||
|
||
# Use explicit specified decoder config | ||
if model_args.decoder_config_name: | ||
decoder_config = AutoConfig.from_pretrained(model_args.decoder_config_name) | ||
# Use pretrained decoder model's config | ||
else: | ||
decoder_config = AutoConfig.from_pretrained(model_args.decoder_model_name_or_path) | ||
|
||
# necessary for `from_encoder_decoder_pretrained` when `decoder_config` is passed | ||
decoder_config.is_decoder = True | ||
decoder_config.add_cross_attention = True | ||
|
||
model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained( | ||
encoder_pretrained_model_name_or_path=model_args.encoder_model_name_or_path, | ||
decoder_pretrained_model_name_or_path=model_args.decoder_model_name_or_path, | ||
encoder_config=encoder_config, | ||
decoder_config=decoder_config, | ||
) | ||
|
||
# GPT2 only has bos/eos tokens but not decoder_start/pad tokens | ||
decoder_start_token_id = decoder_config.decoder_start_token_id | ||
pad_token_id = decoder_config.pad_token_id | ||
if decoder_start_token_id is None: | ||
decoder_start_token_id = decoder_config.bos_token_id | ||
if pad_token_id is None: | ||
pad_token_id = decoder_config.eos_token_id | ||
|
||
# This is necessary to make Flax's generate() work | ||
model.config.eos_token_id = decoder_config.eos_token_id | ||
model.config.decoder_start_token_id = decoder_start_token_id | ||
model.config.pad_token_id = pad_token_id | ||
|
||
feature_extractor = AutoFeatureExtractor.from_pretrained(model_args.encoder_model_name_or_path) | ||
|
||
tokenizer = AutoTokenizer.from_pretrained(model_args.decoder_model_name_or_path) | ||
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(model.config.pad_token_id) | ||
|
||
model.save_pretrained(model_args.output_dir) | ||
feature_extractor.save_pretrained(model_args.output_dir) | ||
tokenizer.save_pretrained(model_args.output_dir) | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.