Skip to content

Commit

Permalink
[DirectML] Add float64 power CPU fallback. (#436)
Browse files Browse the repository at this point in the history
  • Loading branch information
lshqqytiger committed Apr 17, 2024
1 parent 1743100 commit e51a849
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions modules/dml/hijack/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,11 @@ def lerp(*args, **kwargs) -> torch.Tensor:
return _lerp(*args, **kwargs).to(rep.device).type(rep.dtype)
return _lerp(*args, **kwargs)
torch.lerp = lerp

# https://github.com/lshqqytiger/stable-diffusion-webui-directml/issues/436
_pow_ = torch.Tensor.pow_
def pow_(self: torch.Tensor, *args, **kwargs):
if self.dtype == torch.float64:
return _pow_(self.cpu(), *args, **kwargs)
return _pow_(self, *args, **kwargs)
torch.Tensor.pow_ = _pow_

0 comments on commit e51a849

Please sign in to comment.