Skip to content

Commit

Permalink
fix fuyu device_map compatibility (#29880)
Browse files Browse the repository at this point in the history
fix foward
  • Loading branch information
SunMarc committed Mar 27, 2024
1 parent 4d8427f commit 31c575b
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion src/transformers/models/fuyu/modeling_fuyu.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,9 @@ def forward(
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
if image_patches is not None and past_key_values is None:
patch_embeddings = [
self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype)).squeeze(0)
self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype))
.squeeze(0)
.to(inputs_embeds.device)
for patch in image_patches
]
inputs_embeds = self.gather_continuous_embeddings(
Expand Down

0 comments on commit 31c575b

Please sign in to comment.