Skip to content

Commit

Permalink
Spec loop support (#670)
Browse files Browse the repository at this point in the history
* Add some notes on how to handle spec-loops

* Add some infrastructure for spec loop support

* some refactoring of spec support

* fix a bug

* fix a bug

* refactor po serialization

* nit

* nit

* nit

* serialization

* A comment on how to address loops

* Add wrappers to track loop id and iteration in the runtime

* checkpoint

* Pass loop iteration counters to scheduler from compiler

* Reset the loop iter variables once we exit the loop

* whitespace to trigger ci
  • Loading branch information
angelhof committed May 11, 2023
1 parent 4835c2e commit 8e48fb6
Show file tree
Hide file tree
Showing 6 changed files with 242 additions and 57 deletions.
10 changes: 10 additions & 0 deletions compiler/env_var_names.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@

##
## Variable names used in the pash runtime
##

def loop_iters_var() -> str:
return 'pash_loop_iters'

def loop_iter_var(loop_id: int) -> str:
return f'pash_loop_{loop_id}_iter'
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,13 @@ pash_redir_output echo "$$: (1) Bash variables saved in: $pash_runtime_shell_var
## Once the scheduler determines if there are environment changes, it can then
## decide to rerun or not the speculated commands with the new environment.


## Determine all current loop iterations and send them to the scheduler
pash_loop_iter_counters=${pash_loop_iters:-None}
pash_redir_output echo "$$: Loop node iteration counters: $pash_loop_iter_counters"

## Send and receive from daemon
msg="Wait:${pash_speculative_command_id}"
msg="Wait:${pash_speculative_command_id}|Loop iters:${pash_loop_iter_counters}"
daemon_response=$(pash_spec_communicate_scheduler "$msg") # Blocking step, daemon will not send response until it's safe to continue

## Receive an exit code
Expand Down
4 changes: 2 additions & 2 deletions compiler/preprocessor/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def preprocess_asts(ast_objects, args):
po_file=args.partial_order_file)
util_spec.initialize(trans_options)
else:
trans_options = ast_to_ast.TransformationOptions(mode=trans_mode)
trans_options = ast_to_ast.TransformationState(mode=trans_mode)

## Preprocess ASTs by replacing AST regions with calls to PaSh's runtime.
## Then the runtime will do the compilation and optimization with additional
Expand All @@ -54,7 +54,7 @@ def preprocess_asts(ast_objects, args):
## TODO: We could stream the partial_order_file to the scheduler
if trans_mode is ast_to_ast.TransformationType.SPECULATIVE:
## First complete the partial_order file
util_spec.save_number_of_nodes(trans_options)
util_spec.serialize_partial_order(trans_options)

## Then inform the scheduler that it can read it
unix_socket_file = os.getenv("PASH_SPEC_SCHEDULER_SOCKET")
Expand Down
138 changes: 116 additions & 22 deletions compiler/shell_ast/ast_to_ast.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from enum import Enum
import copy
import pickle

import config

from env_var_names import *
from shell_ast.ast_util import *
from parse import from_ast_objects_to_shell
from speculative import util_spec
Expand All @@ -14,27 +16,79 @@ class TransformationType(Enum):

## Use this object to pass options inside the preprocessing
## trasnformation.
class TransformationOptions:
class TransformationState:
def __init__(self, mode: TransformationType):
self.mode = mode
self.node_counter = 0
self.loop_counter = 0
self.loop_contexts = []

def get_mode(self):
return self.mode

## Node id related
def get_next_id(self):
new_id = self.node_counter
self.node_counter += 1
return new_id

def get_current_id(self):
return self.node_counter - 1

def get_number_of_ids(self):
return self.node_counter

## Loop id related
def get_next_loop_id(self):
new_id = self.loop_counter
self.loop_counter += 1
return new_id

def get_current_loop_context(self):
## We want to copy that
return self.loop_contexts[:]

def get_current_loop_id(self):
if len(self.loop_contexts) == 0:
return None
else:
return self.loop_contexts[0]

