Skip to content

Commit

Permalink
DPO Llava 1.5 and PaliGemma support (#1797)
Browse files Browse the repository at this point in the history
* llava support dpo

* add_special_tokens=False only when possible

* format

* pali gemma

* refactor size

* remove image resize

---------

Co-authored-by: Quentin Gallouédec <quentin.gallouedec@huggingface.co>
  • Loading branch information
qgallouedec and Quentin Gallouédec committed Jul 9, 2024
1 parent 30e33bd commit 2860ce5
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 49 deletions.
56 changes: 29 additions & 27 deletions examples/scripts/dpo_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
accelerate launch examples/scripts/dpo_visual.py \
--dataset_name HuggingFaceH4/rlaif-v_formatted \
--model_name_or_path HuggingFaceM4/idefics2-8b \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 32 \
--dataset_num_proc 32 \
--output_dir dpo_idefics_rlaif-v \
--bf16 \
--torch_dtype bfloat16 \
--gradient_checkpointing \
--use_peft \
--lora_target_modules=all-linear
"""
Expand Down Expand Up @@ -82,21 +83,40 @@

model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForVision2Seq.from_pretrained(model_config.model_name_or_path, **model_kwargs)
model = AutoModelForVision2Seq.from_pretrained(
model_config.model_name_or_path,
trust_remote_code=model_config.trust_remote_code,
**model_kwargs,
)
peft_config = get_peft_config(model_config)
if peft_config is None:
model_ref = AutoModelForVision2Seq.from_pretrained(model_config.model_name_or_path, **model_kwargs)
model_ref = AutoModelForVision2Seq.from_pretrained(
model_config.model_name_or_path,
trust_remote_code=model_config.trust_remote_code,
**model_kwargs,
)
else:
model_ref = None
processor = AutoProcessor.from_pretrained(model_config.model_name_or_path, do_image_splitting=False)
processor = AutoProcessor.from_pretrained(
model_config.model_name_or_path,
trust_remote_code=model_config.trust_remote_code,
do_image_splitting=False,
)
tokenizer = processor.tokenizer

# Set up the chat template
if model.config.model_type == "idefics2":
pass # the processor already has a valid chat template
elif model.config.model_type == "paligemma":
processor.chat_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}<|im_start|>{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] if item['type'] == 'text' %}{{ item['text'] }}<|im_end|>{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}"""
elif model.config.model_type == "llava":
processor.chat_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}"""

if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if args.ignore_bias_buffers:
Expand Down Expand Up @@ -124,27 +144,9 @@
ds[key] = ds[key].select(range(50))

def process(row):
# The prompt can be either a string or a list. In some datasets, the prompt is just a common string
# for both rejected and chosen (already included in chosen and rejected) and is not meant to be used
# separately. In other datasets, the prompt is intended to be used as a prefix for rejected and chosen,
# and in such cases, it is properly formatted as a list with keys "role" and "content".
# Example 1:
# row = {"prompt": "What does detox mean?",
# "chosen": [{"content": "What does detox mean?", "role": "user"}, {"content": "It means to get rid of the toxins.", "role": "assistant"}],
# "rejected": [{"content": "What does detox mean?", "role": "assistant"}, {"content": "I don't know.", "role": "user"}]}
# Example 2:
# row = {"prompt": [{"content": "What does detox mean?", "role": "user"}],
# "chosen": [{"content": "It means to get rid of the toxins.", "role": "assistant"}],
# "rejected": [{"content": "I don't know.", "role": "user"}]}
if "prompt" in row and isinstance(row["prompt"], list):
row["prompt"] = processor.apply_chat_template(row["prompt"], tokenize=False)

row["prompt"] = processor.apply_chat_template(row["prompt"], tokenize=False)
row["chosen"] = processor.apply_chat_template(row["chosen"], tokenize=False)
row["rejected"] = processor.apply_chat_template(row["rejected"], tokenize=False)

if "images" in row:
for img in row["images"]: # Resize each image so the largest side is 640 pixels
img.thumbnail((640, 640)) # Resize the image to at most 640x640 pixels
return row

with PartialState().local_main_process_first():
Expand All @@ -168,6 +170,6 @@ def process(row):
)

trainer.train()
trainer.push_to_hub

with save_context:
trainer.save_model(training_args.output_dir)
65 changes: 43 additions & 22 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,9 +723,18 @@ def build_tokenized_answer(self, prompt, answer, images=None):
if self.is_vision_model:
if answer.count("<image>") > 0:
raise NotImplementedError("Answer contains <image> token, which is not supported yet.")
full_tokenized = self.processor(prompt + answer, images=images, add_special_tokens=False)
if "add_special_tokens" in inspect.signature(self.processor).parameters:
processor_kwargs = {"add_special_tokens": False}
else:
processor_kwargs = {}
full_tokenized = self.processor(prompt + answer, images=images, **processor_kwargs)
full_tokenized = {k: v[0] for k, v in full_tokenized.items()} # Unbatch, not done when using idefics
prompt_input_ids = self.processor(prompt, images=images, add_special_tokens=False)["input_ids"][0]
if not isinstance(full_tokenized["input_ids"], list): # llava processor returns tensors
full_tokenized["input_ids"] = full_tokenized["input_ids"].tolist()
full_tokenized["attention_mask"] = full_tokenized["attention_mask"].tolist()
prompt_input_ids = self.processor(prompt, images=images, **processor_kwargs)["input_ids"][0]
if not isinstance(prompt_input_ids, list): # llava processor returns tensors
prompt_input_ids = prompt_input_ids.tolist()
else:
full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False)
prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"]
Expand Down Expand Up @@ -762,22 +771,18 @@ def build_tokenized_answer(self, prompt, answer, images=None):
answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]

return_dict = dict(
prompt_input_ids=prompt_input_ids,
prompt_attention_mask=prompt_attention_mask,
input_ids=answer_input_ids,
attention_mask=answer_attention_mask,
)
if "pixel_values" in full_tokenized:
return dict(
prompt_input_ids=prompt_input_ids,
prompt_attention_mask=prompt_attention_mask,
prompt_pixel_values=full_tokenized["pixel_values"],
prompt_pixel_attention_mask=full_tokenized["pixel_attention_mask"],
input_ids=answer_input_ids,
attention_mask=answer_attention_mask,
)
else:
return dict(
prompt_input_ids=prompt_input_ids,
prompt_attention_mask=prompt_attention_mask,
input_ids=answer_input_ids,
attention_mask=answer_attention_mask,
)
return_dict["prompt_pixel_values"] = full_tokenized["pixel_values"]
if "pixel_attention_mask" in full_tokenized:
return_dict["prompt_pixel_attention_mask"] = full_tokenized["pixel_attention_mask"]

return return_dict

def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> Dict:
"""Tokenize a single row from a DPO specific dataset.
Expand Down Expand Up @@ -805,8 +810,15 @@ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module
if not isinstance(prompt, str):
raise ValueError(f"prompt should be an str but got {type(prompt)}")
if self.is_vision_model:
prompt_tokens = self.processor(prompt, images=images, add_special_tokens=False)
if "add_special_tokens" in inspect.signature(self.processor).parameters:
processor_kwargs = {"add_special_tokens": False}
else:
processor_kwargs = {}
prompt_tokens = self.processor(prompt, images=images, **processor_kwargs)
prompt_tokens = {k: v[0] for k, v in prompt_tokens.items()} # Unbatch, not done when using idefics
if not isinstance(prompt_tokens["input_ids"], list): # llava processor returns tensors
prompt_tokens["input_ids"] = prompt_tokens["input_ids"].tolist()
prompt_tokens["attention_mask"] = prompt_tokens["attention_mask"].tolist()
else:
prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)

