diff --git a/autoparallel/api.py b/autoparallel/api.py index 387228a6..74a75a0c 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -13,6 +13,7 @@ from torch._inductor.decomposition import select_decomp_table from torch._inductor.fx_passes.joint_graph import joint_graph_passes from torch._inductor.fx_passes.post_grad import remove_assert_ops +from torch._logging import trace_structured from torch._subclasses import FakeTensorMode from torch.distributed.tensor import DeviceMesh @@ -287,6 +288,14 @@ def build_model_graph(self): # give more room for optimizations _add_alias(gm) apply_node_renaming(gm, self.params_len, self.buffer_len, self.metadata) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "autoparallel_joint_graph", + "encoding": "string", + }, + payload_fn=lambda: str(gm.graph), + ) self.gm = gm @@ -330,7 +339,16 @@ def optimize_placement(self, verbose=True): self.sharding_placement = self.sharding_optimizer.get_solution(verbose=False) if verbose: - self.sharding_optimizer.print() + print(self.sharding_optimizer.get_log()) + + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "autoparallel_sharding_optimizer_log", + "encoding": "string", + }, + payload_fn=lambda: self.sharding_optimizer.get_log(colored=False), + ) if self.sharding_optimizer.prob.status == -1: raise RuntimeError("Didn't find solution") @@ -347,6 +365,14 @@ def apply_placement(self, sharding_placement=None): # clean it up by removing the added aliases from previous pass # as well as redundant views parallel_gm = joint_graph_passes(parallel_gm) + trace_structured( + "artifact", + metadata_fn=lambda: { + "name": "autoparallel_parallel_graph", + "encoding": "string", + }, + payload_fn=lambda: str(parallel_gm.graph), + ) # now rename input/param/tangent/output/grad_param/grad_input nodes following # our convention apply_node_renaming( diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index 60de41d8..aa054ad0 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -313,15 +313,16 @@ def penalize_inefficient_collectives(self): # penalize case P -> S(1) as there are additional compute cost self.ds[(s_i, counter, oi, ii)]["cost"] *= 4 - def print_violated_constraints(self): + def get_violated_constraints_log(self): violated_constraints = [ (k, c) for k, c in self.prob.constraints.items() if not c.valid() ] - print(f"Violated constraints: {[x[0] for x in violated_constraints]}") + log_str = f"Violated constraints: {[x[0] for x in violated_constraints]}" for cname, c in violated_constraints: - print(f"========= {cname} =============") + log_str += f"\n========= {cname} =============" for cc, v in c.items(): - print(f"{cc}, coeff={v}, value={cc.value()}") + log_str += f"\n{cc}, coeff={v}, value={cc.value()}" + return log_str def print_old(self): ds = self.ds @@ -335,9 +336,10 @@ def print_old(self): ) total_cost = sum(ds[x]["cost"] for x in res) print(f"total_cost: {total_cost:.2f}") - self.print_violated_constraints() + print(self.get_violated_constraints_log()) + + def get_log(self, colored=False): - def print(self, colored=False): from torch.fx.graph import _color_fns, _identity opt = {} @@ -375,10 +377,10 @@ def print(self, colored=False): code[l_id] += line l_id += 1 code = "\n".join(code) - print(code) total_cost = sum(self.ds[x]["cost"] for x in self.res) - print(f"total_cost: {total_cost:.2f}") - self.print_violated_constraints() + code += f"\ntotal_cost: {total_cost:.2f}" + code += "\n" + self.get_violated_constraints_log() + return code def print_costs_for_node(self, node, arg=0, **kwargs): from tabulate import tabulate # type: ignore