Skip to content

Commit

Permalink
[LLGA JIT fuser] Unary aten::max would have two outputs (#2491)
Browse files Browse the repository at this point in the history
* [Change 1/2] aten::max may be unary & may have two outputs

* [Change 2/2] Add UT

* Fix style

---------

Co-authored-by: Chunyuan WU <chunyuan.wu@intel.com>
  • Loading branch information
sanchitintel and chunyuan-w committed Jan 19, 2024
1 parent 7a7ba23 commit ac613a7
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 4 deletions.
8 changes: 4 additions & 4 deletions csrc/cpu/jit/codegen/onednn/utils.cpp
Expand Up @@ -157,14 +157,14 @@ void convertInputTo0DTensor(

void modifyDtypeOfNode(torch::jit::Node* node, at::ScalarType dtype) {
auto existingDtype =
node->output()->type()->expect<TensorType>()->scalarType();
node->outputs()[0]->type()->expect<TensorType>()->scalarType();
if (existingDtype.has_value()) {
switch (existingDtype.value()) {
case at::ScalarType::Float:
case at::ScalarType::BFloat16:
case at::kInt:
node->output()->setType(
node->output()->type()->expect<TensorType>()->withScalarType(
node->outputs()[0]->setType(
node->outputs()[0]->type()->expect<TensorType>()->withScalarType(
dtype));
break;
default:
Expand All @@ -189,7 +189,7 @@ void insertTypeCast(
}

void mayModifyOutputDtype(torch::jit::Node* node) {
if (node->output()->type()->isSubtypeOf(TensorType::get())) {
if (node->outputs()[0]->type()->isSubtypeOf(TensorType::get())) {
if (node->hasAttributeS("was_float")) {
modifyDtypeOfNode(node, at::ScalarType::Float);
node->removeAttributeS("was_float");
Expand Down
15 changes: 15 additions & 0 deletions tests/cpu/test_jit_llga_fuser.py
Expand Up @@ -405,6 +405,21 @@ def forward(self, x, y):
graph, _ = self.checkTrace(m, [x, y])
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 1)

@llga_fp32_bf16_test_env
def test_max_two_outputs(self):
class M(nn.Module):
def __init__(self):
super(M, self).__init__()

def forward(self, x):
# max is unary, and would have 2 outputs
return torch.max(x, dim=1)

m = M()
x = torch.rand(8, 12, 12, 12)
graph, _ = self.checkTrace(m, [x])
self.assertGraphContainsExactly(graph, LLGA_FUSION_GROUP, 0)

@llga_fp32_bf16_test_env
def test_bmm_div(self):
class M(nn.Module):
Expand Down

0 comments on commit ac613a7

Please sign in to comment.