Skip to content

Commit

Permalink
Add Flax image captioning example (#14864)
Browse files Browse the repository at this point in the history
* add image captioning example

* update README

* fix style & quality

* simplify

* apply review suggestions

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Apply suggestions from code review

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Apply review suggestions

* add comments about using np instead jax array

* remove unused lines

* add model creation script

* only support from_pretrained

* fix style

* fix

* not use cache_dir when creating model

* fix tokenizer creation

* update README

* fix quality

* apply suggestion

* simplify some blocks

* Update examples/flax/image-captioning/README.md


* Update examples/flax/image-captioning/run_image_captioning_flax.py

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* apply suggestion

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
  • Loading branch information
3 people committed Jan 6, 2022
1 parent 2e9af29 commit 9f89fa0
Show file tree
Hide file tree
Showing 3 changed files with 1,388 additions and 0 deletions.
68 changes: 68 additions & 0 deletions examples/flax/image-captioning/README.md
@@ -0,0 +1,68 @@
# Image Captioning (vision-encoder-text-decoder model) training example

The following example showcases how to finetune a vision-encoder-text-decoder model for image captioning
using the JAX/Flax backend, leveraging 馃 Transformers library's [FlaxVisionEncoderDecoderModel](https://huggingface.co/docs/transformers/model_doc/visionencoderdecoder#transformers.FlaxVisionEncoderDecoderModel).

JAX/Flax allows you to trace pure functions and compile them into efficient, fused accelerator code on both GPU and TPU.
Models written in JAX/Flax are **immutable** and updated in a purely functional
way which enables simple and efficient model parallelism.

`run_image_captioning_flax.py` is a lightweight example of how to download and preprocess a dataset from the 馃 Datasets
library or use your own files (jsonlines or csv), then fine-tune one of the architectures above on it.

For custom datasets in `jsonlines` format please see: https://huggingface.co/docs/datasets/loading_datasets.html#json-files and you also will find examples of these below.

### Download COCO dataset (2017)
This example uses COCO dataset (2017) through a custom dataset script, which requires users to manually download the
COCO dataset before training.

```bash
mkdir data
cd data
wget http://images.cocodataset.org/zips/train2017.zip
wget http://images.cocodataset.org/zips/val2017.zip
wget http://images.cocodataset.org/zips/test2017.zip
wget http://images.cocodataset.org/annotations/annotations_trainval2017.zip
wget http://images.cocodataset.org/annotations/image_info_test2017.zip
cd ..
```

### Create a model from a vision encoder model and a text decoder model
Next, we create a [FlaxVisionEncoderDecoderModel](https://huggingface.co/docs/transformers/model_doc/visionencoderdecoder#transformers.FlaxVisionEncoderDecoderModel) instance from a pre-trained vision encoder ([ViT](https://huggingface.co/docs/transformers/model_doc/vit#transformers.FlaxViTModel)) and a pre-trained text decoder ([GPT2](https://huggingface.co/docs/transformers/model_doc/gpt2#transformers.FlaxGPT2Model)):

```bash
python3 create_model_from_encoder_decoder_models.py \
--output_dir model \
--encoder_model_name_or_path google/vit-base-patch16-224-in21k \
--decoder_model_name_or_path gpt2
```

### Train the model
Finally, we can run the example script to train the model:

```bash
python3 run_image_captioning_flax.py \
--output_dir ./image-captioning-training-results \
--model_name_or_path model \
--dataset_name ydshieh/coco_dataset_script \
--dataset_config_name=2017 \
--data_dir $PWD/data \
--image_column image_path \
--caption_column caption \
--do_train --do_eval --predict_with_generate \
--num_train_epochs 1 \
--eval_steps 500 \
--learning_rate 3e-5 --warmup_steps 0 \
--per_device_train_batch_size 32 \
--per_device_eval_batch_size 32 \
--overwrite_output_dir \
--max_target_length 32 \
--num_beams 8 \
--preprocessing_num_workers 16 \
--logging_steps 10 \
--block_size 16384 \
--push_to_hub
```

This should finish in about 1h30 on Cloud TPU, with validation loss and ROUGE2 score of 2.0153 and 14.64 respectively
after 1 epoch. Training statistics can be accessed on [Models](https://huggingface.co/ydshieh/image-captioning-training-results/tensorboard).
@@ -0,0 +1,118 @@
#!/usr/bin/env python
# coding=utf-8
# Copyright 2022 The HuggingFace Team All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Create a VisionEncoderDecoderModel instance from pretrained encoder/decoder models.
The cross-attention will be randomly initialized.
"""

from dataclasses import dataclass, field
from typing import Optional

from transformers import (
AutoConfig,
AutoFeatureExtractor,
AutoTokenizer,
FlaxVisionEncoderDecoderModel,
HfArgumentParser,
)


@dataclass
class ModelArguments:
"""
Arguments pertaining to which model/config/tokenizer we are going to fine-tune, or train from scratch.
"""

output_dir: str = field(
metadata={"help": "The output directory where the model will be written."},
)
encoder_model_name_or_path: str = field(
metadata={
"help": "The encoder model checkpoint for weights initialization."
"Don't set if you want to train an encoder model from scratch."
},
)
decoder_model_name_or_path: str = field(
metadata={
"help": "The decoder model checkpoint for weights initialization."
"Don't set if you want to train a decoder model from scratch."
},
)
encoder_config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained encoder config name or path if not the same as encoder_model_name"}
)
decoder_config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained decoder config name or path if not the same as decoder_model_name"}
)


def main():
parser = HfArgumentParser((ModelArguments,))
(model_args,) = parser.parse_args_into_dataclasses()

# Load pretrained model and tokenizer

# Use explicit specified encoder config
if model_args.encoder_config_name:
encoder_config = AutoConfig.from_pretrained(model_args.encoder_config_name)
# Use pretrained encoder model's config
else:
encoder_config = AutoConfig.from_pretrained(model_args.encoder_model_name_or_path)

# Use explicit specified decoder config
if model_args.decoder_config_name:
decoder_config = AutoConfig.from_pretrained(model_args.decoder_config_name)
# Use pretrained decoder model's config
else:
decoder_config = AutoConfig.from_pretrained(model_args.decoder_model_name_or_path)

# necessary for `from_encoder_decoder_pretrained` when `decoder_config` is passed
decoder_config.is_decoder = True
decoder_config.add_cross_attention = True

model = FlaxVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
encoder_pretrained_model_name_or_path=model_args.encoder_model_name_or_path,
decoder_pretrained_model_name_or_path=model_args.decoder_model_name_or_path,
encoder_config=encoder_config,
decoder_config=decoder_config,
)

# GPT2 only has bos/eos tokens but not decoder_start/pad tokens
decoder_start_token_id = decoder_config.decoder_start_token_id
pad_token_id = decoder_config.pad_token_id
if decoder_start_token_id is None:
decoder_start_token_id = decoder_config.bos_token_id
if pad_token_id is None:
pad_token_id = decoder_config.eos_token_id

# This is necessary to make Flax's generate() work
model.config.eos_token_id = decoder_config.eos_token_id
model.config.decoder_start_token_id = decoder_start_token_id
model.config.pad_token_id = pad_token_id

feature_extractor = AutoFeatureExtractor.from_pretrained(model_args.encoder_model_name_or_path)

tokenizer = AutoTokenizer.from_pretrained(model_args.decoder_model_name_or_path)
tokenizer.pad_token = tokenizer.convert_ids_to_tokens(model.config.pad_token_id)

model.save_pretrained(model_args.output_dir)
feature_extractor.save_pretrained(model_args.output_dir)
tokenizer.save_pretrained(model_args.output_dir)


if __name__ == "__main__":
main()

0 comments on commit 9f89fa0

Please sign in to comment.