Skip to content

Commit

Permalink
Params4bit added to bnb classes in set_module_tensor_to_device() (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
poedator committed Jan 10, 2024
1 parent 0d2280d commit 456afd9
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion src/accelerate/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def set_module_tensor_to_device(
param is not None
and param.device.type != "cuda"
and torch.device(device).type == "cuda"
and param_cls.__name__ in ["Int8Params", "FP4Params"]
and param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"]
):
device_quantization = device
device = "cpu"
Expand Down

0 comments on commit 456afd9

Please sign in to comment.