-
Notifications
You must be signed in to change notification settings - Fork 625
Closed
Description
First some good news - we're pretty close to getting Torch->TOSA working for ResNet18 static network form. Now the bad news: the primary blocker is the strange failure to resolve the aten.maxpool2d op. This can be reproduced with the following simple e2e test:
class MaxPool2dStaticModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.mp2d = torch.nn.MaxPool2d(kernel_size=[3, 3],
stride=[2, 2],
padding=[1, 1],
dilation=[1, 1])
@export
@annotate_args([
None,
([1, 64, 112, 112], torch.float32, True),
])
def forward(self, x):
return self.mp2d(x)
@register_test_case(module_factory=lambda: MaxPool2dStaticModule())
def MaxPool2dStaticModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 64, 112, 112))
This emits:
%int3 = torch.constant.int 3 loc(#loc1)
%int1 = torch.constant.int 1 loc(#loc2)
%int2 = torch.constant.int 2 loc(#loc3)
%false = torch.constant.bool false loc(#loc4)
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<!torch.int> loc(#loc5)
%1 = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list<!torch.int> loc(#loc5)
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int> loc(#loc5)
%3 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<!torch.int> loc(#loc5)
%4 = torch.aten.max_pool2d %arg0, %0, %1, %2, %3, %false : !torch.vtensor<[1,64,112,112],f32>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.list<!torch.int>, !torch.bool -> !torch.vtensor<[1,64,?,?],f32> loc(#loc6)
This is strange, since I expected #502 would fix it, and looking at the code, there's no reason why it would fail to do so. It appears to bail when converting op.stride() , but it's not clear why, since I'm able to independently legalize maxpool to TOSA and read stride() during that time.
Even more interestingly, the problem seems to manifest itself only in maxpool but not conv2d. I'm digging into the code further, but any insights would help.
Metadata
Metadata
Assignees
Labels
No labels