def enter_loop(self):
new_loop_id = self.get_next_loop_id()
self.loop_contexts.insert(0, new_loop_id)
return new_loop_id

def exit_loop(self):
self.loop_contexts.pop(0)


## TODO: Turn it into a Transformation State class, and make a subclass for
## each of the two transformations. It is important for it to be state, because
## it will need to be passed around while traversing the tree.
class SpeculativeTransformationState(TransformationOptions):
class SpeculativeTransformationState(TransformationState):
def __init__(self, mode: TransformationType, po_file: str):
super().__init__(mode)
assert(self.mode is TransformationType.SPECULATIVE)
self.partial_order_file = po_file
self.partial_order_edges = []
self.partial_order_node_loop_contexts = {}

def get_partial_order_file(self):
assert(self.mode is TransformationType.SPECULATIVE)
return self.partial_order_file

def add_edge(self, from_id: int, to_id: int):
self.partial_order_edges.append((from_id, to_id))

def get_all_edges(self):
return self.partial_order_edges

def add_node_loop_context(self, node_id: int, loop_contexts):
self.partial_order_node_loop_contexts[node_id] = loop_contexts

def get_all_loop_contexts(self):
return self.partial_order_node_loop_contexts


##
Expand Down Expand Up @@ -247,7 +301,6 @@ def preprocess_node_command(ast_node, trans_options, last_object=False):
def preprocess_node_redir(ast_node, trans_options, last_object=False):
preprocessed_node, something_replaced = preprocess_close_node(ast_node.node, trans_options,
last_object=last_object)
## TODO: Could there be a problem with the in-place update
ast_node.node = preprocessed_node
preprocessed_ast_object = PreprocessedAST(ast_node,
replace_whole=False,
Expand Down Expand Up @@ -279,7 +332,6 @@ def preprocess_node_background(ast_node, trans_options, last_object=False):
def preprocess_node_subshell(ast_node, trans_options, last_object=False):
preprocessed_body, something_replaced = preprocess_close_node(ast_node.body, trans_options,
last_object=last_object)
## TODO: Could there be a problem with the in-place update
ast_node.body = preprocessed_body
preprocessed_ast_object = PreprocessedAST(ast_node,
replace_whole=False,
Expand All @@ -293,20 +345,57 @@ def preprocess_node_subshell(ast_node, trans_options, last_object=False):
## TODO: This is not efficient at all since it calls the PaSh runtime everytime the loop is entered.
## We have to find a way to improve that.
def preprocess_node_for(ast_node, trans_options, last_object=False):
## If we are in a loop, we push the loop identifier into the loop context
loop_id = trans_options.enter_loop()
preprocessed_body, something_replaced = preprocess_close_node(ast_node.body, trans_options, last_object=last_object)
## TODO: Could there be a problem with the in-place update
ast_node.body = preprocessed_body
preprocessed_ast_object = PreprocessedAST(ast_node,

## TODO: Then send this iteration identifier when talking to the spec scheduler
## TODO: After running checks put this behind a check to only run under speculation

## Create a new variable that tracks loop iterations
var_name = loop_iter_var(loop_id)
export_node = make_export_var_constant_string(var_name, '0')
increment_node = make_increment_var(var_name)

## Also store the whole sequence of loop iters in a file
all_loop_ids = trans_options.get_current_loop_context()

## export pash_loop_iters="$pash_loop_XXX_iter $pash_loop_YYY_iter ..."
save_loop_iters_node = export_pash_loop_iters_for_current_context(all_loop_ids)

## Prepend the increment in the body
ast_node.body = make_semi_sequence([increment_node, save_loop_iters_node, copy.deepcopy(preprocessed_body)])

## We pop the loop identifier from the loop context.
##
## KK 2023-04-27: Could this exit happen before the replacement leading to wrong
## results? I think not because we use the _close_node preprocessing variant.
## A similar issue might happen for while
trans_options.exit_loop()

## reset the loop iters after we exit the loop
out_of_loop_loop_ids = trans_options.get_current_loop_context()
reset_loop_iters_node = export_pash_loop_iters_for_current_context(out_of_loop_loop_ids)

## Prepend the export in front of the loop
# new_node = ast_node
new_node = AstNode(make_semi_sequence([export_node, ast_node, reset_loop_iters_node]))
# print(new_node)

preprocessed_ast_object = PreprocessedAST(new_node,
replace_whole=False,
non_maximal=False,
something_replaced=something_replaced,
last_ast=last_object)

return preprocessed_ast_object

def preprocess_node_while(ast_node, trans_options, last_object=False):
## If we are in a loop, we push the loop identifier into the loop context
trans_options.enter_loop()

preprocessed_test, sth_replaced_test = preprocess_close_node(ast_node.test, trans_options, last_object=last_object)
preprocessed_body, sth_replaced_body = preprocess_close_node(ast_node.body, trans_options, last_object=last_object)
## TODO: Could there be a problem with the in-place update
ast_node.test = preprocessed_test
ast_node.body = preprocessed_body
something_replaced = sth_replaced_test or sth_replaced_body
Expand All @@ -315,13 +404,15 @@ def preprocess_node_while(ast_node, trans_options, last_object=False):
non_maximal=False,
something_replaced=something_replaced,
last_ast=last_object)

## We pop the loop identifier from the loop context.
trans_options.exit_loop()
return preprocessed_ast_object

## This is the same as the one for `For`
def preprocess_node_defun(ast_node, trans_options, last_object=False):
## TODO: For now we don't want to compile function bodies
# preprocessed_body = preprocess_close_node(ast_node.body)
## TODO: Could there be a problem with the in-place update
# ast_node.body = preprocessed_body
preprocessed_ast_object = PreprocessedAST(ast_node,
replace_whole=False,
Expand All @@ -337,7 +428,6 @@ def preprocess_node_semi(ast_node, trans_options, last_object=False):
## TODO: Is it valid that only the right one is considered the last command?
preprocessed_left, sth_replaced_left = preprocess_close_node(ast_node.left_operand, trans_options, last_object=False)
preprocessed_right, sth_replaced_right = preprocess_close_node(ast_node.right_operand, trans_options, last_object=last_object)
## TODO: Could there be a problem with the in-place update
ast_node.left_operand = preprocessed_left
ast_node.right_operand = preprocessed_right
sth_replaced = sth_replaced_left or sth_replaced_right
Expand All @@ -354,7 +444,6 @@ def preprocess_node_and(ast_node, trans_options, last_object=False):
# preprocessed_left, should_replace_whole_ast, is_non_maximal = preprocess_node(ast_node.left, irFileGen, config)
preprocessed_left, sth_replaced_left = preprocess_close_node(ast_node.left_operand, trans_options, last_object=last_object)
preprocessed_right, sth_replaced_right = preprocess_close_node(ast_node.right_operand, trans_options, last_object=last_object)
## TODO: Could there be a problem with the in-place update
ast_node.left_operand = preprocessed_left
ast_node.right_operand = preprocessed_right
sth_replaced = sth_replaced_left or sth_replaced_right
Expand All @@ -369,7 +458,6 @@ def preprocess_node_or(ast_node, trans_options, last_object=False):
# preprocessed_left, should_replace_whole_ast, is_non_maximal = preprocess_node(ast_node.left, irFileGen, config)
preprocessed_left, sth_replaced_left = preprocess_close_node(ast_node.left_operand, trans_options, last_object=last_object)
preprocessed_right, sth_replaced_right = preprocess_close_node(ast_node.right_operand, trans_options, last_object=last_object)
## TODO: Could there be a problem with the in-place update
ast_node.left_operand = preprocessed_left
ast_node.right_operand = preprocessed_right
sth_replaced = sth_replaced_left or sth_replaced_right
Expand All @@ -383,7 +471,6 @@ def preprocess_node_or(ast_node, trans_options, last_object=False):
def preprocess_node_not(ast_node, trans_options, last_object=False):
# preprocessed_left, should_replace_whole_ast, is_non_maximal = preprocess_node(ast_node.left)
preprocessed_body, sth_replaced = preprocess_close_node(ast_node.body, trans_options, last_object=last_object)
## TODO: Could there be a problem with the in-place update
ast_node.body = preprocessed_body
preprocessed_ast_object = PreprocessedAST(ast_node,
replace_whole=False,
Expand All @@ -398,7 +485,6 @@ def preprocess_node_if(ast_node, trans_options, last_object=False):
preprocessed_cond, sth_replaced_cond = preprocess_close_node(ast_node.cond, trans_options, last_object=last_object)
preprocessed_then, sth_replaced_then = preprocess_close_node(ast_node.then_b, trans_options, last_object=last_object)
preprocessed_else, sth_replaced_else = preprocess_close_node(ast_node.else_b, trans_options, last_object=last_object)
## TODO: Could there be a problem with the in-place update
ast_node.cond = preprocessed_cond
ast_node.then_b = preprocessed_then
ast_node.else_b = preprocessed_else
Expand All @@ -418,7 +504,6 @@ def preprocess_case(case, trans_options, last_object=False):
def preprocess_node_case(ast_node, trans_options, last_object=False):
preprocessed_cases_replaced = [preprocess_case(case, trans_options, last_object=last_object) for case in ast_node.cases]
preprocessed_cases, sth_replaced_cases = list(zip(*preprocessed_cases_replaced))
## TODO: Could there be a problem with the in-place update
ast_node.cases = preprocessed_cases
preprocessed_ast_object = PreprocessedAST(ast_node,
replace_whole=False,
Expand Down Expand Up @@ -469,12 +554,14 @@ def replace_df_region(asts, trans_options, disable_parallel_pipelines=False, ast
script_file.write(text_to_output)
replaced_node = make_call_to_pash_runtime(ir_filename, sequential_script_file_name, disable_parallel_pipelines)
elif transformation_mode is TransformationType.SPECULATIVE:
## TODO: This currently writes each command on its own line,
## though it should be improved to better serialize each command in its own file
## and then only saving the ids of each command in the partial order file.
text_to_output = get_shell_from_ast(asts, ast_text=ast_text)
## Generate an ID
df_region_id = util_spec.get_next_id()
df_region_id = trans_options.get_next_id()

## Get the current loop id and save it so that the runtime knows
## which loop it is in.
loop_id = trans_options.get_current_loop_id()

## Determine its predecessors
## TODO: To make this properly work, we should keep some state
## in the AST traversal to be able to determine predecessors.
Expand All @@ -485,7 +572,7 @@ def replace_df_region(asts, trans_options, disable_parallel_pipelines=False, ast
## Write to a file indexed by its ID
util_spec.save_df_region(text_to_output, trans_options, df_region_id, predecessors)
## TODO: Add an entry point to spec through normal PaSh
replaced_node = make_call_to_spec_runtime(df_region_id)
replaced_node = make_call_to_spec_runtime(df_region_id, loop_id)
else:
## Unreachable
assert(False)
Expand Down Expand Up @@ -545,10 +632,17 @@ def make_call_to_pash_runtime(ir_filename, sequential_script_file_name,
return runtime_node

## TODO: Make that an actual call to the spec runtime
def make_call_to_spec_runtime(command_id: int) -> AstNode:

def make_call_to_spec_runtime(command_id: int, loop_id) -> AstNode:
assignments = [["pash_spec_command_id",
string_to_argument(str(command_id))]]
if loop_id is None:
loop_id_str = ""
else:
loop_id_str = str(loop_id)

assignments.append(["pash_spec_loop_id",
string_to_argument(loop_id_str)])

## Call the runtime
arguments = [string_to_argument("source"),
string_to_argument(config.RUNTIME_EXECUTABLE)]
Expand Down

0 comments on commit 8e48fb6

Please sign in to comment.