diff --git a/examples/example_local_map.py b/examples/example_local_map.py index 687b48f1..5039218c 100644 --- a/examples/example_local_map.py +++ b/examples/example_local_map.py @@ -61,20 +61,14 @@ def replicate_linear(w, x): @local_map( - out_placements=( - (Shard(0), Shard(0), Replicate()), - None, - ), - in_placements=( - (Shard(0), Shard(0), Replicate()), - None, - ), + out_placements=((Shard(0), Shard(0), Replicate()),), + in_placements=((Shard(0), Shard(0), Replicate()),), redistribute_inputs=True, in_grad_placements=None, device_mesh=mesh, ) -def sharded_pointwise(x, scalar): - return x + scalar, scalar +def sharded_pointwise(x): + return x + 10 @local_map( @@ -114,7 +108,7 @@ def init_weights(self): torch.nn.init.normal_(lin.bias) def _compute_attention(self, x): - boosted_weight, scalar = sharded_pointwise(self.wq.weight, 10) + boosted_weight = sharded_pointwise(self.wq.weight) q = replicate_linear(boosted_weight, x) k = self.wk(x) v = self.wv(x) diff --git a/tests/test_optimize_placement.py b/tests/test_optimize_placement.py index dc6d1522..c93b868c 100644 --- a/tests/test_optimize_placement.py +++ b/tests/test_optimize_placement.py @@ -372,7 +372,7 @@ def input_fn(): class LocalMapTransformerBlock(nn.Module): - def __init__(self, nheads, dim1, dim2): + def __init__(self, nheads, dim1, dim2, mesh): super().__init__() self.nheads = nheads bias = False @@ -382,6 +382,7 @@ def __init__(self, nheads, dim1, dim2): self.wo = nn.Linear(dim1, dim1, bias=bias) self.w1 = nn.Linear(dim1, dim2, bias=bias) self.w2 = nn.Linear(dim2, dim1, bias=bias) + self.mesh = mesh def forward(self, x): @local_map( @@ -393,7 +394,7 @@ def forward(self, x): ), redistribute_inputs=True, in_grad_placements=None, - device_mesh=None, + device_mesh=self.mesh, ) def _context_parallel_attention(query, key, value): out = F.scaled_dot_product_attention( @@ -435,7 +436,7 @@ def test_local_map_placement_respected(device_mesh_local_map, device="cuda"): seq_len = 256 def model_fn(): - return LocalMapTransformerBlock(nheads, dim1, dim2) + return LocalMapTransformerBlock(nheads, dim1, dim2, device_mesh_local_map) def input_fn(): return torch.randn(bs, seq_len, dim1, device=device, requires_grad=True)