Skip to content

Commit

Permalink
Fixed minor issues for bmm/mm decompositon (pytorch#109836)
Browse files Browse the repository at this point in the history
Summary:

* Fixed minor issues for bmm/mm decompositon
* enabled addmm for inductor

Test Plan: ci

Reviewed By: mikekgfb

Differential Revision: D49522332
  • Loading branch information
chenyang78 authored and facebook-github-bot committed Sep 22, 2023
1 parent cd99cdc commit 7111038
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions torch/_inductor/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ def all_dim(input, dim, keepdim=False):

@register_decomposition([aten.bmm])
def bmm(self, batch2):
if self.device == "cpu":
if self.device.type == "cpu":
if self.size(1) == 1 and batch2.size(-1) == 1:
return torch.sum(
self.squeeze(1) * batch2.squeeze(-1), dim=1, keepdim=True
Expand All @@ -209,7 +209,7 @@ def mm(self, input2):
if config.coordinate_descent_tuning:
if self.shape[0] == 1 or input2.shape[1] == 1:
return (self.unsqueeze(2) * input2.unsqueeze(0)).sum(dim=1)
if self.device == "cpu":
if self.device.type == "cpu":
if (
self.size(-1) == 1
and input2.size(0) == 1
Expand Down

0 comments on commit 7111038

Please sign in to comment.