Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DPO Llava 1.5 and PaliGemma support #1797

Merged
merged 7 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 29 additions & 27 deletions examples/scripts/dpo_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@
accelerate launch examples/scripts/dpo_visual.py \
--dataset_name HuggingFaceH4/rlaif-v_formatted \
--model_name_or_path HuggingFaceM4/idefics2-8b \
--per_device_train_batch_size 1 \
--gradient_accumulation_steps 16 \
--per_device_train_batch_size 2 \
--gradient_accumulation_steps 32 \
--dataset_num_proc 32 \
--output_dir dpo_idefics_rlaif-v \
--bf16 \
--torch_dtype bfloat16 \
--gradient_checkpointing \
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Allows the training to fit in the memory without the need to resize images. (resizing has been removed as well, see under)

--use_peft \
--lora_target_modules=all-linear
"""
Expand Down Expand Up @@ -82,21 +83,40 @@

model_kwargs = dict(
revision=model_config.model_revision,
trust_remote_code=model_config.trust_remote_code,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #1806

attn_implementation=model_config.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PaliGemma doesn't take use_cache as init argument

device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config,
)
model = AutoModelForVision2Seq.from_pretrained(model_config.model_name_or_path, **model_kwargs)
model = AutoModelForVision2Seq.from_pretrained(
model_config.model_name_or_path,
trust_remote_code=model_config.trust_remote_code,
**model_kwargs,
)
peft_config = get_peft_config(model_config)
if peft_config is None:
model_ref = AutoModelForVision2Seq.from_pretrained(model_config.model_name_or_path, **model_kwargs)
model_ref = AutoModelForVision2Seq.from_pretrained(
model_config.model_name_or_path,
trust_remote_code=model_config.trust_remote_code,
**model_kwargs,
)
else:
model_ref = None
processor = AutoProcessor.from_pretrained(model_config.model_name_or_path, do_image_splitting=False)
processor = AutoProcessor.from_pretrained(
model_config.model_name_or_path,
trust_remote_code=model_config.trust_remote_code,
do_image_splitting=False,
)
tokenizer = processor.tokenizer

# Set up the chat template
if model.config.model_type == "idefics2":
pass # the processor already has a valid chat template
elif model.config.model_type == "paligemma":
processor.chat_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}<|im_start|>{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] if item['type'] == 'text' %}{{ item['text'] }}<|im_end|>{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}"""
elif model.config.model_type == "llava":
processor.chat_template = """{% if not add_generation_prompt is defined %}{% set add_generation_prompt = false %}{% endif %}{% for message in messages %}{% if message['role'] == 'user' %}USER: {% else %}ASSISTANT: {% endif %}{% for item in message['content'] %}{% if item['type'] == 'text' %}{{ item['text'] }}{% elif item['type'] == 'image' %}<image>{% endif %}{% endfor %}{% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}{% if add_generation_prompt %}ASSISTANT: {% endif %}"""
Comment on lines +113 to +118
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We expect the chat templating to support this structure of data:

