Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion autoparallel/activation_checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ def _mark_nodes_as_must_save(must_save_nodes: list[torch.fx.Node]) -> None:
for node in must_save_nodes:
if (
node.meta.get("recompute", None) is not None
and node.meta["ac_graph_id"] != AP_AC_GRAPH_ID
and node.meta.get("ac_graph_id", -1) != AP_AC_GRAPH_ID
):
# Let user annotations take precedence
skipped_nodes[node] = node.meta["recompute"]
Expand Down
81 changes: 54 additions & 27 deletions examples/example_local_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import functools

import torch
import torch.fx.traceback as fx_traceback
from torch import nn
from torch.distributed._tensor.experimental import local_map
from torch.distributed.fsdp import MixedPrecisionPolicy
Expand Down Expand Up @@ -57,7 +58,8 @@ def policy_fn(ctx, op, *args, **kwargs):
device_mesh=mesh,
)
def replicate_linear(w, x):
return torch.matmul(x, w.t())
with fx_traceback.annotate({"inside_local_map": 1}):
return torch.matmul(x, w.t())


@local_map(
Expand All @@ -68,7 +70,8 @@ def replicate_linear(w, x):
device_mesh=mesh,
)
def sharded_pointwise(x):
return x + 10
with fx_traceback.annotate({"inside_local_map": 0}):
return x + 10


@local_map(
Expand All @@ -83,10 +86,11 @@ def sharded_pointwise(x):
device_mesh=mesh,
)
def context_parallel_attention(query, key, value):
out = nn.functional.scaled_dot_product_attention(
query=query, key=key, value=value, is_causal=False
)
return out
with fx_traceback.annotate({"inside_local_map": 2}):
out = nn.functional.scaled_dot_product_attention(
query=query, key=key, value=value, is_causal=False
)
return out


class Block(nn.Module):
Expand All @@ -108,35 +112,37 @@ def init_weights(self):
torch.nn.init.normal_(lin.bias)

def _compute_attention(self, x):
boosted_weight = sharded_pointwise(self.wq.weight)
q = replicate_linear(boosted_weight, x)
k = self.wk(x)
v = self.wv(x)
with fx_traceback.annotate({"inside_checkpoint": 0}):
boosted_weight = sharded_pointwise(self.wq.weight)
q = replicate_linear(boosted_weight, x)
k = self.wk(x)
v = self.wv(x)

q = q.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3)
k = k.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3)
v = v.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3)
q = q.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3)
k = k.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3)
v = v.unflatten(-1, (self.nheads, -1)).permute(0, 2, 1, 3)

o = context_parallel_attention(q, k, v)
o = o.permute(0, 2, 1, 3).flatten(-2)
o = context_parallel_attention(q, k, v)
o = o.permute(0, 2, 1, 3).flatten(-2)

o = self.wo(o)
return o
o = self.wo(o)
return o

def forward(self, x):
o = torch.utils.checkpoint.checkpoint(
self._compute_attention, x, use_reentrant=False, context_fn=context_fn
)
with fx_traceback.annotate({"outside_checkpoint": 0}):
o = torch.utils.checkpoint.checkpoint(
self._compute_attention, x, use_reentrant=False, context_fn=context_fn
)

o0 = o + x
o0 = o + x

o = self.w1(o0)
o = torch.nn.functional.relu(o)
o = self.w2(o)
o = self.w1(o0)
o = torch.nn.functional.relu(o)
o = self.w2(o)

o = o0 + o
o = o0 + o

return o
return o


bs = 8 * mesh.shape[0]
Expand All @@ -160,7 +166,9 @@ def input_fn():
mp_policy = MixedPrecisionPolicy(param_dtype=torch.bfloat16)
# mp_policy = None

with AutoParallel(model, input_fn, mesh, mp_policy, compile=True) as autop:
with torch.fx.traceback.preserve_node_meta(), AutoParallel(
model, input_fn, mesh, mp_policy, compile=True
) as autop:
assert any(n.meta.get("nn_module_stack") for n in autop.gm.graph.nodes)
assert any(n.meta.get("fwd_nn_module_stack") for n in autop.gm.graph.nodes)
autop.add_parameter_memory_constraint(low=None, high=None)
Expand Down Expand Up @@ -208,4 +216,23 @@ def input_fn():
op="call_function", target=torch.ops.aten.mm.default
)

metas = [n.meta.get("custom", None) for n in autop.parallel_gm.graph.nodes]
fwd_sdpa, bwd_sdpa = [
n
for n in autop.parallel_gm.graph.nodes
if "_scaled_dot_product_flash_attention" in n.name
]
# TODO: Dynamo HOP body is not preserving the fx_traceback.annotate
# We should expect to also see the "inside_local_map" annotation
assert fwd_sdpa.meta["custom"] == {
"inside_checkpoint": 0,
"inside_local_map": 2,
"outside_checkpoint": 0,
}
assert bwd_sdpa.meta["custom"] == {
"inside_checkpoint": 0,
"inside_local_map": 2,
"outside_checkpoint": 0,
}

print("All good!")
193 changes: 192 additions & 1 deletion tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@

import pytest
import torch
import torch.fx.traceback as fx_traceback
from torch import nn
from torch.distributed.tensor.placement_types import Shard
from torch.distributed.tensor.placement_types import Replicate, Shard
from torch.testing._internal.distributed.fake_pg import FakeStore

from autoparallel.api import AutoParallel
Expand Down Expand Up @@ -114,3 +115,193 @@ def input_fn():
assert torch.equal(
parallel_mod.get_buffer("buf").full_tensor(), torch.arange(dim, device="cuda")
)


def test_fx_graph_annotate(device_mesh_1d):
dim = 128

