Skip to content

Commit

Permalink
add torch tensordot
Browse files Browse the repository at this point in the history
  • Loading branch information
jcmgray committed Mar 27, 2021
1 parent 8d331bb commit b62baa9
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions autoray/autoray.py
Original file line number Diff line number Diff line change
Expand Up @@ -989,6 +989,13 @@ def torch_linalg_eigvalsh(x):
return do("symeig", x, eigenvectors=False, like="torch")[0]


def torch_tensordot_wrap(fn):
@functools.wraps(fn)
def numpy_like(a, b, axes=2):
return fn(a, b, dims=axes)
return numpy_like


def torch_pad(array, pad_width, mode="constant", constant_values=0):
if mode != "constant":
raise NotImplementedError
Expand Down Expand Up @@ -1095,6 +1102,7 @@ def numpy_like(N, M=None, dtype=None, **kwargs):
_CUSTOM_WRAPPERS["torch", "random.normal"] = scale_random_normal_manually
_CUSTOM_WRAPPERS["torch", "random.uniform"] = scale_random_uniform_manually
_CUSTOM_WRAPPERS["torch", "split"] = torch_split_wrap
_CUSTOM_WRAPPERS["torch", "tensordot"] = torch_tensordot_wrap
_CUSTOM_WRAPPERS["torch", "stack"] = make_translator(
[
("arrays", ("tensors",)),
Expand Down

0 comments on commit b62baa9

Please sign in to comment.