From c066442e426b83b2bb5504658a82bb41faf36852 Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 24 Jan 2024 12:06:38 -0800 Subject: [PATCH 1/3] write tests --- test/test_compile.py | 77 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/test/test_compile.py b/test/test_compile.py index 9b88811a..2d75fb21 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -3,6 +3,7 @@ # # This source code is licensed under the BSD 3-Clause license found in the # LICENSE file in the root directory of this source tree. +import copy import random import unittest @@ -11,6 +12,7 @@ import torch import torch.nn as nn from float8_experimental.float8_linear_utils import get_float8_linear, LinearType +from float8_experimental.float8_tensor import Float8Tensor # Setting to unblock for calling contiguous in backwards is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) @@ -76,5 +78,80 @@ def test_inductor(fullgraph, emulate: bool, linear_type: bool, dtype: torch.dtyp _test_compile_base("inductor", fullgraph, emulate, linear_type, dtype) +class TestGraphBreaks: + class MockLinear(torch.nn.Module): + def __init__(self, graph_break: bool): + super().__init__() + self.register_buffer("fp8_amax_x", torch.tensor(1.0)) + self.register_buffer("fp8_scale_x", torch.tensor(1.0)) + self.graph_break = graph_break + + def forward(self, x): + x_fp8 = Float8Tensor.to_float8( + x, + self.fp8_scale_x, + torch.float8_e4m3fn, + self.fp8_amax_x, + emulate=True, # TODO: I set this to True so that people on A100 can test, but once fix is in, set to False + ) + if self.graph_break: + torch._dynamo.graph_break() + x_hp = x_fp8.to_original_precision() + return x_hp + return x_fp8 + + @pytest.mark.xfail(reason="TODO: Fix this test, see TODO in MockLinear") + def test_float8_with_graph_break_in_the_middle(self): + """Test that having Float8Tensor object at the boundary of a subgraph""" + mod = self.MockLinear(graph_break=True).cuda() + compiled_mod = copy.deepcopy(mod) + compiled_mod = torch.compile(compiled_mod) + x = torch.randn(16, 16, device="cuda") + y_eager = mod(x) + y_compiled = compiled_mod(x) + torch.testing.assert_close(y_eager, y_compiled) + + def test_float8_graph_input(self): + """Test that having Float8Tensor object as a graph input""" + + def to_float(x): + return x.to_original_precision() + + to_float = torch.compile(to_float) + + mod = self.MockLinear(graph_break=False).cuda() + x = torch.randn(2, 2, device="cuda") + compiled_to_float = torch.compile(to_float) + y = mod(x) + y2_eager = to_float(y) + y2_compiled = compiled_to_float(y) + torch.testing.assert_close(y2_eager, y2_compiled) + + @pytest.mark.xfail(reason="TODO: Fix this test, see TODO in MockLinear") + def test_float8_graph_output(self): + """Test that having Float8Tensor object as a graph output works""" + mod = self.MockLinear(graph_break=False).cuda() + compiled_mod = torch.compile(mod) + x = torch.randn(16, 16, device="cuda") + y_compiled = compiled_mod(x) + + assert not isinstance( + y_compiled._data, torch._subclasses.fake_tensor.FakeTensor + ), "Float8Tensor._data should not be a FakeTensor!" + assert isinstance( + y_compiled._scale, torch._subclasses.fake_tensor.FakeTensor + ), "Float8Tensor._scale should not be a FakeTensor!" + assert isinstance( + y_compiled._orig_dtype, torch.dtype + ), "Float8Tensor._orig_dtype should be a dtype but got {}".format( + type(y_compiled._orig_dtype) + ) + assert isinstance( + y_compiled._emulate, bool + ), "Float8Tensor._emulate should be a bool but got {}".format( + type(y_compiled._emulate) + ) + + if __name__ == "__main__": pytest.main([__file__]) From cd5f9805c7821482304613a18689f63316eb980b Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 24 Jan 2024 12:16:18 -0800 Subject: [PATCH 2/3] use tensor flatten in test --- test/test_compile.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/test/test_compile.py b/test/test_compile.py index 2d75fb21..ed791bf7 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -135,12 +135,11 @@ def test_float8_graph_output(self): x = torch.randn(16, 16, device="cuda") y_compiled = compiled_mod(x) - assert not isinstance( - y_compiled._data, torch._subclasses.fake_tensor.FakeTensor - ), "Float8Tensor._data should not be a FakeTensor!" - assert isinstance( - y_compiled._scale, torch._subclasses.fake_tensor.FakeTensor - ), "Float8Tensor._scale should not be a FakeTensor!" + tensors, ctx = y_compiled.__tensor_flatten__() + for tensor in tensors: + assert not isinstance( + getattr(y_compiled, tensor), torch._subclasses.fake_tensor.FakeTensor + ), "Float8Tensor should not contain any FakeTensors!" assert isinstance( y_compiled._orig_dtype, torch.dtype ), "Float8Tensor._orig_dtype should be a dtype but got {}".format( From 7cdb54eee15e33f66a8b5db4b60f8a08aa2a881e Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 24 Jan 2024 13:03:35 -0800 Subject: [PATCH 3/3] add frame count checks, thanks voz --- test/test_compile.py | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/test/test_compile.py b/test/test_compile.py index ed791bf7..d39b7400 100644 --- a/test/test_compile.py +++ b/test/test_compile.py @@ -14,6 +14,9 @@ from float8_experimental.float8_linear_utils import get_float8_linear, LinearType from float8_experimental.float8_tensor import Float8Tensor +from torch._dynamo.test_case import TestCase as DynamoTestCase +from torch._dynamo.testing import CompileCounterWithBackend + # Setting to unblock for calling contiguous in backwards is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (9, 0) @@ -78,7 +81,7 @@ def test_inductor(fullgraph, emulate: bool, linear_type: bool, dtype: torch.dtyp _test_compile_base("inductor", fullgraph, emulate, linear_type, dtype) -class TestGraphBreaks: +class TestGraphBreaks(DynamoTestCase): class MockLinear(torch.nn.Module): def __init__(self, graph_break: bool): super().__init__() @@ -103,12 +106,14 @@ def forward(self, x): @pytest.mark.xfail(reason="TODO: Fix this test, see TODO in MockLinear") def test_float8_with_graph_break_in_the_middle(self): """Test that having Float8Tensor object at the boundary of a subgraph""" + cnts = CompileCounterWithBackend("inductor") mod = self.MockLinear(graph_break=True).cuda() compiled_mod = copy.deepcopy(mod) - compiled_mod = torch.compile(compiled_mod) + compiled_mod = torch.compile(compiled_mod, backend=cnts) x = torch.randn(16, 16, device="cuda") y_eager = mod(x) y_compiled = compiled_mod(x) + self.assertEqual(cnts.frame_count, 2, "Compiled graph should have 2 frames!") torch.testing.assert_close(y_eager, y_compiled) def test_float8_graph_input(self): @@ -117,24 +122,30 @@ def test_float8_graph_input(self): def to_float(x): return x.to_original_precision() - to_float = torch.compile(to_float) - + cnts = CompileCounterWithBackend("inductor") mod = self.MockLinear(graph_break=False).cuda() x = torch.randn(2, 2, device="cuda") - compiled_to_float = torch.compile(to_float) + compiled_to_float = torch.compile(to_float, backend=cnts) y = mod(x) y2_eager = to_float(y) y2_compiled = compiled_to_float(y) + self.assertEqual( + cnts.frame_count, + 1, + "to_float was not compiled into 1 frame and likely encountered a skip!", + ) torch.testing.assert_close(y2_eager, y2_compiled) @pytest.mark.xfail(reason="TODO: Fix this test, see TODO in MockLinear") def test_float8_graph_output(self): """Test that having Float8Tensor object as a graph output works""" + cnts = CompileCounterWithBackend("inductor") mod = self.MockLinear(graph_break=False).cuda() - compiled_mod = torch.compile(mod) + compiled_mod = torch.compile(mod, backend=cnts) x = torch.randn(16, 16, device="cuda") y_compiled = compiled_mod(x) + self.assertEqual(cnts.frame_count, 1, "Compiled graph should have 1 frame!") tensors, ctx = y_compiled.__tensor_flatten__() for tensor in tensors: assert not isinstance(