diff --git a/src/diffusers/loaders/lora_pipeline.py b/src/diffusers/loaders/lora_pipeline.py index a38b38774932..293b461ceabc 100644 --- a/src/diffusers/loaders/lora_pipeline.py +++ b/src/diffusers/loaders/lora_pipeline.py @@ -1978,6 +1978,16 @@ def _maybe_expand_transformer_param_shape_or_error_( "The checkpoint seems to have been quantized with `bitsandbytes` (4bits). Install `bitsandbytes` to load quantized checkpoints." ) elif is_bnb_4bit_quantized: + weight_on_cpu = False + if not module.weight.is_cuda: + weight_on_cpu = True + module_weight = dequantize_bnb_weight( + module.weight.cuda() if weight_on_cpu else module.weight, + state=module.weight.quant_state, + dtype=transformer.dtype, + ).data + if weight_on_cpu: + module_weight = module_weight.cpu() module_weight = dequantize_bnb_weight(module.weight, state=module.weight.quant_state).data else: module_weight = module.weight.data diff --git a/tests/quantization/bnb/test_4bit.py b/tests/quantization/bnb/test_4bit.py index 9b1f78acb795..7da118ee9f2c 100644 --- a/tests/quantization/bnb/test_4bit.py +++ b/tests/quantization/bnb/test_4bit.py @@ -21,8 +21,15 @@ import pytest import safetensors.torch from huggingface_hub import hf_hub_download - -from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel +from PIL import Image + +from diffusers import ( + BitsAndBytesConfig, + DiffusionPipeline, + FluxControlPipeline, + FluxTransformer2DModel, + SD3Transformer2DModel, +) from diffusers.utils import is_accelerate_version, logging from diffusers.utils.testing_utils import ( CaptureLogger, @@ -702,10 +709,7 @@ def setUp(self) -> None: gc.collect() torch.cuda.empty_cache() - self.pipeline_4bit = DiffusionPipeline.from_pretrained( - "eramth/flux-4bit", - torch_dtype=torch.float16, - ) + self.pipeline_4bit = FluxControlPipeline.from_pretrained("eramth/flux-4bit", torch_dtype=torch.float16) self.pipeline_4bit.enable_model_cpu_offload() def tearDown(self): @@ -719,6 +723,7 @@ def test_lora_loading(self): output = self.pipeline_4bit( prompt=self.prompt, + control_image=Image.new(mode="RGB", size=(256, 256)), height=256, width=256, max_sequence_length=64, @@ -727,8 +732,7 @@ def test_lora_loading(self): generator=torch.Generator().manual_seed(42), ).images out_slice = output[0, -3:, -3:, -1].flatten() - # TODO: update slice - expected_slice = np.array([0.5347, 0.5342, 0.5283, 0.5093, 0.4988, 0.5093, 0.5044, 0.5015, 0.4946]) + expected_slice = np.array([0.1636, 0.1675, 0.1982, 0.1743, 0.1809, 0.1936, 0.1743, 0.2095, 0.2139]) max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}")