diff --git a/parallel-orch/partial_program_order.py b/parallel-orch/partial_program_order.py index 36fab2f8..afbb7786 100644 --- a/parallel-orch/partial_program_order.py +++ b/parallel-orch/partial_program_order.py @@ -23,43 +23,106 @@ def get_stdout_file(self): def __str__(self): return f'CompletedNodeInfo(ec:{self.get_exit_code()}, vf:{self.get_variable_file()}, stdout:{self.get_stdout_file()})' +## This class is used for both loop contexts and loop iters +## The indices go from inner to outer +class LoopStack: + def __init__(self, loop_contexts_or_iters=None): + if loop_contexts_or_iters is None: + self.loops = [] + else: + self.loops = loop_contexts_or_iters + + def is_empty(self): + return len(self.loops) == 0 + + def __len__(self): + return len(self.loops) + + ## Generates a new loop stack with the same length but 0s as values + def new_zeroed_loop_stack(self): + return [0 for i in self.loops] + + def get_outer(self): + return self.loops[-1] + + def pop_outer(self): + return self.loops.pop() + + def add_inner(self, loop_iter_id: int): + self.loops.insert(0, loop_iter_id) + + def outer_to_inner(self): + return self.loops[::-1] + + def index(self, loop_iter_id: int) -> int: + return self.loops.index(loop_iter_id) + + def get(self, index: int): + return self.loops[index] + + def __repr__(self): + ## TODO: Represent it using 'it', 'it0', 'it1', etc + ## or -(iters)- in front of it. + output = "-".join([str(it) for it in self.loops]) + return output + + def __eq__(self, other): + if not len(self.loops) == len(other.loops): + return False + for i in range(len(self.loops)): + if not self.loops[i] == other.loops[i]: + return False + return True + + class NodeId: def __init__(self, id: int, loop_iters=None): self.id = id + if loop_iters is None: - self.loop_iters = [] + self.loop_iters = LoopStack() else: + assert(isinstance(loop_iters, LoopStack)) self.loop_iters = loop_iters def has_iters(self): - return len(self.loop_iters) > 0 + return not self.loop_iters.is_empty() + + def get_iters(self): + return copy.deepcopy(self.loop_iters) def get_non_iter_id(self): - return self.id + return NodeId(self.id) + + ## Returns a new NodeId + def generate_new_node_id_with_another_iter(self, new_iter: int): + ## This node already contains iterations for the outer loops potentially + ## so we just need to add another inner iteration + new_iters = copy.deepcopy(self.loop_iters) + new_iters.add_inner(new_iter) + + new_node_id = NodeId(self.id, new_iters) + return new_node_id def __repr__(self): - output = str(self.id) - if len(self.loop_iters) > 0: - output += f'+{"-".join([str(it) for it in self.loop_iters])}' + ## TODO: Represent it using n. + output = f'{self.id}' + if not self.loop_iters.is_empty(): + output += f'+{self.loop_iters}' return output def __hash__(self): return hash(str(self)) def __eq__(self, other): - if not len(self.loop_iters) == len(other.loop_iters): - return False - for i in range(len(self.loop_iters)): - if not self.loop_iters[i] == other.loop_iters[i]: - return False - return self.id == other.id + return self.loop_iters == other.loop_iters and self.id == other.id def __ne__(self, other): # Not strictly necessary, but to avoid having both x==y and x!=y # True at the same time return not(self == other) - ## TODO: Maybe we need to make these better + ## TODO: Define this correctly if it is to be used for something other than dictionary indexing def __lt__(self, obj): return (str(self) < str(obj)) @@ -76,19 +139,19 @@ def parse_node_id(node_id_str: str) -> NodeId: if "+" in node_id_str: node_id_int, iters_str = node_id_str.split("+") iters = [int(it) for it in iters_str.split("-")] - return NodeId(int(node_id_int), iters) + return NodeId(int(node_id_int), LoopStack(iters)) else: - return NodeId(int(node_id_str)) + return NodeId(int(node_id_str), LoopStack()) class Node: - def __init__(self, id, cmd, loop_context): + def __init__(self, id, cmd, loop_context: LoopStack): self.cmd = cmd self.id = id self.cmd_no_redir = trace.remove_command_redir(self.cmd) self.loop_context = loop_context ## Keep track of how many iterations of this loop node we have unrolled - if len(loop_context) > 0: - self.current_iter = 0 + if not loop_context.is_empty(): + self.current_iters = loop_context.new_zeroed_loop_stack() def __str__(self): # return f"ID: {self.id}\nCMD: {self.cmd}\nR: {self.read_set}\nW: {self.write_set}" @@ -104,16 +167,19 @@ def get_cmd(self) -> str: def get_cmd_no_redir(self) -> str: return self.cmd_no_redir - def get_loop_context(self) -> "list[int]": + def get_loop_context(self) -> LoopStack: return self.loop_context def in_loop(self) -> bool: - return len(self.loop_context) > 0 + return not self.loop_context.is_empty() - def get_next_iter(self) -> int: + ## KK 2023-05-17 Does this generate the correct iteration even in nested loops? + def get_next_iter(self, loop_id: int) -> int: assert(self.in_loop()) - self.current_iter += 1 - return self.current_iter + assert(self.loop_context.get_outer() == loop_id) + loop_id_index_in_loop_context_stack = self.loop_context.index(loop_id) + self.current_iters[loop_id_index_in_loop_context_stack] += 1 + return self.current_iters[loop_id_index_in_loop_context_stack] ## Note: This information is valid only after a node is committed. ## It might be set even before that, but it should only be retrieved when @@ -199,6 +265,24 @@ def get_standard_source_nodes(self) -> list: source_nodes = self.get_source_nodes() return self.filter_standard_nodes(source_nodes) + ## This returns the minimum w.r.t. to the PO of a bunch of node_ids. + ## In a real partial order, this could be many, + def get_min(self, node_ids: "list[NodeId]") -> "list[NodeId]": + potential_minima = set(copy.deepcopy(node_ids)) + for node_id in node_ids: + tc = self.get_transitive_closure([node_id]) + ## Remove the node itself from its transitive closure + tc.remove(node_id) + ## If a node is found in the tc of another node, then + ## it is not a minimum + for nid in tc: + potential_minima.discard(nid) + ## KK 2023-05-22 This will be removed at some point but I keep it here + ## for now for easier bug finding. + # logging.debug(f"Potential minima: {potential_minima}") + assert(len(potential_minima) == 1) + return list(potential_minima) + ## This returns all previous nodes of a sub partial order def get_sub_po_source_nodes(self, node_ids: "list[NodeId]") -> "list[NodeId]": # assert(self.is_closed_sub_partial_order(node_ids)) @@ -307,6 +391,9 @@ def valid(self): logging.debug("Checking partial order validity...") self.log_partial_program_order_info() valid1 = self.loop_nodes_valid() + ## TODO: Add a check that for x, y : NodeIds, x < y iff x is a predecessor to x + ## This is necessary due to the `hypothetical_before` method. + ## TODO: Fix the checks below because they do not work currently ## TODO: Check that committed is prefix closed w.r.t partial order # self.all_frontier_nodes_after_committed_nodes() @@ -345,7 +432,7 @@ def get_node(self, node_id:NodeId) -> Node: def is_node_id(self, node_id:NodeId) -> bool: return node_id in self.nodes - def get_node_loop_context(self, node_id: NodeId) -> "list[int]": + def get_node_loop_context(self, node_id: NodeId) -> LoopStack: return self.get_node(node_id).get_loop_context() def get_all_non_committed(self) -> "list[NodeId]": @@ -368,14 +455,23 @@ def is_loop_node(self, node_id:NodeId) -> bool: def filter_standard_nodes(self, node_ids: "list[NodeId]") -> "list[NodeId]": return [node_id for node_id in node_ids if not self.is_loop_node(node_id)] + + def filter_loop_nodes(self, node_ids: "list[NodeId]") -> "list[NodeId]": + return [node_id for node_id in node_ids + if self.is_loop_node(node_id)] ## This creates a new node_id and then creates a mapping from the node and iteration id to this node id ## TODO: Currently doesn't work with nested loops - def create_standard_id_from_loop_node(self, node_id: NodeId, loop_id: int) -> NodeId: + def create_node_id_with_one_less_loop_from_loop_node(self, node_id: NodeId, loop_id: int) -> NodeId: node = self.get_node(node_id) - new_iter = node.get_next_iter() - assert(not node_id.has_iters()) - return NodeId(node_id.id, [new_iter]) + logging.debug(f' >>> Node: {node}') + logging.debug(f' >>> its loops: {node.loop_context} --- {node.current_iters}') + + new_iter = node.get_next_iter(loop_id) + ## Creates a new node id where we have appended the new iter + new_node_id = node_id.generate_new_node_id_with_another_iter(new_iter) + logging.debug(f' >>> new node_id with another iter: {new_node_id}') + return new_node_id ## Returns all non committed non-loop nodes @@ -585,29 +681,109 @@ def rerun_stopped(self): self.to_be_resolved[cmd_id] = [] self.stopped = new_stopped - ## When the frontend sends a wait for a node, it means that execution in the frontend has - ## already surpassed all nodes prior to it. This is particularly important for loops, - ## since we can't always statically predict how many iterations they will do, so the only - ## definitive way to know that they are done is to receive a wait for a node after them. - def wait_received(self, node_id: NodeId): - ## Whenever we receive a wait for a node, we always need to check and "commit" all prior loop nodes - ## since we know that they won't have any more iterations (the JIT frontend has already passed them). - ## - ## TODO: This doesn't straightforwardly work for nested_loops, we need to figure out something else there - - ## Get inverse_transitive_closure to find all nodes that are before this one - inverse_tc_node_ids = self.get_inverse_transitive_closure([node_id]) + ## This method checks if nid1 would be before nid2 if nid2 was part of the PO. + ## + ## Therefore it does not just check edges, but rather computes if it would be before + ## based on ids and loop iterations. + ## + ## 1. Check if the loop ids of the two abstract parents of both nodes differ + ## thus showing that one is before the other + ## 2. If all loop ids are the same, now we can actually compare iterations. + ## If a node is in the same loop ids but in a later iteration then it is later. + ## 3. If all iterations are the same too, then we just compare node ids + ## + ## KK 2023-05-22 This is a complex procedure, I wonder if we can simplify it in some way + def hypothetical_before(self, nid1: NodeId, nid2: NodeId): + raw_id1 = nid1.get_non_iter_id() + ## Get all loop ids that nid1 could be in + loop_ids1 = self.get_node_loop_context(raw_id1) + + raw_id2 = nid1.get_non_iter_id() + ## Get all loop ids that nid2 could be in + loop_ids2 = self.get_node_loop_context(raw_id2) + + i = 0 + while i < len(loop_ids1) and i < len(loop_ids2): + loop_id_1 = loop_ids1.get(len(loop_ids1) - 1 - i) + loop_id_2 = loop_ids2.get(len(loop_ids2) - 1 - i) + ## If the first node is in a previous loop than the second, + ## then we are done. + if loop_id_1 < loop_id_2: + return True + elif loop_id_1 > loop_id_2: + return False + + ## We need to keep going + i += 1 + + ## If we reach this, we know that both nodes are in the same loops up to i + ## so we now compare iterations and node identifiers. + + iters1 = nid1.get_iters() + iters2 = nid2.get_iters() + + i = 0 + while i < len(iters1) and i < len(iters2): + iter1 = iters1.get(len(iters1) - 1 - i) + iter2 = iters2.get(len(iters2) - 1 - i) + ## If the first node is in a previous iteration than the second, + ## then we are done. + if iter1 < iter2: + return True + elif iter1 > iter2: + return False + ## We need to keep going + i += 1 + + ## We now know that their common prefix of iterations is the same + + ## Check if the node could potentially generate other nodes that are bigger + ## i.e., if it is more abstract. If so, then it is not smaller. + common_loop_depth = min(len(loop_ids1), len(loop_ids2)) + abstract_depth1 = max(common_loop_depth - len(iters1), 0) + abstract_depth2 = max(common_loop_depth - len(iters2), 0) + if abstract_depth1 < abstract_depth2: + return True + elif abstract_depth1 > abstract_depth2: + return False + + return nid1.id < nid2.id + + + def progress_po_due_to_wait(self, node_id: NodeId): + logging.debug(f"Checking if we can progress the partial order after having received a wait for {node_id}") + ## The node might not be part of the partial order if it corresponds to + ## a loop node iteration. In this case, we just need to make sure that + ## we commit the right previous loop nodes that are relevant to it. + if not self.is_node_id(node_id): + ## TODO: This check is not correct currently, it works for now, but when we move to full partial orders it wont anymore, + ## due to the check happening with < in hypothetical before + logging.debug(f" > Node {node_id} is not part of the PO so we compute the nodes that would be before it...") + all_non_committed = self.get_all_non_committed() + all_non_committed_loop_nodes = self.filter_loop_nodes(all_non_committed) + non_committed_loop_nodes_that_would_be_predecessors = [n_id for n_id in all_non_committed_loop_nodes + if self.hypothetical_before(n_id, node_id)] + + new_committed_nodes = non_committed_loop_nodes_that_would_be_predecessors + + else: + logging.debug(f" > Node {node_id} is part of the PO so we just check its predecessors following the inverse edges...") + ## If the node is in the PO, then we can proceed normally and find its predecessors and commit them + + ## Get inverse_transitive_closure to find all nodes that are before this one + inverse_tc_node_ids = self.get_inverse_transitive_closure([node_id]) + + ## Out of those nodes, filter out the non-committed loop ones + non_committed_loop_nodes_in_inverse_tc = [node_id for node_id in inverse_tc_node_ids + if not self.is_committed(node_id) and + self.is_loop_node(node_id)] + logging.debug(f'Non committed loop nodes that are predecessors to {node_id} are: {non_committed_loop_nodes_in_inverse_tc}') + + new_committed_nodes = non_committed_loop_nodes_in_inverse_tc - ## Out of those nodes, filter out the non-committed loop ones - non_committed_loop_nodes_in_inverse_tc = [node_id for node_id in inverse_tc_node_ids - if not self.is_committed(node_id) and - self.is_loop_node(node_id)] - logging.debug(f'Non committed loop nodes that are predecessors to {node_id} are: {non_committed_loop_nodes_in_inverse_tc}') - ## And "close them" ## TODO: This is a hack here, we need to have a proper method that commits ## nodes and does whatever else is needed to do (e.g., add new nodes to frontier) - new_committed_nodes = non_committed_loop_nodes_in_inverse_tc logging.debug(f'Adding following loop nodes to committed: {new_committed_nodes}') for node_id in new_committed_nodes: self.commit_node(node_id) @@ -635,53 +811,71 @@ def wait_received(self, node_id: NodeId): ## since in many tests there is nothing new to resolve after a wait) self.resolve_commands_that_can_be_resolved_and_step_forward() - - def find_loop_sub_partial_order(self, loop_id: int) -> "list[NodeId]": + ## When the frontend sends a wait for a node, it means that execution in the frontend has + ## already surpassed all nodes prior to it. This is particularly important for loops, + ## since we can't always statically predict how many iterations they will do, so the only + ## definitive way to know that they are done is to receive a wait for a node after them. + def wait_received(self, node_id: NodeId): + ## Whenever we receive a wait for a node, we always need to check and "commit" all prior loop nodes + ## since we know that they won't have any more iterations (the JIT frontend has already passed them). + + ## We first have to push and progress the PO due to the wait and then unroll + ## KK 2023-05-22 Currently this checks whether a still nonexistent node is + ## would be a successor of existing nodes to commit some of + ## them if needed. Unfortunately, to make this check for a non-existent + ## node is very complex and not elegant. + ## TODO: Could we swap unrolling and progressing so that we always + ## check if a node can be progressed by checking edges? + self.progress_po_due_to_wait(node_id) + + ## Unroll some nodes if needed. + if node_id.has_iters(): + ## TODO: This unrolling can also happen and be moved to speculation. + ## For now we are being conservative and that is why it only happens here + ## TODO: Move this to the scheduler.schedule_work() (if we have a loop node waiting for response and we are not unrolled, unroll to create work) + self.maybe_unroll(node_id) + + + def find_outer_loop_sub_partial_order(self, loop_id: int, nodes_subset: "list[NodeId]") -> "list[NodeId]": loop_node_ids = [] - for node_id in self.nodes: + for node_id in nodes_subset: loop_context = self.get_node_loop_context(node_id) - if loop_id in loop_context: + ## Note: this only checks for the nodes that have this loop id as their outer loop + if not loop_context.is_empty() and loop_id == loop_context.get_outer(): loop_node_ids.append(node_id) ## TODO: Assert that this is closed w.r.t. partial order return loop_node_ids - ## KK 2023-05-02: We should not be able to step/execute/speculate loop nodes, instead - ## the only action we should be able to do to them is to unroll them, - ## by creating iterations before them in the partial order. - ## The loop nodes then act as barriers that cannot be committed, executed (or put in the frontier) - ## and separate the already committed with the future partial order. - ## - ## Note: We have to be careful when unrolling loops to unroll a complete iteration to start with - ## (to not have to deal with partial order relations between commands of different iterations). - ## - ## - ## Concrete pseudocode: - ## def unroll(self, loop_id): - ## ## Finds all of the nodes in the same loop in the partial order - ## sub_po = self.find_loop_sub_partial_order(loop_id) - ## ## Find the previous node - ## previous_ids = self.find_prev_nodes(sub_po.first) - ## ## Create an iteration version (no loop nodes) of the sub_po - ## ## (be careful to not eliminate nested loop nodes) - ## sub_po_iter = create_iter(sub_po) - ## ## add the iteration between the loop and its previous node - ## self.add_po_between(sub_po_iter, previous_ids, sub_po.first) - ## - ## We need to determine when to call unroll. For now we can just do it if the frontier is empty - ## (which means that the next node of the frontier is a loop node). - def unroll_loop(self, loop_id: int): + + ## This function unrolls a single loop, by first finding all its nodes (they must be contiguous) and then creating new versions of them + ## that are concretized. Its second argument describes which subset of all partial order nodes we want to look at. + ## That is necessary because when unrolling nested loops, we might end up in a situation where we have unrolled the + ## outer loop, but some of the newly created nodes might still be loop nodes (so we might have loop nodes for the same loop in multiple locations). + def unroll_single_loop(self, loop_id: int, nodes_subset: "list[NodeId]"): logging.info(f'Unrolling loop with id: {loop_id}') - loop_node_ids = self.find_loop_sub_partial_order(loop_id) + all_loop_node_ids = self.find_outer_loop_sub_partial_order(loop_id, nodes_subset) + + ## We don't want to unroll already committed nodes + loop_node_ids = [nid for nid in all_loop_node_ids + if not self.is_committed(nid)] + logging.debug(f'Node ids for loop: {loop_id} are: {loop_node_ids}') ## Create the new nodes and remap adjacencies accordingly node_mappings = {} for node_id in loop_node_ids: node = self.get_node(node_id) - new_loop_node_id = self.create_standard_id_from_loop_node(node_id, loop_id) + new_loop_node_id = self.create_node_id_with_one_less_loop_from_loop_node(node_id, loop_id) node_mappings[node_id] = new_loop_node_id + ## The new node has one less loop context than the previous one + node_loop_contexts = node.get_loop_context() + logging.debug(f'Node: {node_id} loop_contexts: {node_loop_contexts}') + assert(node_loop_contexts.get_outer() == loop_id) + new_node_loop_contexts = copy.deepcopy(node_loop_contexts) + new_node_loop_contexts.pop_outer() + ## Create the new node - self.nodes[new_loop_node_id] = Node(new_loop_node_id, node.cmd, []) + self.nodes[new_loop_node_id] = Node(new_loop_node_id, node.cmd, new_node_loop_contexts) self.executions[new_loop_node_id] = 0 logging.debug(f'New loop ids: {node_mappings}') @@ -726,15 +920,9 @@ def unroll_loop(self, loop_id: int): self.remove_edge(from_id=previous_id, to_id=old_nodes_source) - ## Add all new nodes to the workset (since they have to be tracked) - for _, new_node_id in node_mappings.items(): - self.workset.append(new_node_id) - - ## TODO: We need to correctly populate the resolved set of next commands - ## after unrolling the loop. - ## Return the new first node - return node_mappings[old_nodes_source] + ## Return the new first node and all node mappings + return node_mappings[old_nodes_source], node_mappings.values() ## Static method that just maps using a node mapping dictionary or leaves them as ## they are if not @@ -748,22 +936,76 @@ def map_using_mapping(node_ids: "list[NodeId]", mapping) -> "list[NodeId]": new_node_ids.append(new_id) return new_node_ids + ## This unrolls a sequence of loops by unrolling each loop outside-in + def unroll_loops(self, loop_contexts: LoopStack) -> NodeId: + logging.debug(f'Unrolling the following loops: {loop_contexts}') + + ## All new node_ids + all_new_node_ids = set() + relevant_node_ids = list(self.nodes.keys()) + for loop_ctx in loop_contexts.outer_to_inner(): + new_first_node_id, new_node_ids = self.unroll_single_loop(loop_ctx, relevant_node_ids) + logging.debug(f'New node ids after unrolling: {new_node_ids}') + ## Update all new nodes that we have added + all_new_node_ids.update(new_node_ids) + + ## Re-set the relevant node ids to only the new nodes (if we unrolled a big loop once, + ## we just want to look at those new unrolled nodes for the next unrolling). + relevant_node_ids = new_node_ids - def unroll_loop_node(self, node_id: NodeId): - assert(self.is_loop_node(node_id)) - loop_context = self.get_node_loop_context(node_id) - ## TODO: First determine which exactly loop do we need to unroll - ## I am not sure if it is correct to just do the last one - ## I think it might be the difference between this and the previous node - ## - ## TODO: I actually think we have to unroll all loops - new_first_node_id = self.unroll_loop(loop_context[0]) + logging.debug(f' >>> Edges after unrolling : {self.adjacency}') + logging.debug(f' >>> Inv Edges after unrolling: {self.inverse_adjacency}') + + ## Add all new standard nodes to the workset (since they have to be tracked) + for new_node_id in all_new_node_ids: + if not self.is_loop_node(new_node_id): + self.workset.append(new_node_id) + + ## KK 2023-05-22 Do we need to correctly populate the resolved set of next commands + ## after unrolling the loop. + + return new_first_node_id + + ## This unrolls a loop given a target concrete node id + def unroll_loop_node(self, target_concrete_node_id: NodeId): + raw_node_id = target_concrete_node_id.get_non_iter_id() + assert(self.is_loop_node(raw_node_id)) + + logging.debug(f'Edges: {self.adjacency}') + + ## Find the closest non-committed successor with this node id + ## Note: This is necessary because we might need to unroll only a subset of the loops that a node is part of. + ## This is relevant when we have nested loops. + all_non_committed = self.get_all_non_committed() + all_non_committed_loop_nodes = self.filter_loop_nodes(all_non_committed) + logging.debug(f'All non committed loop nodes: {all_non_committed_loop_nodes}') + source_node_ids = self.get_min(all_non_committed_loop_nodes) + ## Note: This assertion might not hold once we have actual partial orders + assert(len(source_node_ids) == 1) + node_id = source_node_ids[0] + logging.debug(f'Closest non-committed loop node successor with raw_id {raw_node_id} is: {node_id}') + loop_contexts = self.get_node_loop_context(node_id) + + + ## Unroll all loops that this node is in + new_first_node_id = self.unroll_loops(loop_contexts) ## TODO: This needs to change when we modify unrolling to happen speculatively too ## TODO: This needs to properly add the node to frontier and to resolve dictionary self.step_forward(self.get_committed()) self.frontier.append(new_first_node_id) + ## At the end of unrolling the target node must be part of the PO + assert(self.is_node_id(target_concrete_node_id)) + + + def maybe_unroll(self, node_id: NodeId) -> NodeId: + ## Only unrolls this node if it doesn't already exist in the PO + if not self.is_node_id(node_id): + self.unroll_loop_node(node_id) + + ## The node_id must be part of the PO after unrolling, otherwise we did something wrong + assert(self.is_node_id(node_id)) ## KK 2023-09-05 @Giorgo Do all of these steps need to be done at once, or are these methods ## meaningful even if called one by one? In general, I would like there to @@ -1129,7 +1371,7 @@ def parse_partial_program_order_from_file(file_path: str) -> PartialProgramOrder file_path = f'{cmds_directory}/{i}' cmd = parse_cmd_from_file(file_path) loop_ctx = loop_contexts[i] - nodes[NodeId(i)] = Node(NodeId(i), cmd, loop_ctx) + nodes[NodeId(i)] = Node(NodeId(i), cmd, LoopStack(loop_ctx)) edges = {NodeId(i) : [] for i in range(number_of_nodes)} for edge_line in edge_lines: diff --git a/parallel-orch/scheduler_server.py b/parallel-orch/scheduler_server.py index d2756cb1..d6f352c3 100644 --- a/parallel-orch/scheduler_server.py +++ b/parallel-orch/scheduler_server.py @@ -5,7 +5,7 @@ from util import * import config import sys -from partial_program_order import parse_partial_program_order_from_file, NodeId, parse_node_id +from partial_program_order import parse_partial_program_order_from_file, LoopStack, NodeId, parse_node_id ## ## A scheduler server @@ -70,16 +70,17 @@ def handle_init(self, input_cmd: str): self.partial_program_order = parse_partial_program_order_from_file(partial_order_file) self.partial_program_order.init_partial_order() - def __parse_wait(self, input_cmd: str): + def __parse_wait(self, input_cmd: str) -> NodeId: try: node_id_component, loop_iter_counter_component = input_cmd.rstrip().split("|") - node_id = NodeId(int(node_id_component.split(":")[1].rstrip())) + raw_node_id_int = int(node_id_component.split(":")[1].rstrip()) loop_counters_str = loop_iter_counter_component.split(":")[1].rstrip() if loop_counters_str == "None": - loop_counters = [] + node_id = NodeId(raw_node_id_int) else: loop_counters = [int(cnt) for cnt in loop_counters_str.split("-")] - return node_id, loop_counters + node_id = NodeId(raw_node_id_int, LoopStack(loop_counters)) + return node_id except: raise Exception(f'Parsing failure for line: {input_cmd}') @@ -87,19 +88,9 @@ def handle_wait(self, input_cmd: str, connection): assert(input_cmd.startswith("Wait")) ## We have received this message by the JIT, which waits for a node_id to ## finish execution. - raw_node_id, loop_counters = self.__parse_wait(input_cmd) - logging.debug(f'Scheduler: Received wait for node_id: {raw_node_id} with loop counters: {loop_counters}') - - if self.partial_program_order.is_loop_node(raw_node_id): - node_id = NodeId(raw_node_id.id, loop_counters) - if not self.partial_program_order.is_node_id(node_id): - ## TODO: This unrolling can also happen and be moved to speculation. - ## For now we are being conservative and that is why it only happens here - ## TODO: Move this to the scheduler.schedule_work() (if we have a loop node waiting for response and we are not unrolled, unroll to create work) - self.partial_program_order.unroll_loop_node(raw_node_id) - else: - ## If we are not in a loop, then the node id corresponds to the concrete node - node_id = raw_node_id + node_id = self.__parse_wait(input_cmd) + logging.debug(f'Scheduler: Received wait for node_id: {node_id}') + ## Inform the partial order that we received a wait for a node so that it can push loops ## forward and so on. diff --git a/test/test_scripts/test_loop.sh b/test/test_scripts/test_loop.sh index 075b752f..d8602fe5 100644 --- a/test/test_scripts/test_loop.sh +++ b/test/test_scripts/test_loop.sh @@ -1,10 +1,27 @@ echo hi -for i in 1 2 3; do +for i in 1 2; do echo hi1 - sleep 1 - echo hi2 + for j in 1 2 3; do + echo hi2 + sleep 0.5 + echo hi3 + done + echo hi4 done -echo hi3 +echo hi5 + +echo hi6 +for i in 1 2 3; do + echo hi7 + for j in 1 2; do + echo hi8 + sleep 0.5 + echo hi9 + done + echo hi10 +done +echo hi11 + ## Future loop tests must include: ## 1. A single loop with a single command without anything else