Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 27 additions & 1 deletion autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch._inductor.decomposition import select_decomp_table
from torch._inductor.fx_passes.joint_graph import joint_graph_passes
from torch._inductor.fx_passes.post_grad import remove_assert_ops
from torch._logging import trace_structured
from torch._subclasses import FakeTensorMode
from torch.distributed.tensor import DeviceMesh

Expand Down Expand Up @@ -287,6 +288,14 @@ def build_model_graph(self):
# give more room for optimizations
_add_alias(gm)
apply_node_renaming(gm, self.params_len, self.buffer_len, self.metadata)
trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "autoparallel_joint_graph",
"encoding": "string",
},
payload_fn=lambda: str(gm.graph),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be preferable to store the str(gm), as it would let it be a runnable representation of the graph.

)

self.gm = gm

Expand Down Expand Up @@ -330,7 +339,16 @@ def optimize_placement(self, verbose=True):
self.sharding_placement = self.sharding_optimizer.get_solution(verbose=False)

if verbose:
self.sharding_optimizer.print()
print(self.sharding_optimizer.get_log())

trace_structured(
"artifact",
metadata_fn=lambda: {
"name": "autoparallel_sharding_optimizer_log",
"encoding": "string",
},
payload_fn=lambda: self.sharding_optimizer.get_log(colored=False),
)

if self.sharding_optimizer.prob.status == -1:
raise RuntimeError("Didn't find solution")
Expand All @@ -347,6 +365,14 @@ def apply_placement(self, sharding_placement=None):
# clean it up by removing the added aliases from previous pass
# as well as redundant views
parallel_gm = joint_graph_passes(parallel_gm)
trace_structured(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: it would be maybe more readable to store the parallel_gm after the apply_node_renaming as the nodes would have the same name as for the unsharded graph

"artifact",
metadata_fn=lambda: {
"name": "autoparallel_parallel_graph",
"encoding": "string",
},
payload_fn=lambda: str(parallel_gm.graph),
)
# now rename input/param/tangent/output/grad_param/grad_input nodes following
# our convention
apply_node_renaming(
Expand Down
20 changes: 11 additions & 9 deletions autoparallel/optimize_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,15 +313,16 @@ def penalize_inefficient_collectives(self):
# penalize case P -> S(1) as there are additional compute cost
self.ds[(s_i, counter, oi, ii)]["cost"] *= 4

def print_violated_constraints(self):
def get_violated_constraints_log(self):
violated_constraints = [
(k, c) for k, c in self.prob.constraints.items() if not c.valid()
]
print(f"Violated constraints: {[x[0] for x in violated_constraints]}")
log_str = f"Violated constraints: {[x[0] for x in violated_constraints]}"
for cname, c in violated_constraints:
print(f"========= {cname} =============")
log_str += f"\n========= {cname} ============="
for cc, v in c.items():
print(f"{cc}, coeff={v}, value={cc.value()}")
log_str += f"\n{cc}, coeff={v}, value={cc.value()}"
return log_str

def print_old(self):
ds = self.ds
Expand All @@ -335,9 +336,10 @@ def print_old(self):
)
total_cost = sum(ds[x]["cost"] for x in res)
print(f"total_cost: {total_cost:.2f}")
self.print_violated_constraints()
print(self.get_violated_constraints_log())

def get_log(self, colored=False):

def print(self, colored=False):
from torch.fx.graph import _color_fns, _identity

opt = {}
Expand Down Expand Up @@ -375,10 +377,10 @@ def print(self, colored=False):
code[l_id] += line
l_id += 1
code = "\n".join(code)
print(code)
total_cost = sum(self.ds[x]["cost"] for x in self.res)
print(f"total_cost: {total_cost:.2f}")
self.print_violated_constraints()
code += f"\ntotal_cost: {total_cost:.2f}"
code += "\n" + self.get_violated_constraints_log()
return code

def print_costs_for_node(self, node, arg=0, **kwargs):
from tabulate import tabulate # type: ignore
Expand Down