From 31c575bcf13c2b85b65d652dd1b5b401f99be999 Mon Sep 17 00:00:00 2001 From: Marc Sun <57196510+SunMarc@users.noreply.github.com> Date: Wed, 27 Mar 2024 10:18:48 +0100 Subject: [PATCH] fix fuyu device_map compatibility (#29880) fix foward --- src/transformers/models/fuyu/modeling_fuyu.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/transformers/models/fuyu/modeling_fuyu.py b/src/transformers/models/fuyu/modeling_fuyu.py index 94d9704631fba..f94bac569fc9b 100644 --- a/src/transformers/models/fuyu/modeling_fuyu.py +++ b/src/transformers/models/fuyu/modeling_fuyu.py @@ -290,7 +290,9 @@ def forward( inputs_embeds = self.language_model.get_input_embeddings()(input_ids) if image_patches is not None and past_key_values is None: patch_embeddings = [ - self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype)).squeeze(0) + self.vision_embed_tokens(patch.to(self.vision_embed_tokens.weight.dtype)) + .squeeze(0) + .to(inputs_embeds.device) for patch in image_patches ] inputs_embeds = self.gather_continuous_embeddings(