From 8e20332597d8b341ced21997a15982c338896fe5 Mon Sep 17 00:00:00 2001 From: Mengni Wang Date: Sun, 9 Nov 2025 22:38:41 -0500 Subject: [PATCH] Fix param missing bug --- auto_round/compressors/diffusion/compressor.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/auto_round/compressors/diffusion/compressor.py b/auto_round/compressors/diffusion/compressor.py index b57e927b6..e17692532 100644 --- a/auto_round/compressors/diffusion/compressor.py +++ b/auto_round/compressors/diffusion/compressor.py @@ -191,6 +191,7 @@ def _get_current_q_output( input_others: dict, indices: list[int], device: str, + cache_device: str = "cpu", ) -> torch.Tensor: output_config = output_configs.get(block.__class__.__name__, []) idx = None if "hidden_states" not in output_config else output_config.index("hidden_states") @@ -207,7 +208,7 @@ def _get_current_q_output( current_input_others.update(current_input_ids) current_input_ids = hidden_states output_q = block_forward(block, current_input_ids, current_input_others, self.amp, self.amp_dtype, device, idx) - return output_q + return output_q.to(cache_device) @torch.no_grad() def _get_block_outputs(