Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@
import mindnlp
from diffusers import DiffusionPipeline

pipeline = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", ms_dtype=mindspore.float16)
pipeline = DiffusionPipeline.from_pretrained("stable-diffusion-v1-5/stable-diffusion-v1-5", ms_dtype=mindspore.float16, device_map='cuda')
pipeline("An image of a squirrel in Picasso style").images[0]
```

Expand Down
6 changes: 6 additions & 0 deletions mindtorch/_apis/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1240,3 +1240,9 @@ def scatter_nd_update(input, indices, updates):

def triu_indices(row, col, offset, dtype):
return legacy.triu_indices(row, col, offset, dtype)

def cumprod(input, dim, dtype):
out = legacy.cum_prod(input, dim, False, False)
if dtype is not None:
out = cast(out, dtype)
return out
6 changes: 6 additions & 0 deletions mindtorch/_apis/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,3 +1247,9 @@ def fft(input, n=None, dim=-1, norm="backward"):

def triu_indices(row, col, offset, dtype):
return legacy.triu_indices(row, col, offset, dtype)

def cumprod(input, dim, dtype):
out = legacy.cum_prod(input, dim, False, False)
if dtype is not None:
out = cast(out, dtype)
return out
6 changes: 6 additions & 0 deletions mindtorch/_apis/npu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1656,3 +1656,9 @@ def repeat_interleave_tensor(input, repeats, dim, output_size):

def triu_indices(row, col, offset, dtype):
return legacy.triu_indices(row, col, offset, dtype)

def cumprod(input, dim, dtype):
out = legacy.cum_prod(input, dim, False, False)
if dtype is not None:
out = cast(out, dtype)
return out
3 changes: 3 additions & 0 deletions mindtorch/ops/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ def clone(input, *, memory_format=mindtorch.preserve_format):
# cummin

# cumprod
def cumprod(input, dim, *, dtype=None, out=None):
return execute('cumprod', input, dim, dtype)

# cumsum
def cumsum(input, dim=None, dtype=None, **kwargs):
Expand Down Expand Up @@ -1131,6 +1133,7 @@ def cosine_similarity(x1, x2, dim=1, eps=1e-8):
"clone",
"contains",
"cross",
"cumprod",
"cumsum",
"diag",
"diagonal",
Expand Down
19 changes: 16 additions & 3 deletions mindtorch/ops/pointwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,8 +200,15 @@ def div(input, other, *, rounding_mode=None):
rounding_mode
)
else:
if not isinstance(other, numbers.Number) and not isinstance(input, numbers.Number) and other.device != input.device:
other = other.to(input.device)
if not isinstance(other, numbers.Number) and not isinstance(input, numbers.Number):
if other.device != input.device:
device = max([input.device, other.device])
other = other.to(device)
input = input.to(device)
if other.dtype != input.dtype:
dtype = min([input.dtype, other.dtype])
other = other.to(dtype)
input = input.to(dtype)
output = execute("div", input, other)
return output

Expand Down Expand Up @@ -380,7 +387,13 @@ def logical_xor(input, other):
# mul
def mul(input, other):
if not isinstance(other, numbers.Number) and other.device != input.device:
other = other.to(input.device)
device = max([input.device, other.device])
other = other.to(device)
input = input.to(device)
if not isinstance(other, numbers.Number) and other.dtype != input.dtype:
dtype = min([input.dtype, other.dtype])
other = other.to(dtype)
input = input.to(dtype)
# and isinstance(input, torch.Tensor):
# return execute("muls", input, other)
return execute("mul", input, other)
Expand Down
Loading