Skip to content

Commit

Permalink
[ETConverter] Update text converter to new schema
Browse files Browse the repository at this point in the history
  • Loading branch information
changhai0109 committed Dec 7, 2023
1 parent e7d7a7c commit a92cbaa
Showing 1 changed file with 29 additions and 1 deletion.
30 changes: 29 additions & 1 deletion et_converter/text2chakra_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
ALL_TO_ALL,
ALL_GATHER,
REDUCE_SCATTER,
GlobalMetadata
)

class Layer:
Expand Down Expand Up @@ -66,6 +67,17 @@ def __init__(
self.num_passes = num_passes
self.logger = logger
self.next_node_id = 0

def get_global_metadata(self):
input_text = ""
with open(self.input_filename, "r") as input_file:
input_text = input_file.read()
attr = [
ChakraAttr(name="schema", string_val="text2chakra_converter"),
ChakraAttr(name="input_file", string_val=input_text)
]
metadata = GlobalMetadata(attr=attr)
return metadata

def get_layers(
self,
Expand Down Expand Up @@ -126,14 +138,18 @@ def get_comm_coll_node(
node.attr.append(
ChakraAttr(name="comm_type",
int64_val=self.get_comm_type(comm_type)))
node.attr.append(
ChakraAttr(name="comm_size",
uint64_val = comm_size)
)
return node

def add_parent(
self,
child_node: Any,
parent_node: Any
) -> None:
child_node.parent.append(parent_node.id)
child_node.data_deps.append(parent_node.id)

def convert(self) -> None:
with open(self.input_filename, "r") as f:
Expand Down Expand Up @@ -167,6 +183,8 @@ def convert_microbenchmark(
for npu_id in range(self.num_npus):
output_filename = "%s.%d.et" % (self.output_filename, npu_id)
with open(output_filename, "wb") as g:
global_metadata = self.get_global_metadata()
encode_message(g, global_metadata)
for i in range(self.num_passes):
for layer in layers:
bwd_wg_comm_node = self.get_comm_coll_node(
Expand All @@ -190,6 +208,8 @@ def convert_data_parallel(
for npu_id in range(self.num_npus):
output_filename = "%s.%d.et" % (self.output_filename, npu_id)
with open(output_filename, "wb") as g:
global_metadata = self.get_global_metadata()
encode_message(g, global_metadata)
for i in range(self.num_passes):
fwd_comp_node = None

Expand Down Expand Up @@ -252,6 +272,8 @@ def convert_model_parallel(
for npu_id in range(self.num_npus):
output_filename = "%s.%d.et" % (self.output_filename, npu_id)
with open(output_filename, "wb") as g:
global_metadata = self.get_global_metadata()
encode_message(g, global_metadata)
for i in range(self.num_passes):
fwd_comm_node = None

Expand Down Expand Up @@ -327,6 +349,8 @@ def convert_hybrid_data_model(
for npu_id in range(self.num_npus):
output_filename = "%s.%d.et" % (self.output_filename, npu_id)
with open(output_filename, "wb") as g:
global_metadata = self.get_global_metadata()
encode_message(g, global_metadata)
for i in range(self.num_passes):
fwd_comm_node = None

Expand Down Expand Up @@ -416,6 +440,8 @@ def convert_hybrid_model_data(
for npu_id in range(self.num_npus):
output_filename = "%s.%d.et" % (self.output_filename, npu_id)
with open(output_filename, "wb") as g:
global_metadata = self.get_global_metadata()
encode_message(g, global_metadata)
for i in range(self.num_passes):
fwd_comm_node = None

Expand Down Expand Up @@ -504,6 +530,8 @@ def convert_hybrid_dlrm(
for npu_id in range(self.num_npus):
output_filename = "%s.%d.et" % (self.output_filename, npu_id)
with open(output_filename, "wb") as g:
global_metadata = self.get_global_metadata()
encode_message(g, global_metadata)
for i in range(self.num_passes):
fwd_comp_node = None

Expand Down

0 comments on commit a92cbaa

Please sign in to comment.