Expand Down Expand Up @@ -1037,10 +1049,13 @@ def concatenated_inputs(
)

if is_vision_model:
concatenated_batch["pixel_values"] = batch["prompt_pixel_values"].repeat(2, 1, 1, 1, 1).to(device=device)
concatenated_batch["pixel_attention_mask"] = (
batch["prompt_pixel_attention_mask"].repeat(2, 1, 1, 1).to(device=device)
concatenated_batch["pixel_values"] = torch.cat(
[batch["prompt_pixel_values"], batch["prompt_pixel_values"]], dim=0
)
if "prompt_pixel_attention_mask" in batch:
concatenated_batch["pixel_attention_mask"] = torch.cat(
[batch["prompt_pixel_attention_mask"], batch["prompt_pixel_attention_mask"]], dim=0
)
return concatenated_batch

def dpo_loss(
Expand Down Expand Up @@ -1262,7 +1277,8 @@ def concatenated_forward(

if self.is_vision_model:
model_kwargs["pixel_values"] = concatenated_batch["pixel_values"]
model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"]
if "pixel_attention_mask" in concatenated_batch:
model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"]

if self.aux_loss_enabled:
model_kwargs["output_router_logits"] = True
Expand All @@ -1275,6 +1291,11 @@ def concatenated_forward(
)
all_logits = outputs.logits

if all_logits.shape[:2] != concatenated_batch["concatenated_labels"].shape[:2]:
# for llava, the model returns logits for the entire sequence, including the image tokens (placed before the text tokens)
seq_len = concatenated_batch["concatenated_labels"].shape[1]
all_logits = all_logits[:, -seq_len:]

all_logps, size_completion = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
Expand Down

0 comments on commit 2860ce5

Please sign in to comment.