Skip to content

Commit

Permalink
fix: use eos token in target tensor for instruction-tuning (#3945)
Browse files Browse the repository at this point in the history
  • Loading branch information
geoffreyangus committed Feb 27, 2024
1 parent d347063 commit 021a099
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 23 deletions.
22 changes: 13 additions & 9 deletions ludwig/models/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,11 @@ def forward(
)

# Wrap with flash attention backend for faster generation
with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False) if (
torch.cuda.is_available() and self.curr_device.type == "cuda"
) else contextlib.nullcontext():
with (
torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False)
if (torch.cuda.is_available() and self.curr_device.type == "cuda")
else contextlib.nullcontext()
):
# TODO (jeffkinnison): Determine why the 8-bit `SCB` and `CB` matrices are deleted in the forward pass
model_outputs = self.model(input_ids=self.model_inputs, attention_mask=self.attention_masks).get(LOGITS)

Expand Down Expand Up @@ -330,9 +332,11 @@ def generate(
input_lengths.append(input_ids_sample_no_padding.shape[1])

# Wrap with flash attention backend for faster generation
with torch.backends.cuda.sdp_kernel(
enable_flash=True, enable_math=False, enable_mem_efficient=False
) if (torch.cuda.is_available() and self.curr_device.type == "cuda") else contextlib.nullcontext():
with (
torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=False)
if (torch.cuda.is_available() and self.curr_device.type == "cuda")
else contextlib.nullcontext()
):
# Generate text using the model
model_outputs = self.model.generate(
input_ids=input_ids_sample_no_padding,
Expand Down Expand Up @@ -656,7 +660,7 @@ def _update_target_tensor_for_finetuning(
) -> Dict[str, torch.Tensor]:
"""Update target tensor for fine-tuning.
This method removes left padding from target tensors, adds a pad token to the end of the target tensors,
This method removes left padding from target tensors, adds a eos token to the end of the target tensors,
and pads the target tensors with -100 to ensure equal length for loss computation. It then realigns the
target tensors with the prediction tensors.
Expand All @@ -674,10 +678,10 @@ def _update_target_tensor_for_finetuning(
targets_without_padding = []
lengths = []

pad_token_tensor = torch.tensor([self.tokenizer.pad_token_id])
eos_token_tensor = torch.tensor([self.tokenizer.eos_token_id])
for target in targets[of_name]:
target = remove_left_padding(target, self.tokenizer)[0]
target = torch.cat([target, pad_token_tensor.to(device=target.device)], dim=-1).unsqueeze(0)
target = torch.cat([target, eos_token_tensor.to(device=target.device)], dim=-1).unsqueeze(0)
targets_without_padding.append(target)
lengths.append(target.shape[1])

Expand Down
4 changes: 2 additions & 2 deletions ludwig/utils/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,14 +471,14 @@ def generate_merged_ids(
merged_input_and_targets = []
lengths = []

pad_tensor = torch.tensor([tokenizer.pad_token_id]).to(target_ids[0].device)
eos_tensor = torch.tensor([tokenizer.eos_token_id]).to(target_ids[0].device)

# Merge input_ids and target_ids by concatenating them together.
# We remove the left padding from both input_ids and target_ids before concatenating them.
for input_id_sample, target_id_sample in zip(input_ids, target_ids):
input_id_sample_no_padding = remove_left_padding(input_id_sample, tokenizer)[0]
target_id_sample_no_padding = remove_left_padding(target_id_sample, tokenizer)[0]
target_id_sample_no_padding = torch.cat((target_id_sample_no_padding, pad_tensor), dim=-1)
target_id_sample_no_padding = torch.cat((target_id_sample_no_padding, eos_tensor), dim=-1)

merged_sample_ids = torch.cat((input_id_sample_no_padding, target_id_sample_no_padding), dim=-1)
# If the merged tensor is longer than the maximum sequence length, we truncate it.
Expand Down
20 changes: 10 additions & 10 deletions tests/integration_tests/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1339,22 +1339,22 @@ def test_llm_used_tokens(tmpdir):
) as f:
progress_tracker = json.load(f)

assert progress_tracker["cumulative_step_token_usage"]["11"] == progress_tracker["total_tokens_used"] == 612
assert progress_tracker["cumulative_step_token_usage"]["11"] == progress_tracker["total_tokens_used"] == 621
assert progress_tracker["checkpoint_to_epoch"] == {"1": 1, "2": 1, "3": 2, "4": 2, "5": 3, "6": 3}
assert progress_tracker["checkpoint_to_step"] == {"1": 4, "2": 4, "3": 8, "4": 8, "5": 12, "6": 12}
assert progress_tracker["cumulative_checkpoint_token_usage"] == {
"1": 204,
"2": 204,
"3": 408,
"4": 408,
"5": 612,
"6": 612,
"1": 207,
"2": 207,
"3": 414,
"4": 414,
"5": 621,
"6": 621,
}
assert progress_tracker["incremental_checkpoint_token_usage"] == {
"1": 204,
"1": 207,
"2": 0,
"3": 204,
"3": 207,
"4": 0,
"5": 204,
"5": 207,
"6": 0,
}
4 changes: 2 additions & 2 deletions tests/ludwig/utils/test_llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def test_find_last_matching_index(tensor_a, tensor_b, expected_index):
def test_generate_merged_ids_with_target(tokenizer, input_ids, target_ids):
# Test case when target_ids is not None
merged_ids, attention_masks = generate_merged_ids(input_ids, target_ids, tokenizer)
assert torch.equal(merged_ids, torch.tensor([[3, 4, 5, 9, 10, 11, 1], [6, 7, 8, 12, 13, 14, 1]]))
assert torch.equal(merged_ids, torch.tensor([[3, 4, 5, 9, 10, 11, 2], [6, 7, 8, 12, 13, 14, 2]]))
assert merged_ids.shape == (2, 7) # Check the shape of merged_ids
assert attention_masks.shape == (2, 7) # Check the shape of attention_masks

Expand Down Expand Up @@ -186,7 +186,7 @@ def test_generate_merged_ids_padding_removal(tokenizer, input_ids, target_ids):

assert torch.equal(merged_ids[0][:3], input_ids[0]) # Check the input_ids part without padding
assert torch.equal(merged_ids[0][3:-1], target_ids[0]) # Check the target_ids part without padding
assert torch.equal(merged_ids[0][-1], torch.tensor(tokenizer.pad_token_id)) # Check the padding tokens
assert torch.equal(merged_ids[0][-1], torch.tensor(tokenizer.eos_token_id)) # Check the padding tokens

assert torch.all(attention_masks == 1)

Expand Down

0 comments on commit 021a099

Please sign in to comment.