Skip to content
Merged
2 changes: 2 additions & 0 deletions src/transformers/utils/import_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -638,6 +638,8 @@ def is_torch_bf16_gpu_available() -> bool:
if is_torch_mps_available():
# Note: Emulated in software by Metal using fp32 for hardware without native support (like M1/M2)
return torch.backends.mps.is_macos_or_newer(14, 0)
if is_torch_musa_available():
return torch.musa.is_bf16_supported()
return False


Expand Down