Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update mul(self, other): in _tt_base.py to allow multiplication with complex scalars and fix multiplication with torch.Tensor scalars #23

Merged
merged 1 commit into from
Jul 9, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions torchtt/_tt_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,16 +689,16 @@ def __mul__(self, other):
Following are supported:
- TT tensor and TT tensor
- TT matrix and TT matrix
- TT tensor and scalar(int, float or torch.tensor scalar)
- TT tensor and scalar(int, float, complex or torch.Tensor scalar)
The broadcasting rules are the same as in torch (see [here](https://pytorch.org/docs/stable/notes/broadcasting.html)).

Args:
other (torchtt.TT | float | int | torch.tensor): the second operand. If a `torch.tensor` is provided, it must have 1 element.
other (torchtt.TT | float | int | complex | torch.Tensor): the second operand. If a `torch.Tensor` is provided, it must have 1 element.

Raises:
ShapeMismatch: Shapes are incompatible (see the broadcasting rules).
IncompatibleTypes: Second operand must be the same type as the fisrt (both should be either TT matrices or TT tensors).
InvalidArguments: Second operand must be of type: torchtt.TT, float, int of torch.tensor.
InvalidArguments: Second operand must be of type: torchtt.TT, float, int, complex or torch.Tensor.

Returns:
torchtt.TT: the result.
Expand Down Expand Up @@ -762,7 +762,7 @@ def __mul__(self, other):
'Second operand must be the same type as the fisrt (both should be either TT matrices or TT tensors).')
result = TT(cores_new)

elif isinstance(other, int) or isinstance(other, float) or isinstance(other, tn.tensor):
elif isinstance(other, int) or isinstance(other, float) or isinstance(other, complex) or isinstance(other, tn.Tensor):
if other != 0:
cores_new = [c+0 for c in self.cores]
cores_new[0] *= other
Expand All @@ -773,7 +773,7 @@ def __mul__(self, other):
# result = zeros([(m,n) for m,n in zip(self.M,self.N)] if self.is_ttm else self.N, device=self.cores[0].device)
else:
raise InvalidArguments(
'Second operand must be of type: TT, float, int of tensorflow Tensor.')
'Second operand must be of type: TT, float, int, complex or tensorflow Tensor.')

return result

Expand Down