diff --git a/autoparallel/activation_checkpointing.py b/autoparallel/activation_checkpointing.py index 5cafd920..aafcc2ea 100644 --- a/autoparallel/activation_checkpointing.py +++ b/autoparallel/activation_checkpointing.py @@ -184,7 +184,7 @@ def _mark_nodes_as_must_save(must_save_nodes: list[torch.fx.Node]) -> None: for node in must_save_nodes: if ( node.meta.get("recompute", None) is not None - and node.meta["ac_graph_id"] != AP_AC_GRAPH_ID + and node.meta.get("ac_graph_id", -1) != AP_AC_GRAPH_ID ): # Let user annotations take precedence skipped_nodes[node] = node.meta["recompute"] diff --git a/examples/example_local_map.py b/examples/example_local_map.py index 5039218c..4c02ab2a 100644 --- a/examples/example_local_map.py +++ b/examples/example_local_map.py @@ -6,6 +6,7 @@ import functools import torch +import torch.fx.traceback as fx_traceback from torch import nn from torch.distributed._tensor.experimental import local_map from torch.distributed.fsdp import MixedPrecisionPolicy @@ -57,7 +58,8 @@ def policy_fn(ctx, op, *args, **kwargs): device_mesh=mesh, ) def replicate_linear(w, x): - return torch.matmul(x, w.t()) + with fx_traceback.annotate({"inside_local_map": 1}): + return torch.matmul(x, w.t()) @local_map( @@ -68,7 +70,8 @@ def replicate_linear(w, x): device_mesh=mesh, ) def sharded_pointwise(x): - return x + 10 + with fx_traceback.annotate({"inside_local_map": 0}): + return x + 10 @local_map( @@ -83,10 +86,11 @@ def sharded_pointwise(x): device_mesh=mesh, ) def context_parallel_attention(query, key, value): - out = nn.functional.scaled_dot_product_attention( - query=query, key=key, value=value, is_causal=False - ) - return out + with fx_traceback.annotate({"inside_local_map": 2}): + out = nn.functional.scaled_dot_product_attention( + query=query, key=key, value=value, is_causal=False + ) + return out class Block(nn.Module): @@ -108,35 +112,37 @@ def init_weights(self): torch.nn.init.normal_(lin.bias) def _compute_attention(self, x): - boosted_weight = sharded_pointwise(self.wq.weight) - q = replicate_linear(boosted_weight, x) - k = self.wk(x) - v = self.wv(x) + with fx_traceback.annotate({"inside_checkpoint": 0}): + boosted_weight = sharded_pointwise(self.wq.weight) + q = replicate_linear(boosted_weight, x) + k = self.wk(x) + v = self.wv(x) - q = q.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3) - k = k.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3) - v = v.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3) + q = q.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3) + k = k.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3) + v = v.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3) - o = context_parallel_attention(q, k, v) - o = o.permute(0, 2, 1, 3).flatten(-2) + o = context_parallel_attention(q, k, v) + o = o.permute(0, 2, 1, 3).flatten(-2) - o = self.wo(o) - return o + o = self.wo(o) + return o def forward(self, x): - o = torch.utils.checkpoint.checkpoint( - self._compute_attention, x, use_reentrant=False, context_fn=context_fn - ) + with fx_traceback.annotate({"outside_checkpoint": 0}): + o = torch.utils.checkpoint.checkpoint( + self._compute_attention, x, use_reentrant=False, context_fn=context_fn + ) - o0 = o + x + o0 = o + x - o = self.w1(o0) - o = torch.nn.functional.relu(o) - o = self.w2(o) + o = self.w1(o0) + o = torch.nn.functional.relu(o) + o = self.w2(o) - o = o0 + o + o = o0 + o - return o + return o bs = 8 * mesh.shape[0] @@ -160,7 +166,9 @@ def input_fn(): mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16) # mp_policy = None -with AutoParallel(model, input_fn, mesh, mp_policy, compile=True) as autop: +with torch.fx.traceback.preserve_node_meta(), AutoParallel( + model, input_fn, mesh, mp_policy, compile=True +) as autop: assert any(n.meta.get("nn_module_stack") for n in autop.gm.graph.nodes) assert any(n.meta.get("fwd_nn_module_stack") for n in autop.gm.graph.nodes) autop.add_parameter_memory_constraint(low=None, high=None) @@ -208,4 +216,23 @@ def input_fn(): op="call_function", target=torch.ops.aten.mm.default ) +metas = [n.meta.get("custom", None) for n in autop.parallel_gm.graph.nodes] +fwd_sdpa, bwd_sdpa = [ + n + for n in autop.parallel_gm.graph.nodes + if "_scaled_dot_product_flash_attention" in n.name +] +# TODO: Dynamo HOP body is not preserving the fx_traceback.annotate +# We should expect to also see the "inside_local_map" annotation +assert fwd_sdpa.meta["custom"] == { + "inside_checkpoint": 0, + "inside_local_map": 2, + "outside_checkpoint": 0, +} +assert bwd_sdpa.meta["custom"] == { + "inside_checkpoint": 0, + "inside_local_map": 2, + "outside_checkpoint": 0, +} + print("All good!") diff --git a/tests/test_api.py b/tests/test_api.py index 04af2b80..2369ba97 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -5,8 +5,9 @@ import pytest import torch +import torch.fx.traceback as fx_traceback from torch import nn -from torch.distributed.tensor.placement_types import Shard +from torch.distributed.tensor.placement_types import Replicate, Shard from torch.testing._internal.distributed.fake_pg import FakeStore from autoparallel.api import AutoParallel @@ -114,3 +115,193 @@ def input_fn(): assert torch.equal( parallel_mod.get_buffer("buf").full_tensor(), torch.arange(dim, device="cuda") ) + + +def test_fx_graph_annotate(device_mesh_1d): + dim = 128 + + class Model(nn.Module): + def __init__(self, dim): + super().__init__() + self.a = nn.Linear(dim, dim, bias=False) + self.b = nn.Linear(dim, dim, bias=False) + self.c = nn.Linear(dim, dim, bias=False) + self.d = nn.Linear(dim, dim, bias=False) + + def forward(self, x): + with fx_traceback.annotate({"outer": 0}): + with fx_traceback.annotate({"inner": 0}): + a = self.a(x) + with fx_traceback.annotate({"inner": 1}): + b = self.b(a) + with fx_traceback.annotate({"inner": 2}): + c = self.c(b) + with fx_traceback.annotate({"inner": 3}): + d = self.d(c) + return d + + def input_fn(): + b = 512 + inputs = (torch.rand(b, dim, device="cuda"),) + return inputs + + with torch.device("meta"): + model = Model(dim) + + with fx_traceback.preserve_node_meta(), AutoParallel( + model, + input_fn, + device_mesh_1d, + ) as autop: + x_sharding = (Shard(0),) + autop.add_input_constraints([x_sharding]) + sharding_placement = autop.optimize_placement() + + # AutoParallel produces a module with meta-DTensor parameters that need to be initialized + _ = autop.apply_placement(sharding_placement) + + graph = autop.parallel_gm.graph + + # 4 linear -> 4 mm ops + fw_seen_annotations = set() + bw_seen_annotations = set() + for mm in [n for n in graph.nodes if "mm" in n.name]: + assert mm.meta["custom"]["outer"] == 0 + assert "inner" in mm.meta["custom"] + if mm.meta.get("partitioner_tag", "") == "is_backward": + bw_seen_annotations.add(mm.meta["custom"]["inner"]) + else: + fw_seen_annotations.add(mm.meta["custom"]["inner"]) + assert fw_seen_annotations == bw_seen_annotations == {0, 1, 2, 3} + + for ph in graph.find_nodes(op="placeholder"): + assert ( + "custom" not in ph.meta + ), "Placeholders didn't have have custom metadata before" + for out in graph.find_nodes(op="output"): + assert ( + "custom" not in out.meta + ), "Output didn't have have custom metadata before" + + # NOTE: The tests below are just to prevent semantics from changing silently. + # Currently, custom metadata is not set for: + # - graph inputs + # - graph outputs + # - collectives/waits added by AP + for node in graph.nodes: + if node.meta.get("custom", None) is None: + assert ( + node.op == "placeholder" + or node.op == "output" + or node.target.namespace == "_c10d_functional" + ) + + +def test_fx_graph_annotate_overlap_pass(device_mesh_1d): + class DummyOp(torch.autograd.Function): + @staticmethod + def forward(ctx, x, scalar): + ctx.save_for_backward(x) + return x + scalar + + @staticmethod + def backward(ctx, grad_out): + return grad_out, None + + def mock_fw_compute(x): + with fx_traceback.annotate({"compute": 0}): + return DummyOp.apply(x, 10) + + def mock_bw_comm(x): + with fx_traceback.annotate({"comm": 0}): + return DummyOp.apply(x, 20) + + def mock_bw_compute(x): + return DummyOp.apply(x, 30) + + class Model(nn.Module): + def forward(self, fw_in, bw_in): + fw_out = mock_fw_compute(fw_in) + # bw_in blocks bw_out + bw_in = mock_bw_comm(bw_in) + bw_out = mock_bw_compute(bw_in) + return fw_out, bw_out + + def input_fn(): + inputs = (torch.rand(2, 128, device="cuda", requires_grad=True),) + grad_ins = (torch.rand(2, 128, device="cuda"),) + return ( + *inputs, + *grad_ins, + ) + + with torch.device("meta"): + model = Model() + + with fx_traceback.preserve_node_meta(), AutoParallel( + model, + input_fn, + device_mesh_1d, + ) as autop: + autop.add_input_constraints( + [ + (Replicate(),), + (Replicate(),), + ] + ) + autop.add_output_constraints( + [ + (Replicate(),), + (Replicate(),), + ] + ) + sharding_placement = autop.optimize_placement() + + # AutoParallel produces a module with meta-DTensor parameters that need to be initialized + _ = autop.apply_placement(sharding_placement) + + graph = autop.parallel_gm.graph + + # At this point, the graph looks like: + # graph(): + # %primals_1 : [num_users=1] = placeholder[target=primals_1] + # %primals_2 : [num_users=1] = placeholder[target=primals_2] + # %tangents_1 : [num_users=1] = placeholder[target=tangents_1] + # %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%primals_1, 10), kwargs = {}) + # %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%primals_2, 20), kwargs = {}) + # %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_1, 30), kwargs = {}) + # return ((add, add_2), (tangents_1, None)) + + compute_nodes = { + n for n in graph.nodes if n.meta.get("custom", {}).get("compute", None) == 0 + } + comm_nodes = [ + n for n in graph.nodes if n.meta.get("custom", {}).get("comm", None) == 0 + ] + assert len(compute_nodes) == 1 + assert len(comm_nodes) == 1 + + # move comm nodes before compute nodes + first_compute_node = None + for n in graph.nodes: + if n in compute_nodes: + first_compute_node = n + break + + assert first_compute_node is not None + for node in reversed(comm_nodes): + first_compute_node.prepend(node) + + # After pass, add_1 (comm) should be before add (compute) + node_names = [n.name for n in graph.nodes] + assert node_names.index("add_1") == node_names.index("add") - 1 + + # The graph looks like: + # graph(): + # %primals_1 : [num_users=1] = placeholder[target=primals_1] + # %primals_2 : [num_users=1] = placeholder[target=primals_2] + # %tangents_1 : [num_users=1] = placeholder[target=tangents_1] + # %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%primals_2, 20), kwargs = {}) + # %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%primals_1, 10), kwargs = {}) + # %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_1, 30), kwargs = {}) + # return ((add, add_2), (tangents_1, None))