Skip to content

Commit

Permalink
[NeuralChat] Enabled image2text finetuning and added an example (#1372)
Browse files Browse the repository at this point in the history
* Enabled image2text finetuning and added an example.

Signed-off-by: Ye, Xinyu <xinyu.ye@intel.com>
  • Loading branch information
XinyuYe-Intel committed Mar 14, 2024
1 parent 7539c35 commit ef94aea
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 5 deletions.
21 changes: 20 additions & 1 deletion intel_extension_for_transformers/llm/finetuning/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import datasets
import re
from itertools import chain
from transformers import AutoProcessor
from intel_extension_for_transformers.neural_chat.prompts.prompt import PromptTemplate

IGNORE_INDEX = -100
Expand Down Expand Up @@ -398,7 +399,21 @@ def preprocess_function(examples):
return preprocess_function


def preprocess_dataset(raw_datasets, tokenizer, data_args, finetune_args):
class ImageCaptioningDataPreprocess:
def tokenize_func(self, tokenizer, data_args, finetune_args, model_args):
processor = AutoProcessor.from_pretrained(model_args.model_name_or_path)
def preprocess_function(examples):
encodings = processor(
images=examples[data_args.image_column], text=examples[data_args.caption_column],
padding=True, return_tensors="pt", truncation=True
)
encodings.update({"labels": encodings["input_ids"]})
return encodings

return preprocess_function


def preprocess_dataset(raw_datasets, tokenizer, data_args, finetune_args, model_args=None):

if data_args.dataset_name == "Intel/orca_dpo_pairs":
preprocess = IntelDpoDataPreprocess(
Expand All @@ -417,6 +432,10 @@ def preprocess_dataset(raw_datasets, tokenizer, data_args, finetune_args):

preprocess_fn = preprocess.tokenize_func(tokenizer, data_args, finetune_args)

elif finetune_args.task == "image2text":
preprocess = ImageCaptioningDataPreprocess()
preprocess_fn = preprocess.tokenize_func(tokenizer, data_args, finetune_args, model_args)

elif finetune_args.task == "chat":
preprocess = ChatDataPreprocess(tokenizer.eos_token)
new_datasets = datasets.DatasetDict()
Expand Down
31 changes: 30 additions & 1 deletion intel_extension_for_transformers/llm/finetuning/finetuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,33 @@ def load_dataset(self, data_args, model_args, training_args):
token=model_args.token,
streaming=data_args.streaming,
)
elif data_args.train_dir is not None:
data_files = {}
if data_args.train_dir is not None:
data_files["train"] = os.path.join(data_args.train_dir, "**")
if data_args.validation_dir is not None:
data_files["validation"] = os.path.join(data_args.validation_dir, "**")
raw_datasets = load_dataset(
"imagefolder",
data_files=data_files,
cache_dir=model_args.cache_dir,
)

# If no validation data is there, validation_split_percentage will be used to divide the dataset.
if "validation" not in raw_datasets.keys() and training_args.do_eval and \
data_args.validation_split_percentage > 0:
raw_datasets["validation"] = load_dataset(
"imagefolder",
data_files=data_files,
split=f"train[:{data_args.validation_split_percentage}%]",
cache_dir=model_args.cache_dir,
)
raw_datasets["train"] = load_dataset(
"imagefolder",
data_files=data_files,
split=f"train[{data_args.validation_split_percentage}%:]",
cache_dir=model_args.cache_dir,
)
else:
data_files = {}
dataset_args = {}
Expand Down Expand Up @@ -437,7 +464,9 @@ def finetune_clm(self, model_args, data_args, training_args, finetune_args, conf
if tokenizer.pad_token_id is None:
tokenizer.pad_token_id = tokenizer.eos_token_id

raw_datasets, preprocess_function = preprocess_dataset(raw_datasets, tokenizer, data_args, finetune_args)
raw_datasets, preprocess_function = preprocess_dataset(
raw_datasets, tokenizer, data_args, finetune_args, model_args
)
column_names = list(raw_datasets["train"].features)

with training_args.main_process_first(desc="dataset map pre-processing"):
Expand Down
17 changes: 14 additions & 3 deletions intel_extension_for_transformers/neural_chat/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,16 @@ class DataArguments:
"help": "An optional input evaluation data file to evaluate the perplexity on (a text file)."
},
)
train_dir: Optional[str] = field(default=None, metadata={"help": "A folder containing the training data."})
validation_dir: Optional[str] = field(default=None, metadata={"help": "A folder containing the validation data."})
image_column: Optional[str] = field(
default="image",
metadata={"help": "The column of the dataset containing an image or a list of images."}
)
caption_column: Optional[str] = field(
default="text",
metadata={"help": "The column of the dataset containing a caption or a list of captions."}
)
max_seq_length: Optional[int] = field(
default=512,
metadata={
Expand Down Expand Up @@ -248,8 +258,9 @@ def __post_init__(self):
if self.streaming:
require_version("datasets>=2.0.0", "The streaming feature requires `datasets>=2.0.0`")

if self.dataset_name is None and self.train_file is None and self.validation_file is None:
raise ValueError("Need either a dataset name or a training/validation file.")
if self.dataset_name is None and self.train_file is None and self.validation_file is None and \
self.train_dir is None:
raise ValueError("Need either a dataset name, a training/validation file or a train_dir.")
else:
if self.train_file is not None:
extension = self.train_file.split(".")[-1]
Expand Down Expand Up @@ -326,7 +337,7 @@ class FinetuningArguments:
task: Optional[str] = field(
default="completion",
metadata={"help": "task name, different task means different data format.",
"choices": ["completion", "chat", "summarization", "code-generation"]
"choices": ["completion", "chat", "summarization", "code-generation", "image2text"]
},
)
eval_ppl: bool = field(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
NeuralChat Fine-tuning
============

This example demonstrates how to finetune the pretrained generative image-to-text model on customized dataset.

# Prerequisite​

## 1. Environment​
### Bare Metal
Recommend python 3.9 or higher version.
```shell
pip install -r requirements.txt
pip install transformers==4.34.1
# To use ccl as the distributed backend in distributed training on CPU requires to install below requirement.
python -m pip install oneccl_bind_pt==2.2.0 -f https://developer.intel.com/ipex-whl-stable-cpu
```
>**Note**: Suggest using transformers no higher than 4.34.1
### Docker
Pick either one of below options to setup docker environment.
#### Option 1 : Build Docker image from scratch
Please refer to this section : [How to build docker images for NeuralChat FineTuning](../../../docker/finetuning/README.md#21-build-docker-image) to build docker image from scratch.

#### Option 2: Pull existing Docker image
Please follow the session [itrex docker setup](../../../docker/finetuning/README.md#22-docker-pull-from-docker-hub) and use the docker pull command to pull itrex docker image.


Once you have the docker image ready, please follow [run docker image](../../../docker/finetuning/README.md#3-create-docker-container) session to launch a docker instance from the image.


## 2. Prepare the Model

#### microsoft/git-base
To acquire the checkpoints and tokenizer, the user can get those files from [microsoft/git-base](https://huggingface.co/microsoft/git-base).
Users could follow below commands to get the checkpoints from github repository after the access request to the files is approved.
```bash
git lfs install
git clone https://huggingface.co/microsoft/git-base
```

## 3. Prepare Dataset

For datasets exist in the Hugging Face Hub, user can use `dataset_name` argument to pass in the needed dataset.
For local datasets, user can follow this [guide](https://huggingface.co/docs/datasets/v2.18.0/en/image_dataset#image-captioning) from datasets' official document to create a metadata file that contain image and text pairs, than use `train_dir` and optionally `validation_dir` to pass in the path to the needed dataset.

### Dataset related arguments
- **dataset_name**: The name of the dataset to use (via the datasets library).
- **dataset_config_name**: The configuration name of the dataset to use (via the datasets library).
- **train_dir**: A folder containing the training data.
- **validation_dir**: A folder containing the validation data.
- **image_column**: The column of the dataset containing an image or a list of images.
- **caption_column**: The column of the dataset containing a caption or a list of captions.
- **validation_split_percentage**: The percentage of the train set used as validation set in case there's no validation split.

# Finetune

Use the below command line for finetuning `microsoft/git-base` model on the `gaodrew/roco-65k-256px` dataset.

```bash
python finetune_clm.py \
--model_name_or_path "microsoft/git-base" \
--bf16 True \
--dataset_name "gaodrew/roco-65k-256px" \
--per_device_train_batch_size 8 \
--per_device_eval_batch_size 8 \
--gradient_accumulation_steps 1 \
--do_train \
--learning_rate 1e-4 \
--num_train_epochs 3 \
--logging_steps 100 \
--save_total_limit 2 \
--overwrite_output_dir \
--log_level info \
--save_strategy epoch \
--output_dir ./git-base_finetuned_model \
--task image2text \
--full_finetune \
--bits 16
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
# !/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
from transformers import TrainingArguments, HfArgumentParser
from intel_extension_for_transformers.neural_chat.config import (
ModelArguments,
DataArguments,
FinetuningArguments,
BaseFinetuningConfig,
)
from intel_extension_for_transformers.neural_chat.chatbot import finetune_model
from intel_extension_for_transformers.utils.device_utils import is_hpu_available

def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.
if not is_hpu_available:
parser = HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments, FinetuningArguments)
)
else:
from optimum.habana import GaudiTrainingArguments

parser = HfArgumentParser(
(ModelArguments, DataArguments, GaudiTrainingArguments, FinetuningArguments)
)

if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
model_args, data_args, training_args, finetune_args = parser.parse_json_file(
json_file=os.path.abspath(sys.argv[1])
)
else:
(
model_args,
data_args,
training_args,
finetune_args,
) = parser.parse_args_into_dataclasses()

finetune_cfg = BaseFinetuningConfig(
model_args=model_args,
data_args=data_args,
training_args=training_args,
finetune_args=finetune_args,
)
finetune_model(finetune_cfg)

if __name__ == "__main__":
main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
datasets
einops
evaluate
fastapi
nltk
peft
pydub
python-multipart
rouge_score
sentencepiece
shortuuid
torch==2.2.0
transformers
uvicorn
yacs

0 comments on commit ef94aea

Please sign in to comment.