-
Notifications
You must be signed in to change notification settings - Fork 1.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DPO Llava 1.5 and PaliGemma support #1797
Changes from all commits
94cf1a9
1a51600
d259cee
543989c
61ce5f1
39082ce
406dee5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -16,12 +16,13 @@ | |
accelerate launch examples/scripts/dpo_visual.py \ | ||
--dataset_name HuggingFaceH4/rlaif-v_formatted \ | ||
--model_name_or_path HuggingFaceM4/idefics2-8b \ | ||
--per_device_train_batch_size 1 \ | ||
--gradient_accumulation_steps 16 \ | ||
--per_device_train_batch_size 2 \ | ||
--gradient_accumulation_steps 32 \ | ||
--dataset_num_proc 32 \ | ||
--output_dir dpo_idefics_rlaif-v \ | ||
--bf16 \ | ||
--torch_dtype bfloat16 \ | ||
--gradient_checkpointing \ | ||
--use_peft \ | ||
--lora_target_modules=all-linear | ||
""" | ||
|
@@ -82,21 +83,40 @@ | |
|
||
model_kwargs = dict( | ||
revision=model_config.model_revision, | ||
trust_remote_code=model_config.trust_remote_code, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See #1806 |
||
attn_implementation=model_config.attn_implementation, | ||
torch_dtype=torch_dtype, | ||
use_cache=False if training_args.gradient_checkpointing else True, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. PaliGemma doesn't take |
||
device_map=get_kbit_device_map() if quantization_config is not None else None, | ||
quantization_config=quantization_config, | ||
) | ||
model = AutoModelForVision2Seq.from_pretrained(model_config.model_name_or_path, **model_kwargs) | ||
model = AutoModelForVision2Seq.from_pretrained( | ||
model_config.model_name_or_path, | ||
trust_remote_code=model_config.trust_remote_code, | ||
**model_kwargs, | ||
) | ||
peft_config = get_peft_config(model_config) | ||
if peft_config is None: | ||
model_ref = AutoModelForVision2Seq.from_pretrained(model_config.model_name_or_path, **model_kwargs) | ||
model_ref = AutoModelForVision2Seq.from_pretrained( | ||
model_config.model_name_or_path, | ||
trust_remote_code=model_config.trust_remote_code, | ||
**model_kwargs, | ||
) | ||
else: | ||
model_ref = None | ||
processor = AutoProcessor.from_pretrained(model_config.model_name_or_path, do_image_splitting=False) | ||
processor = AutoProcessor.from_pretrained( | ||
model_config.model_name_or_path, | ||
trust_remote_code=model_config.trust_remote_code, | ||
do_image_splitting=False, | ||
) | ||
tokenizer = processor.tokenizer | ||
|
||
# Set up the chat template | ||
if model.config.model_type == "idefics2": | ||
pass # the processor already has a valid chat template | ||
elif model.config.model_type == "paligemma": | ||
processor.chat_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}<|im_start|>{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] if item['type'] == 'text' %}{{ item['text'] }}<|im_end|>{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}""" | ||
elif model.config.model_type == "llava": | ||
processor.chat_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}""" | ||
Comment on lines
+113
to
+118
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We expect the chat templating to support this structure of data:
but idefics2 is the only model to support it. So we need to adapt the chat template of the other models. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Btw, llava models have different chat templates depending on which llm backbone is used. We'll add chat templates directly to processor, but not very soon probably. Wondering if we need to support different templates here for each checkpoint (Ah, I see, seems like this PR doesn't support llava-1.6 which is the one with different templates. Then the question becomes if llava-1.6 will be supported soon?) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can definitely work on its support, yes. |
||
|
||
if tokenizer.pad_token is None: | ||
tokenizer.pad_token = tokenizer.eos_token | ||
if args.ignore_bias_buffers: | ||
|
@@ -124,27 +144,9 @@ | |
ds[key] = ds[key].select(range(50)) | ||
|
||
def process(row): | ||
# The prompt can be either a string or a list. In some datasets, the prompt is just a common string | ||
# for both rejected and chosen (already included in chosen and rejected) and is not meant to be used | ||
# separately. In other datasets, the prompt is intended to be used as a prefix for rejected and chosen, | ||
# and in such cases, it is properly formatted as a list with keys "role" and "content". | ||
# Example 1: | ||
# row = {"prompt": "What does detox mean?", | ||
# "chosen": [{"content": "What does detox mean?", "role": "user"}, {"content": "It means to get rid of the toxins.", "role": "assistant"}], | ||
# "rejected": [{"content": "What does detox mean?", "role": "assistant"}, {"content": "I don't know.", "role": "user"}]} | ||
# Example 2: | ||
# row = {"prompt": [{"content": "What does detox mean?", "role": "user"}], | ||
# "chosen": [{"content": "It means to get rid of the toxins.", "role": "assistant"}], | ||
# "rejected": [{"content": "I don't know.", "role": "user"}]} | ||
if "prompt" in row and isinstance(row["prompt"], list): | ||
row["prompt"] = processor.apply_chat_template(row["prompt"], tokenize=False) | ||
|
||
row["prompt"] = processor.apply_chat_template(row["prompt"], tokenize=False) | ||
Comment on lines
-127
to
+147
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It turns out that this isn't the case for visual data: we always have a prompt. The reason is that it's the prompt that contains the image (the image can't be part of the output). |
||
row["chosen"] = processor.apply_chat_template(row["chosen"], tokenize=False) | ||
row["rejected"] = processor.apply_chat_template(row["rejected"], tokenize=False) | ||
|
||
if "images" in row: | ||
for img in row["images"]: # Resize each image so the largest side is 640 pixels | ||
img.thumbnail((640, 640)) # Resize the image to at most 640x640 pixels | ||
Comment on lines
-144
to
-147
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Resizing is not required. All processors handle images of any size (they resize if necessary). I previously needed to avoid oom error, but I've fixed it with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep, we shouldn't do it or else it breaks |
||
return row | ||
|
||
with PartialState().local_main_process_first(): | ||
|
@@ -168,6 +170,6 @@ def process(row): | |
) | ||
|
||
trainer.train() | ||
trainer.push_to_hub | ||
|
||
with save_context: | ||
trainer.save_model(training_args.output_dir) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -723,9 +723,18 @@ def build_tokenized_answer(self, prompt, answer, images=None): | |
if self.is_vision_model: | ||
if answer.count("<image>") > 0: | ||
raise NotImplementedError("Answer contains <image> token, which is not supported yet.") | ||
full_tokenized = self.processor(prompt + answer, images=images, add_special_tokens=False) | ||
if "add_special_tokens" in inspect.signature(self.processor).parameters: | ||
processor_kwargs = {"add_special_tokens": False} | ||
else: | ||
processor_kwargs = {} | ||
full_tokenized = self.processor(prompt + answer, images=images, **processor_kwargs) | ||
Comment on lines
+726
to
+730
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
full_tokenized = {k: v[0] for k, v in full_tokenized.items()} # Unbatch, not done when using idefics | ||
prompt_input_ids = self.processor(prompt, images=images, add_special_tokens=False)["input_ids"][0] | ||
if not isinstance(full_tokenized["input_ids"], list): # llava processor returns tensors | ||
full_tokenized["input_ids"] = full_tokenized["input_ids"].tolist() | ||
full_tokenized["attention_mask"] = full_tokenized["attention_mask"].tolist() | ||
prompt_input_ids = self.processor(prompt, images=images, **processor_kwargs)["input_ids"][0] | ||
if not isinstance(prompt_input_ids, list): # llava processor returns tensors | ||
prompt_input_ids = prompt_input_ids.tolist() | ||
Comment on lines
+736
to
+737
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. make sure that the input_ids is a list and not a tensor. (Llava returns a tensor) |
||
else: | ||
full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False) | ||
prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"] | ||
|
@@ -762,22 +771,18 @@ def build_tokenized_answer(self, prompt, answer, images=None): | |
answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:] | ||
answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:] | ||
|
||
return_dict = dict( | ||
prompt_input_ids=prompt_input_ids, | ||
prompt_attention_mask=prompt_attention_mask, | ||
input_ids=answer_input_ids, | ||
attention_mask=answer_attention_mask, | ||
) | ||
if "pixel_values" in full_tokenized: | ||
return dict( | ||
prompt_input_ids=prompt_input_ids, | ||
prompt_attention_mask=prompt_attention_mask, | ||
prompt_pixel_values=full_tokenized["pixel_values"], | ||
prompt_pixel_attention_mask=full_tokenized["pixel_attention_mask"], | ||
input_ids=answer_input_ids, | ||
attention_mask=answer_attention_mask, | ||
) | ||
else: | ||
return dict( | ||
prompt_input_ids=prompt_input_ids, | ||
prompt_attention_mask=prompt_attention_mask, | ||
input_ids=answer_input_ids, | ||
attention_mask=answer_attention_mask, | ||
) | ||
return_dict["prompt_pixel_values"] = full_tokenized["pixel_values"] | ||
if "pixel_attention_mask" in full_tokenized: | ||
return_dict["prompt_pixel_attention_mask"] = full_tokenized["pixel_attention_mask"] | ||
|
||
return return_dict | ||
Comment on lines
+774
to
+785
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Equivalent, but simpler. The only motif is that it allows |
||
|
||
def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> Dict: | ||
"""Tokenize a single row from a DPO specific dataset. | ||
|
@@ -805,8 +810,15 @@ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module | |
if not isinstance(prompt, str): | ||
raise ValueError(f"prompt should be an str but got {type(prompt)}") | ||
if self.is_vision_model: | ||
prompt_tokens = self.processor(prompt, images=images, add_special_tokens=False) | ||
if "add_special_tokens" in inspect.signature(self.processor).parameters: | ||
processor_kwargs = {"add_special_tokens": False} | ||
else: | ||
processor_kwargs = {} | ||
prompt_tokens = self.processor(prompt, images=images, **processor_kwargs) | ||
prompt_tokens = {k: v[0] for k, v in prompt_tokens.items()} # Unbatch, not done when using idefics | ||
if not isinstance(prompt_tokens["input_ids"], list): # llava processor returns tensors | ||
prompt_tokens["input_ids"] = prompt_tokens["input_ids"].tolist() | ||
prompt_tokens["attention_mask"] = prompt_tokens["attention_mask"].tolist() | ||
else: | ||
prompt_tokens = self.tokenizer(prompt, add_special_tokens=False) | ||
|
||
|
@@ -1037,10 +1049,13 @@ def concatenated_inputs( | |
) | ||
|
||
if is_vision_model: | ||
concatenated_batch["pixel_values"] = batch["prompt_pixel_values"].repeat(2, 1, 1, 1, 1).to(device=device) | ||
concatenated_batch["pixel_attention_mask"] = ( | ||
batch["prompt_pixel_attention_mask"].repeat(2, 1, 1, 1).to(device=device) | ||
concatenated_batch["pixel_values"] = torch.cat( | ||
[batch["prompt_pixel_values"], batch["prompt_pixel_values"]], dim=0 | ||
) | ||
if "prompt_pixel_attention_mask" in batch: | ||
concatenated_batch["pixel_attention_mask"] = torch.cat( | ||
[batch["prompt_pixel_attention_mask"], batch["prompt_pixel_attention_mask"]], dim=0 | ||
) | ||
kashif marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return concatenated_batch | ||
|
||
def dpo_loss( | ||
|
@@ -1262,7 +1277,8 @@ def concatenated_forward( | |
|
||
if self.is_vision_model: | ||
model_kwargs["pixel_values"] = concatenated_batch["pixel_values"] | ||
model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] | ||
if "pixel_attention_mask" in concatenated_batch: | ||
model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"] | ||
|
||
if self.aux_loss_enabled: | ||
model_kwargs["output_router_logits"] = True | ||
|
@@ -1275,6 +1291,11 @@ def concatenated_forward( | |
) | ||
all_logits = outputs.logits | ||
|
||
if all_logits.shape[:2] != concatenated_batch["concatenated_labels"].shape[:2]: | ||
# for llava, the model returns logits for the entire sequence, including the image tokens (placed before the text tokens) | ||
seq_len = concatenated_batch["concatenated_labels"].shape[1] | ||
all_logits = all_logits[:, -seq_len:] | ||
|
||
all_logps, size_completion = self.get_batch_logps( | ||
all_logits, | ||
concatenated_batch["concatenated_labels"], | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Allows the training to fit in the memory without the need to resize images. (resizing has been removed as well, see under)