Skip to content

Commit

Permalink
remove sliced empty image features (#1199)
Browse files Browse the repository at this point in the history
Signed-off-by: Dillon Laird <dillonalaird@gmail.com>
  • Loading branch information
dillonalaird committed Jan 28, 2024
1 parent b8ee438 commit 9cff14a
Showing 1 changed file with 3 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,9 @@ def prepare_inputs_labels_for_multimodal(
for batch_idx, cur_input_ids in enumerate(input_ids):
num_images = (cur_input_ids == IMAGE_TOKEN_INDEX).sum()
if num_images == 0:
cur_image_features = image_features[cur_image_idx]
cur_input_embeds_1 = self.get_model().embed_tokens(cur_input_ids)
cur_input_embeds = torch.cat([cur_input_embeds_1, cur_image_features[0:0]], dim=0)
new_input_embeds.append(cur_input_embeds)
# Concatenating the cur_image_features[0:0], like in the original implementation,
# is removed as it causes the backpropogation to crash on the hpu.
new_input_embeds.append(self.get_model().embed_tokens(cur_input_ids))
new_labels.append(labels[batch_idx])
cur_image_idx += 1
continue
Expand Down

0 comments on commit 9cff14a

Please sign in to comment.