Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions mindtorch/_apis/cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -1089,7 +1089,7 @@ def leaky_relu(input, negative_slope):
select_op = maximum
if negative_slope > 1:
select_op = minimum
return select_op(mul(negative_slope, input), input)
return select_op(mul(input, negative_slope), input)

def ceil(input):
return legacy.ceil(input)
Expand Down Expand Up @@ -1220,7 +1220,8 @@ def logsumexp(input, dim, keepdim=False):
return add(input_logsumexp, input_max)

def bernoulli(input, generator):
return legacy.bernoulli(input, seed, offset)
seed, offset = generator._step(12) # pylint: disable=protected-access
return legacy.bernoulli(input, 0.5, seed.item(), offset.item())

def right_shift(input, other):
return legacy.right_shift(input, other)
Expand Down
13 changes: 8 additions & 5 deletions mindtorch/_apis/gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,9 +746,9 @@ def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_paddi
pad_mode = 'pad'
pad = padding
if isinstance(padding, tuple):
pad = (0, 0, padding[0], padding[0])
pad = (padding[0], padding[0], padding[1], padding[1])
elif isinstance(padding, int):
pad = (0, 0) + (padding,) * 2
pad = (padding,) * 4
if not isinstance(padding, (int, tuple)):
pad_mode = padding
pad = (0,) * 4
Expand All @@ -758,7 +758,6 @@ def conv_transpose2d(input, weight, bias=None, stride=1, padding=0, output_paddi

in_channel, out_channels = weight.shape[0], weight.shape[1] * groups
kernel_size = weight.shape[2:]

n, _, h, w = input.shape
h_add = _deconv_output_length(pad_mode, kernel_size[0], stride[0], dilation[0], pad[0] + pad[1])
w_add = _deconv_output_length(pad_mode, kernel_size[1], stride[1], dilation[1], pad[2] + pad[3])
Expand Down Expand Up @@ -1004,7 +1003,7 @@ def leaky_relu(input, negative_slope):
select_op = maximum
if negative_slope > 1:
select_op = minimum
return select_op(mul(negative_slope, input), input)
return select_op(mul(input, negative_slope), input)

def ceil(input):
return legacy.ceil(input)
Expand Down Expand Up @@ -1146,4 +1145,8 @@ def search_sorted(sorted_sequence, values, sorter, dtype, right):
return legacy.search_sorted(sorted_sequence, values, sorter, dtype, right)

def einsum(equation, operands):
return legacy.einsum(operands, equation)
return legacy.einsum(operands, equation)

def unique2(input, sorted, return_inverse, return_counts):
outs = legacy.unique(input)
return outs + (None,)
3 changes: 3 additions & 0 deletions mindtorch/nn/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def elu(input, alpha=1.0):
return execute('elu', input, alpha)

def glu(input, dim=-1):
if input.device.type == 'cuda':
x, y = input.chunk(2, dim)
return x * sigmoid(y)
return execute('glu', input, dim)

def softplus(input, beta=1, threshold=20):
Expand Down
2 changes: 1 addition & 1 deletion mindtorch/ops/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,7 @@ def range_(start, length):
stacked_indices = indices[0]
stacked_indices = stacked_indices.to(mindtorch.int32)
stacked_indices = where(
stacked_indices < 0, stacked_indices + mindtorch.tensor(dim_sizes, device=stacked_indices.device), stacked_indices
stacked_indices < 0, stacked_indices + mindtorch.tensor(dim_sizes, dtype=stacked_indices.dtype, device=stacked_indices.device), stacked_indices
)
axis = dims[0]
if len(dims) > 1:
Expand Down
3 changes: 3 additions & 0 deletions mindtorch/ops/other.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def broadcast_tensors(*tensors):

# broadcast_to
def broadcast_to(input, shape):
if input.shape == shape:
return input

new_shape = ()
for s in shape:
if not isinstance(s, int):
Expand Down