[{"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "foo"}]}]

but idefics2 is the only model to support it. So we need to adapt the chat template of the other models.

Copy link
Member

@zucchini-nlp zucchini-nlp Jul 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Btw, llava models have different chat templates depending on which llm backbone is used. We'll add chat templates directly to processor, but not very soon probably.

Wondering if we need to support different templates here for each checkpoint

(Ah, I see, seems like this PR doesn't support llava-1.6 which is the one with different templates. Then the question becomes if llava-1.6 will be supported soon?)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can definitely work on its support, yes.


if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
if args.ignore_bias_buffers:
Expand Down Expand Up @@ -124,27 +144,9 @@
ds[key] = ds[key].select(range(50))

def process(row):
# The prompt can be either a string or a list. In some datasets, the prompt is just a common string
# for both rejected and chosen (already included in chosen and rejected) and is not meant to be used
# separately. In other datasets, the prompt is intended to be used as a prefix for rejected and chosen,
# and in such cases, it is properly formatted as a list with keys "role" and "content".
# Example 1:
# row = {"prompt": "What does detox mean?",
# "chosen": [{"content": "What does detox mean?", "role": "user"}, {"content": "It means to get rid of the toxins.", "role": "assistant"}],
# "rejected": [{"content": "What does detox mean?", "role": "assistant"}, {"content": "I don't know.", "role": "user"}]}
# Example 2:
# row = {"prompt": [{"content": "What does detox mean?", "role": "user"}],
# "chosen": [{"content": "It means to get rid of the toxins.", "role": "assistant"}],
# "rejected": [{"content": "I don't know.", "role": "user"}]}
if "prompt" in row and isinstance(row["prompt"], list):
row["prompt"] = processor.apply_chat_template(row["prompt"], tokenize=False)

row["prompt"] = processor.apply_chat_template(row["prompt"], tokenize=False)
Comment on lines -127 to +147
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turns out that this isn't the case for visual data: we always have a prompt. The reason is that it's the prompt that contains the image (the image can't be part of the output).

row["chosen"] = processor.apply_chat_template(row["chosen"], tokenize=False)
row["rejected"] = processor.apply_chat_template(row["rejected"], tokenize=False)

if "images" in row:
for img in row["images"]: # Resize each image so the largest side is 640 pixels
img.thumbnail((640, 640)) # Resize the image to at most 640x640 pixels
Comment on lines -144 to -147
Copy link
Member Author

@qgallouedec qgallouedec Jul 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Resizing is not required. All processors handle images of any size (they resize if necessary). I previously needed to avoid oom error, but I've fixed it with gradient_checkpointing

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, we shouldn't do it or else it breaks

return row

with PartialState().local_main_process_first():
Expand All @@ -168,6 +170,6 @@ def process(row):
)

trainer.train()
trainer.push_to_hub

with save_context:
trainer.save_model(training_args.output_dir)
65 changes: 43 additions & 22 deletions trl/trainer/dpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -723,9 +723,18 @@ def build_tokenized_answer(self, prompt, answer, images=None):
if self.is_vision_model:
if answer.count("<image>") > 0:
raise NotImplementedError("Answer contains <image> token, which is not supported yet.")
full_tokenized = self.processor(prompt + answer, images=images, add_special_tokens=False)
if "add_special_tokens" in inspect.signature(self.processor).parameters:
processor_kwargs = {"add_special_tokens": False}
else:
processor_kwargs = {}
full_tokenized = self.processor(prompt + answer, images=images, **processor_kwargs)
Comment on lines +726 to +730
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add_special_token is not an argument in all processors. So the only way is to check the signature.

full_tokenized = {k: v[0] for k, v in full_tokenized.items()} # Unbatch, not done when using idefics
prompt_input_ids = self.processor(prompt, images=images, add_special_tokens=False)["input_ids"][0]
if not isinstance(full_tokenized["input_ids"], list): # llava processor returns tensors
full_tokenized["input_ids"] = full_tokenized["input_ids"].tolist()
full_tokenized["attention_mask"] = full_tokenized["attention_mask"].tolist()
prompt_input_ids = self.processor(prompt, images=images, **processor_kwargs)["input_ids"][0]
if not isinstance(prompt_input_ids, list): # llava processor returns tensors
prompt_input_ids = prompt_input_ids.tolist()
Comment on lines +736 to +737
Copy link
Member Author

@qgallouedec qgallouedec Jul 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make sure that the input_ids is a list and not a tensor. (Llava returns a tensor)

else:
full_tokenized = self.tokenizer(prompt + answer, add_special_tokens=False)
prompt_input_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"]
Expand Down Expand Up @@ -762,22 +771,18 @@ def build_tokenized_answer(self, prompt, answer, images=None):
answer_input_ids = full_tokenized["input_ids"][response_token_ids_start_idx:]
answer_attention_mask = full_tokenized["attention_mask"][response_token_ids_start_idx:]

return_dict = dict(
prompt_input_ids=prompt_input_ids,
prompt_attention_mask=prompt_attention_mask,
input_ids=answer_input_ids,
attention_mask=answer_attention_mask,
)
if "pixel_values" in full_tokenized:
return dict(
prompt_input_ids=prompt_input_ids,
prompt_attention_mask=prompt_attention_mask,
prompt_pixel_values=full_tokenized["pixel_values"],
prompt_pixel_attention_mask=full_tokenized["pixel_attention_mask"],
input_ids=answer_input_ids,
attention_mask=answer_attention_mask,
)
else:
return dict(
prompt_input_ids=prompt_input_ids,
prompt_attention_mask=prompt_attention_mask,
input_ids=answer_input_ids,
attention_mask=answer_attention_mask,
)
return_dict["prompt_pixel_values"] = full_tokenized["pixel_values"]
if "pixel_attention_mask" in full_tokenized:
return_dict["prompt_pixel_attention_mask"] = full_tokenized["pixel_attention_mask"]

return return_dict
Comment on lines +774 to +785
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Equivalent, but simpler. The only motif is that it allows full_tokenized not to have an entry "pixel_attention_mask" (which is the case for Llava)


def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module]] = None) -> Dict:
"""Tokenize a single row from a DPO specific dataset.
Expand Down Expand Up @@ -805,8 +810,15 @@ def tokenize_row(self, feature, model: Optional[Union[PreTrainedModel, nn.Module
if not isinstance(prompt, str):
raise ValueError(f"prompt should be an str but got {type(prompt)}")
if self.is_vision_model:
prompt_tokens = self.processor(prompt, images=images, add_special_tokens=False)
if "add_special_tokens" in inspect.signature(self.processor).parameters:
processor_kwargs = {"add_special_tokens": False}
else:
processor_kwargs = {}
prompt_tokens = self.processor(prompt, images=images, **processor_kwargs)
prompt_tokens = {k: v[0] for k, v in prompt_tokens.items()} # Unbatch, not done when using idefics
if not isinstance(prompt_tokens["input_ids"], list): # llava processor returns tensors
prompt_tokens["input_ids"] = prompt_tokens["input_ids"].tolist()
prompt_tokens["attention_mask"] = prompt_tokens["attention_mask"].tolist()
else:
prompt_tokens = self.tokenizer(prompt, add_special_tokens=False)

Expand Down Expand Up @@ -1037,10 +1049,13 @@ def concatenated_inputs(
)

if is_vision_model:
concatenated_batch["pixel_values"] = batch["prompt_pixel_values"].repeat(2, 1, 1, 1, 1).to(device=device)
concatenated_batch["pixel_attention_mask"] = (
batch["prompt_pixel_attention_mask"].repeat(2, 1, 1, 1).to(device=device)
concatenated_batch["pixel_values"] = torch.cat(
[batch["prompt_pixel_values"], batch["prompt_pixel_values"]], dim=0
)
if "prompt_pixel_attention_mask" in batch:
concatenated_batch["pixel_attention_mask"] = torch.cat(
[batch["prompt_pixel_attention_mask"], batch["prompt_pixel_attention_mask"]], dim=0
)
kashif marked this conversation as resolved.
Show resolved Hide resolved
return concatenated_batch

def dpo_loss(
Expand Down Expand Up @@ -1262,7 +1277,8 @@ def concatenated_forward(

if self.is_vision_model:
model_kwargs["pixel_values"] = concatenated_batch["pixel_values"]
model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"]
if "pixel_attention_mask" in concatenated_batch:
model_kwargs["pixel_attention_mask"] = concatenated_batch["pixel_attention_mask"]

if self.aux_loss_enabled:
model_kwargs["output_router_logits"] = True
Expand All @@ -1275,6 +1291,11 @@ def concatenated_forward(
)
all_logits = outputs.logits

if all_logits.shape[:2] != concatenated_batch["concatenated_labels"].shape[:2]:
# for llava, the model returns logits for the entire sequence, including the image tokens (placed before the text tokens)
seq_len = concatenated_batch["concatenated_labels"].shape[1]
all_logits = all_logits[:, -seq_len:]

all_logps, size_completion = self.get_batch_logps(
all_logits,
concatenated_batch["concatenated_labels"],
Expand Down
Loading