From 7dd49832d9a7a8480ef176beeb810206541691b2 Mon Sep 17 00:00:00 2001 From: HansBug Date: Fri, 11 Aug 2023 11:11:00 +0800 Subject: [PATCH] dev(narugo): add buggy test --- test/tree/integration/test_torch.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/test/tree/integration/test_torch.py b/test/tree/integration/test_torch.py index 679977f701..1987156865 100644 --- a/test/tree/integration/test_torch.py +++ b/test/tree/integration/test_torch.py @@ -73,3 +73,24 @@ def foo(x, y, t): b = torch.randn(3, 4) c = torch.randn(3, 4) assert torch.isclose(foo(a, b, c), (a + b * 2000) / (c - 100)).all() + + @skipUnless(vpip('torch') >= '2.0.0' and OS.linux and vpython < '3.11', 'Torch 2 on linux platform required') + def test_torch_compile_buggy(self): + @torch.compile + def foox(x, y): + z = x + y + return z + + x = FastTreeValue({ + 'a': torch.randn(3, 4) + 200, + 'b': torch.randn(5) - 300, + }) + y = FastTreeValue({ + 'a': torch.rand(4) + 500, + 'b': torch.randn(4, 5) + 1000, + }) + + _t_isclose = FastTreeValue.func()(torch.isclose) + + assert _t_isclose(foox(x, y), x + y).all() == \ + FastTreeValue({'a': torch.tensor(True), 'b': torch.tensor(True)})