From 80d61657dd4d00c3b917ad5e9b3a4847f12aae71 Mon Sep 17 00:00:00 2001 From: Kaihui-intel Date: Thu, 6 Nov 2025 03:03:50 -0500 Subject: [PATCH] fix mllm device_map ut Signed-off-by: Kaihui-intel --- test/test_cuda/test_multiple_card.py | 18 ++++++++---------- 1 file changed, 8 insertions(+), 10 deletions(-) diff --git a/test/test_cuda/test_multiple_card.py b/test/test_cuda/test_multiple_card.py index f2f1685be..5dac584fe 100644 --- a/test/test_cuda/test_multiple_card.py +++ b/test/test_cuda/test_multiple_card.py @@ -362,24 +362,22 @@ def test_mllm_device_map(self): device_map = "0,1" ar = AutoRoundMLLM(model_name, device_map=device_map) self.assertEqual(ar.device, "cuda:0") - self.assertEqual(ar.device_map, "auto") - self.assertEqual(ar.device_list, [0, 1]) + self.assertEqual(ar.device_map, device_map) device_map = 1 - ar = AutoRoundMLLM(ar.model, ar.tokenizer, ar.processor, device_map=device_map) + ar = AutoRoundMLLM(ar.model, ar.tokenizer, processor=ar.processor, device_map=device_map) self.assertEqual(ar.device, "cuda:1") - self.assertEqual(ar.device_map, None) - self.assertFalse(hasattr(ar, "device_list")) + self.assertEqual(ar.device_map, device_map) device_map = "auto" - ar = AutoRoundMLLM(ar.model, ar.tokenizer, ar.processor, device_map=device_map) + ar = AutoRoundMLLM(ar.model, ar.tokenizer, processor=ar.processor, device_map=device_map) self.assertEqual(ar.device, "cuda") - self.assertEqual(ar.device_map, "auto") + self.assertEqual(ar.device_map, device_map) device_map = {"model.language_model.layers": 0, "model.visual.blocks": 1} - ar = AutoRoundMLLM(ar.model, ar.tokenizer, ar.processor, device_map=device_map) - self.assertEqual(ar.model.model.language_model.layers.tuning_device, "cuda:0") - self.assertEqual(ar.model.model.visual.blocks.tuning_device, "cuda:1") + ar = AutoRoundMLLM(ar.model, ar.tokenizer, processor=ar.processor, device_map=device_map) + self.assertEqual(ar.model.model.language_model.layers[0].self_attn.q_proj.tuning_device, "cuda:0") + self.assertEqual(ar.model.model.visual.blocks[0].mlp.fc1.tuning_device, "cuda:1") if __name__ == "__main__":