Skip to content

Commit

Permalink
Fix the TypeError for XPU Accelerator (#5531)
Browse files Browse the repository at this point in the history
Fixing following error
/datadisk2/wengshiy/llm.devkit/DeepSpeed/deepspeed/runtime/utils.py
    return get_accelerator().FloatTensor(float(v)).detach()
TypeError: new(): data must be a sequence (got float)

cuda accelerator modified the interface for fixing warning:
177dc14

---------

Co-authored-by: Olatunji Ruwase <olruwase@microsoft.com>
  • Loading branch information
shiyang-weng and tjruwase committed May 20, 2024
1 parent 69af361 commit 1d81967
Showing 1 changed file with 8 additions and 7 deletions.
15 changes: 8 additions & 7 deletions accelerator/xpu_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from deepspeed.accelerator.abstract_accelerator import DeepSpeedAccelerator
import intel_extension_for_pytorch as ipex # noqa: F401 # type: ignore
import oneccl_bindings_for_pytorch # noqa: F401 # type: ignore
import functools


class XPU_Accelerator(DeepSpeedAccelerator):
Expand Down Expand Up @@ -191,31 +192,31 @@ def supported_dtypes(self):

@property
def BFloat16Tensor(self):
return torch.xpu.BFloat16Tensor
return functools.partial(torch.tensor, dtype=torch.bfloat16, device=self._name)

@property
def ByteTensor(self):
return torch.xpu.ByteTensor
return functools.partial(torch.tensor, dtype=torch.uint8, device=self._name)

@property
def DoubleTensor(self):
return torch.xpu.DoubleTensor
return functools.partial(torch.tensor, dtype=torch.double, device=self._name)

@property
def FloatTensor(self):
return torch.xpu.FloatTensor
return functools.partial(torch.tensor, dtype=torch.float, device=self._name)

@property
def HalfTensor(self):
return torch.xpu.HalfTensor
return functools.partial(torch.tensor, dtype=torch.half, device=self._name)

@property
def IntTensor(self):
return torch.xpu.IntTensor
return functools.partial(torch.tensor, dtype=torch.int, device=self._name)

@property
def LongTensor(self):
return torch.xpu.LongTensor
return functools.partial(torch.tensor, dtype=torch.long, device=self._name)

def pin_memory(self, tensor, align_bytes=1):
if align_bytes == 1:
Expand Down

0 comments on commit 1d81967

Please sign in to comment.