Skip to content

Commit

Permalink
rewrite finetuning data preprocessing with static shape for Gaudi2. (#…
Browse files Browse the repository at this point in the history
…1212)

* rewrite finetuning data preprocessing with static shape for Gaudi2.

Co-authored-by: VincyZhang <wenxin.zhang@intel.com>
  • Loading branch information
lkk12014402 and VincyZhang committed Jan 31, 2024
1 parent 5fd9566 commit 3f62ceb
Show file tree
Hide file tree
Showing 12 changed files with 432 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,6 @@ New options to note:
- `--image_aspect_ratio pad`: this pads the non-square images to square, instead of cropping them; it slightly reduces hallucination.
- `--group_by_modality_length True`: this should only be used when your instruction tuning dataset contains both language (e.g. ShareGPT) and multimodal (e.g. LLaVA-Instruct). It makes the training sampler only sample a single modality (either image or language) during training, which we observe to speed up training by ~25%, and does not affect the final outcome.
- `--use_habana, --use_lazy_mode` for Intel Gaudi2 setting.
- For finetuning stage, when using Intel Gaudi2, `--pad_max True` should be set, which will pad input sequence length (text + image patches) to `--model_max_length`.

**Note:** If don't set `--use_habana, --use_lazy_mode`, the code can also run on gpus as well.
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
# !/usr/bin/env python
# -*- coding: utf-8 -*-
#
# Copyright (c) 2023 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
A simple launcher script for distributed training on HPUs.
Single node:
::
>>> python gaudi_spawn.py --world_size=NUM_CARDS_YOU_HAVE --use_mpi
YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other
arguments of your training script)
Multi node:
::
>>> python gaudi_spawn.py --hostfile=PATH_TO_HOSTFILE --use_deepspeed
YOUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other
arguments of your training script)
"""


import sys
from argparse import REMAINDER, ArgumentParser

from optimum.habana.distributed import DistributedRunner
from optimum.utils import logging


logger = logging.get_logger(__name__)


def parse_args():
"""
Helper function parsing the command line options.
@retval ArgumentParser
"""
parser = ArgumentParser(
description=(
"Habana Gaudi distributed training launch helper utility that will spawn up multiple distributed"
" processes."
)
)

# Optional arguments for the launch helper
parser.add_argument("--world_size", type=int, default=1, help="Number of HPUs to use (1 or 8)")
parser.add_argument("--hostfile", type=str, default=None, help="Path to the file where hosts are specified.")
parser.add_argument("--use_mpi", action="store_true", help="Use MPI for distributed training")
parser.add_argument("--use_deepspeed", action="store_true", help="Use DeepSpeed for distributed training")

# positional
parser.add_argument(
"training_script",
type=str,
help=(
"The full path to the single HPU training "
"program/script to be launched in parallel, "
"followed by all the arguments for the "
"training script."
),
)

# rest from the training program
parser.add_argument("training_script_args", nargs=REMAINDER)

return parser.parse_args()


def main():
args = parse_args()

if args.use_deepspeed:
from transformers.deepspeed import is_deepspeed_available

if not is_deepspeed_available():
raise ImportError(
"--use_deepspeed requires deepspeed: `pip install"
" git+https://github.com/HabanaAI/DeepSpeed.git@1.10.0`."
)

# Patch sys.argv
sys.argv = [args.training_script] + args.training_script_args
# Handle the case where arguments contain whitespaces
argv = ['"{}"'.format(arg) if " " in arg and arg[0] != '"' and arg[-1] != '"' else arg for arg in sys.argv]
command_list = [" ".join(argv)]

distributed_runner = DistributedRunner(
command_list=command_list,
world_size=args.world_size,
hostfile=args.hostfile,
use_mpi=args.use_mpi,
use_deepspeed=args.use_deepspeed,
)

ret_code = distributed_runner.run()
sys.exit(ret_code)


if __name__ == "__main__":
main()

Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,7 @@ def preprocess_plain(
tokenized_len = len(tokenizer_image_token(source[0]['value'], tokenizer))
target[:tokenized_len] = IGNORE_INDEX

return dict(input_ids=input_ids, labels=targets)
return dict(input_ids=torch.stack(input_ids, dim=0), labels=torch.stack(targets, dim=0))


def preprocess(
Expand Down Expand Up @@ -496,9 +496,123 @@ def expand2square(pil_img, background_color):
self.tokenizer,
self.conversation_template,
has_image=('image' in self.list_data_dict[i]))

data_dict["attention_mask"] = data_dict["input_ids"].ne(self.tokenizer.pad_token_id)

if isinstance(i, int):
data_dict = dict(input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0])
labels=data_dict["labels"][0],
attention_mask=data_dict["attention_mask"][0])

# image exist in the data
if 'image' in self.list_data_dict[i]:
data_dict['images'] = image
elif self.data_args.is_multimodal:
# image does not exist in the data, but the model is multimodal
crop_size = self.data_args.image_processor.crop_size
data_dict['images'] = torch.zeros(3, crop_size['height'], crop_size['width'])
return data_dict


class LazySupervisedDatasetPadding(LazySupervisedDataset):
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
sources = self.list_data_dict[i]
if isinstance(i, int):
sources = [sources]
assert len(sources) == 1, "Don't know why it is wrapped to a list" # FIXME
if 'image' in sources[0]:
image_file = self.list_data_dict[i]['image']
image_folder = self.data_args.image_folder
processor = self.data_args.image_processor
image = Image.open(os.path.join(image_folder, image_file)).convert('RGB')
if self.data_args.image_aspect_ratio == 'pad':
def expand2square(pil_img, background_color):
width, height = pil_img.size
if width == height:
return pil_img
elif width > height:
result = Image.new(pil_img.mode, (width, width), background_color)
result.paste(pil_img, (0, (width - height) // 2))
return result
else:
result = Image.new(pil_img.mode, (height, height), background_color)
result.paste(pil_img, ((height - width) // 2, 0))
return result
image = expand2square(image, tuple(int(x*255) for x in processor.image_mean))
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
else:
image = processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
sources = preprocess_multimodal(
copy.deepcopy([e["conversations"] for e in sources]),
self.data_args, self.conversation_template)
else:
sources = copy.deepcopy([e["conversations"] for e in sources])
data_dict = preprocess(
sources,
self.tokenizer,
self.conversation_template,
has_image=('image' in self.list_data_dict[i]))

input_ids = data_dict["input_ids"].tolist()
labels = data_dict["labels"].tolist()

# fill image placeholder & pre-padding
padded_input_ids = []
padded_targets = []
images_mask = []
attention_mask = []
for inp, tar in zip(input_ids, labels):
new_inp = []
new_tar = []
image_mask = []
for ele_inp, ele_tar in zip(inp, tar):
if ele_inp == IMAGE_TOKEN_INDEX:
# fill image placeholder with pad token
new_inp.extend([IMAGE_TOKEN_INDEX] * self.data_args.mm_im_patchs)
new_tar.extend([IGNORE_INDEX] * self.data_args.mm_im_patchs)
image_mask.extend([1] * self.data_args.mm_im_patchs)
else:
new_inp.append(ele_inp)
new_tar.append(ele_tar)
image_mask.append(0)

attn_mask = [1] * len(new_inp)

if len(new_inp) >= self.tokenizer.model_max_length:
new_inp = new_inp[:self.tokenizer.model_max_length]
new_tar = new_tar[:self.tokenizer.model_max_length]
image_mask = image_mask[:self.tokenizer.model_max_length]
attn_mask = attn_mask[:self.tokenizer.model_max_length]
else:
# padding
inp_len = len(new_inp)
pad_len = self.tokenizer.model_max_length - inp_len
new_inp = new_inp + [self.tokenizer.pad_token_id] * pad_len
new_tar = new_tar + [IGNORE_INDEX] * pad_len
image_mask = image_mask + [0] * pad_len
attn_mask = attn_mask + [0] * pad_len

assert len(new_inp) == len(new_tar) == self.tokenizer.model_max_length

image_mask = torch.tensor(image_mask)

new_inp = torch.tensor(new_inp)[torch.where(image_mask!=1)].tolist()

padded_input_ids.append(new_inp)
padded_targets.append(new_tar)
images_mask.append(image_mask)
attention_mask.append(attn_mask)

data_dict.update({"input_ids": torch.tensor(padded_input_ids),
"labels": torch.tensor(padded_targets),
"images_mask": images_mask,
"attention_mask": torch.tensor(attention_mask)})

if isinstance(i, int):
data_dict = dict(input_ids=data_dict["input_ids"][0],
labels=data_dict["labels"][0],
images_mask=data_dict["images_mask"][0],
attention_mask=data_dict["attention_mask"][0])

# image exist in the data
if 'image' in self.list_data_dict[i]:
Expand All @@ -517,21 +631,27 @@ class DataCollatorForSupervisedDataset(object):
tokenizer: transformers.PreTrainedTokenizer

def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
input_ids, labels = tuple([instance[key] for instance in instances]
for key in ("input_ids", "labels"))
input_ids, labels, attention_mask = tuple([instance[key] for instance in instances]
for key in ("input_ids", "labels", "attention_mask"))
input_ids = torch.nn.utils.rnn.pad_sequence(
input_ids,
batch_first=True,
padding_value=self.tokenizer.pad_token_id)
labels = torch.nn.utils.rnn.pad_sequence(labels,
batch_first=True,
padding_value=IGNORE_INDEX)
attention_mask = torch.nn.utils.rnn.pad_sequence(
attention_mask,
batch_first=True,
padding_value=0)

input_ids = input_ids[:, :self.tokenizer.model_max_length]
labels = labels[:, :self.tokenizer.model_max_length]
attention_mask = attention_mask[:, :self.tokenizer.model_max_length]
batch = dict(
input_ids=input_ids,
labels=labels,
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
attention_mask=attention_mask,
)

if 'images' in instances[0]:
Expand All @@ -541,17 +661,27 @@ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
else:
batch['images'] = images

if 'images_mask' in instances[0]:
images_mask = [instance['images_mask'] for instance in instances]
batch['images_mask'] = torch.stack(images_mask)

return batch


def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
data_args) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
conversation_template = conversation_utils.conv_templates[data_args.template]
train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
data_path=data_args.data_path,
data_args=data_args,
conversation_template=conversation_template)
if not data_args.pad_max:
train_dataset = LazySupervisedDataset(tokenizer=tokenizer,
data_path=data_args.data_path,
data_args=data_args,
conversation_template=conversation_template)
else:
train_dataset = LazySupervisedDatasetPadding(tokenizer=tokenizer,
data_path=data_args.data_path,
data_args=data_args,
conversation_template=conversation_template)
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
return dict(train_dataset=train_dataset,
eval_dataset=None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

deepspeed train.py \
PT_HPU_MAX_COMPOUND_OP_SIZE=10 DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED=1 \
python3 ./gaudi_spawn.py --use_deepspeed --world_size 4 \
train.py \
--deepspeed ./scripts/zero3.json \
--model_name_or_path mistralai/Mistral-7B-v0.1 \
--template v1 \
Expand All @@ -28,7 +30,7 @@ deepspeed train.py \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--image_aspect_ratio pad \
--group_by_modality_length True \
--group_by_modality_length False \
--bf16 True \
--output_dir ./checkpoints/llava-v1.5-13b \
--num_train_epochs 1 \
Expand All @@ -45,6 +47,7 @@ deepspeed train.py \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--model_max_length 2048 \
--pad_max True \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,23 @@
# See the License for the specific language governing permissions and
# limitations under the License.

deepspeed train.py \
PT_HPU_MAX_COMPOUND_OP_SIZE=10 DEEPSPEED_HPU_ZERO3_SYNC_MARK_STEP_REQUIRED=1 \
python3 ./gaudi_spawn.py --use_deepspeed --world_size 4 \
train.py \
--lora_enable True --lora_r 128 --lora_alpha 256 --mm_projector_lr 2e-5 \
--deepspeed ./scripts/zero3.json \
--model_name_or_path mistralai/Mistral-7B-v0.1 \
--template v1 \
--data_path ./playground/data/llava_v1_5_mix665k.json \
--image_folder ./playground/data \
--data_path finetuning_data/llava_v1_5_mix665k.json \
--image_folder ./finetuning_data/ \
--vision_tower openai/clip-vit-large-patch14-336 \
--pretrain_mm_mlp_adapter llava-v1.5-mistral-7b-pretrain/mm_projector.bin \
--mm_projector_type mlp2x_gelu \
--mm_vision_select_layer -2 \
--mm_use_im_start_end False \
--mm_use_im_patch_token False \
--image_aspect_ratio pad \
--group_by_modality_length True \
--group_by_modality_length False \
--bf16 True \
--output_dir ./checkpoints/llava-v1.5-13b \
--num_train_epochs 1 \
Expand All @@ -46,6 +48,7 @@ deepspeed train.py \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--model_max_length 2048 \
--pad_max True \
--gradient_checkpointing True \
--dataloader_num_workers 4 \
--lazy_preprocess True \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

deepspeed --include localhost:0,1 \
--master_port 29501 \
python3 ./gaudi_spawn.py --use_deepspeed --world_size 4 \
train.py \
--deepspeed scripts/zero2.json \
--model_name_or_path mistralai/Mistral-7B-v0.1 \
Expand Down

0 comments on commit 3f62ceb

Please sign in to comment.