class Model(nn.Module):
def __init__(self, dim):
super().__init__()
self.a = nn.Linear(dim, dim, bias=False)
self.b = nn.Linear(dim, dim, bias=False)
self.c = nn.Linear(dim, dim, bias=False)
self.d = nn.Linear(dim, dim, bias=False)

def forward(self, x):
with fx_traceback.annotate({"outer": 0}):
with fx_traceback.annotate({"inner": 0}):
a = self.a(x)
with fx_traceback.annotate({"inner": 1}):
b = self.b(a)
with fx_traceback.annotate({"inner": 2}):
c = self.c(b)
with fx_traceback.annotate({"inner": 3}):
d = self.d(c)
return d

def input_fn():
b = 512
inputs = (torch.rand(b, dim, device="cuda"),)
return inputs

with torch.device("meta"):
model = Model(dim)

with fx_traceback.preserve_node_meta(), AutoParallel(
model,
input_fn,
device_mesh_1d,
) as autop:
x_sharding = (Shard(0),)
autop.add_input_constraints([x_sharding])
sharding_placement = autop.optimize_placement()

# AutoParallel produces a module with meta-DTensor parameters that need to be initialized
_ = autop.apply_placement(sharding_placement)

graph = autop.parallel_gm.graph

# 4 linear -> 4 mm ops
fw_seen_annotations = set()
bw_seen_annotations = set()
for mm in [n for n in graph.nodes if "mm" in n.name]:
assert mm.meta["custom"]["outer"] == 0
assert "inner" in mm.meta["custom"]
if mm.meta.get("partitioner_tag", "") == "is_backward":
bw_seen_annotations.add(mm.meta["custom"]["inner"])
else:
fw_seen_annotations.add(mm.meta["custom"]["inner"])
assert fw_seen_annotations == bw_seen_annotations == {0, 1, 2, 3}

for ph in graph.find_nodes(op="placeholder"):
assert (
"custom" not in ph.meta
), "Placeholders didn't have have custom metadata before"
for out in graph.find_nodes(op="output"):
assert (
"custom" not in out.meta
), "Output didn't have have custom metadata before"

# NOTE: The tests below are just to prevent semantics from changing silently.
# Currently, custom metadata is not set for:
# - graph inputs
# - graph outputs
# - collectives/waits added by AP
for node in graph.nodes:
if node.meta.get("custom", None) is None:
assert (
node.op == "placeholder"
or node.op == "output"
or node.target.namespace == "_c10d_functional"
)


def test_fx_graph_annotate_overlap_pass(device_mesh_1d):
class DummyOp(torch.autograd.Function):
@staticmethod
def forward(ctx, x, scalar):
ctx.save_for_backward(x)
return x + scalar

@staticmethod
def backward(ctx, grad_out):
return grad_out, None

def mock_fw_compute(x):
with fx_traceback.annotate({"compute": 0}):
return DummyOp.apply(x, 10)

def mock_bw_comm(x):
with fx_traceback.annotate({"comm": 0}):
return DummyOp.apply(x, 20)

def mock_bw_compute(x):
return DummyOp.apply(x, 30)

class Model(nn.Module):
def forward(self, fw_in, bw_in):
fw_out = mock_fw_compute(fw_in)
# bw_in blocks bw_out
bw_in = mock_bw_comm(bw_in)
bw_out = mock_bw_compute(bw_in)
return fw_out, bw_out

def input_fn():
inputs = (torch.rand(2, 128, device="cuda", requires_grad=True),)
grad_ins = (torch.rand(2, 128, device="cuda"),)
return (
*inputs,
*grad_ins,
)

with torch.device("meta"):
model = Model()

with fx_traceback.preserve_node_meta(), AutoParallel(
model,
input_fn,
device_mesh_1d,
) as autop:
autop.add_input_constraints(
[
(Replicate(),),
(Replicate(),),
]
)
autop.add_output_constraints(
[
(Replicate(),),
(Replicate(),),
]
)
sharding_placement = autop.optimize_placement()

# AutoParallel produces a module with meta-DTensor parameters that need to be initialized
_ = autop.apply_placement(sharding_placement)

graph = autop.parallel_gm.graph

# At this point, the graph looks like:
# graph():
# %primals_1 : [num_users=1] = placeholder[target=primals_1]
# %primals_2 : [num_users=1] = placeholder[target=primals_2]
# %tangents_1 : [num_users=1] = placeholder[target=tangents_1]
# %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%primals_1, 10), kwargs = {})
# %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%primals_2, 20), kwargs = {})
# %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_1, 30), kwargs = {})
# return ((add, add_2), (tangents_1, None))

compute_nodes = {
n for n in graph.nodes if n.meta.get("custom", {}).get("compute", None) == 0
}
comm_nodes = [
n for n in graph.nodes if n.meta.get("custom", {}).get("comm", None) == 0
]
assert len(compute_nodes) == 1
assert len(comm_nodes) == 1

# move comm nodes before compute nodes
first_compute_node = None
for n in graph.nodes:
if n in compute_nodes:
first_compute_node = n
break

assert first_compute_node is not None
for node in reversed(comm_nodes):
first_compute_node.prepend(node)

# After pass, add_1 (comm) should be before add (compute)
node_names = [n.name for n in graph.nodes]
assert node_names.index("add_1") == node_names.index("add") - 1

# The graph looks like:
# graph():
# %primals_1 : [num_users=1] = placeholder[target=primals_1]
# %primals_2 : [num_users=1] = placeholder[target=primals_2]
# %tangents_1 : [num_users=1] = placeholder[target=tangents_1]
# %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%primals_2, 20), kwargs = {})
# %add : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%primals_1, 10), kwargs = {})
# %add_2 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%add_1, 30), kwargs = {})
# return ((add, add_2), (tangents_1, None))
Loading