Skip to content

aten.maxpool2d shape resolution failure #545

@sjarus

Description

@sjarus

First some good news - we're pretty close to getting Torch->TOSA working for ResNet18 static network form. Now the bad news: the primary blocker is the strange failure to resolve the aten.maxpool2d op. This can be reproduced with the following simple e2e test:

class MaxPool2dStaticModule(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.mp2d = torch.nn.MaxPool2d(kernel_size=[3, 3],
                                       stride=[2, 2],
                                       padding=[1, 1],
                                       dilation=[1, 1])
 
   @export
    @annotate_args([
        None,
        ([1, 64, 112, 112], torch.float32, True),
    ])
    def forward(self, x):
        return self.mp2d(x)
@register_test_case(module_factory=lambda: MaxPool2dStaticModule())
def MaxPool2dStaticModule_basic(module, tu: TestUtils):
    module.forward(tu.rand(1, 64, 112, 112))

This emits:

%int3 = torch.constant.int 3 loc(#loc1)
%int1 = torch.constant.int 1 loc(#loc2)
%int2 = torch.constant.int 2 loc(#loc3)
%false = torch.constant.bool false loc(#loc4)
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<!torch.int> loc(#loc5)
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<!torch.int> loc(#loc5)
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int> loc(#loc5)
%3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int> loc(#loc5)
%4 = torch.aten.max_pool2d %arg0, %0, %1, %2, %3, %false : !torch.vtensor<[1,64,112,112],f32>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.bool -> !torch.vtensor<[1,64,?,?],f32> loc(#loc6)

This is strange, since I expected #502 would fix it, and looking at the code, there's no reason why it would fail to do so. It appears to bail when converting op.stride() , but it's not clear why, since I'm able to independently legalize maxpool to TOSA and read stride() during that time.

Even more interestingly, the problem seems to manifest itself only in maxpool but not conv2d. I'm digging into the code further, but any insights would help.

@ramiro050 @cathyzhyi @ljfitz

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions