-
Notifications
You must be signed in to change notification settings - Fork 360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nn.Conv2d output mismatch with torch.nn.Conv2d #66
Comments
Could you directly modify this unittest to see if there anything abnormal? https://github.com/facebookincubator/AITemplate/blob/main/tests/unittest/ops/test_conv.py |
When I change that UT it passes: class ConvTestCase(unittest.TestCase):
def test_fp16(self, batch=1):
target = detect_target()
X = Tensor(
shape=[1, 384, 384, 4],
dtype="float16",
name="input_0",
is_input=True,
)
W = Tensor(
shape=[256, 7, 7, 4], dtype="float16", name="input_1", is_input=True
)
OP = ops.conv2d(stride=4, pad=3, dilate=1)
Y = OP(X, W)
Y._attrs["name"] = "output_0"
Y._attrs["is_output"] = True
module = compile_model(Y, target, "./tmp", "conv2d")
X_pt = torch.randn(1, 4, 384, 384).cuda().half()
W_pt = torch.randn(256, 4, 7, 7).cuda().half()
Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=3, stride=4)
x = X_pt.permute((0, 2, 3, 1)).contiguous()
w = W_pt.permute((0, 2, 3, 1)).contiguous()
y = torch.empty([1, 96, 96, 256]).cuda().half()
module.run_with_tensors({"input_0": x, "input_1": w}, [y])
y_transpose = y.permute((0, 3, 1, 2))
if target.name() == "cuda":
self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-2, rtol=1e-2))
else:
self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1.25e-1, rtol=1e-1)) Will have to look further to see what is the difference between module and OP versions |
This visualization tool is very helpful for investigating: https://facebookincubator.github.io/AITemplate/tutorial/how_to_visualize.html |
Yes, check attributes especially op_type. I suspect frontend conv2d doesn’t
map to correct ops.
On Tue, Nov 1, 2022 at 17:31 Ehsan Azar ***@***.***> wrote:
The visualization is pretty simple
[image: image]
<https://user-images.githubusercontent.com/873905/199366695-424afe4e-8a55-4929-bda9-01000ae35155.png>
—
Reply to this email directly, view it on GitHub
<#66 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAJTLXQQJAIDFR3DQY32SR3WGGY57ANCNFSM6AAAAAARUQH5TM>
.
You are receiving this because you commented.Message ID:
***@***.***>
--
Bing Xu
|
The only attribute is op_type. For class ConvBiasTestCase(unittest.TestCase):
def test_fp16(self, batch=4):
target = detect_target()
X = Tensor(
shape=[1, 384, 384, 4],
dtype="float16",
name="input_0",
is_input=True,
)
W = Tensor(
shape=[256, 7, 7, 4], dtype="float16", name="input_1", is_input=True
)
B = Tensor(shape=[256], dtype="float16", name="input_2", is_input=True)
OP = ops.conv2d_bias(stride=4, pad=3, dilate=1)
Y = OP(X, W, B)
Y._attrs["name"] = "output_0"
Y._attrs["is_output"] = True
module = compile_model(Y, target, "./tmp", "conv2d_bias")
X_pt = torch.randn(1, 4, 384, 384).cuda().half()
W_pt = torch.randn(256, 4, 7, 7).cuda().half()
B_pt = torch.randn(1, 256, 1, 1).cuda().half()
Y_pt = torch.nn.functional.conv2d(X_pt, W_pt, padding=3, stride=4)
Y_pt = Y_pt + B_pt
x = X_pt.permute((0, 2, 3, 1)).contiguous()
w = W_pt.permute((0, 2, 3, 1)).contiguous()
inputs = {"input_0": x, "input_1": w, "input_2": B_pt.squeeze()}
y = torch.empty([1, 96, 96, 256]).cuda().half()
module.run_with_tensors(inputs, [y])
y_transpose = y.permute((0, 3, 1, 2))
if target.name() == "cuda":
self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1e-2, rtol=1e-2))
else:
self.assertTrue(torch.allclose(Y_pt, y_transpose, atol=1.25e-1, rtol=1e-1)) |
I think this was because I had to delete the temp folder. |
This is a full repro:
nn.Conv2dBias
andnn.Conv2d
which did not help (Conv has bias but it was not clear ifnn.Conv2dBias
is same asnn.Conv2d
).The text was updated successfully, but these errors were encountered: