From 7111038731e7155b41392dbcaa75c9e0712a6d94 Mon Sep 17 00:00:00 2001 From: Yang Chen Date: Thu, 21 Sep 2023 23:16:06 -0700 Subject: [PATCH] Fixed minor issues for bmm/mm decompositon (#109836) Summary: * Fixed minor issues for bmm/mm decompositon * enabled addmm for inductor Test Plan: ci Reviewed By: mikekgfb Differential Revision: D49522332 --- torch/_inductor/decomposition.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index da407f65ebb75..391e823e3fe6e 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -194,7 +194,7 @@ def all_dim(input, dim, keepdim=False): @register_decomposition([aten.bmm]) def bmm(self, batch2): - if self.device == "cpu": + if self.device.type == "cpu": if self.size(1) == 1 and batch2.size(-1) == 1: return torch.sum( self.squeeze(1) * batch2.squeeze(-1), dim=1, keepdim=True @@ -209,7 +209,7 @@ def mm(self, input2): if config.coordinate_descent_tuning: if self.shape[0] == 1 or input2.shape[1] == 1: return (self.unsqueeze(2) * input2.unsqueeze(0)).sum(dim=1) - if self.device == "cpu": + if self.device.type == "cpu": if ( self.size(-1) == 1 and input2.size(0) == 1