Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Hail] Fix several CSE bugs #7479

Merged
merged 12 commits into from Nov 14, 2019
@@ -85,10 +85,10 @@ def _eq(self, other):
def __hash__(self):
return 31 + hash(str(self))

@abc.abstractmethod
def new_block(self, i: int) -> bool:
...
return self.renderable_new_block(self.renderable_idx_of_child(i))

@abc.abstractmethod
def renderable_new_block(self, i: int) -> bool:
return self.new_block(i)

@@ -105,25 +105,46 @@ def bindings(self, i: int, default_value=None):
mapping from bound variables to 'default_value', if provided,
otherwise to their types
"""
return self.renderable_bindings(self.renderable_idx_of_child(i), default_value)

def renderable_bindings(self, i: int, default_value=None):
return {}

def agg_bindings(self, i: int, default_value=None):
return self.renderable_agg_bindings(self.renderable_idx_of_child(i), default_value)

def renderable_agg_bindings(self, i: int, default_value=None):
return {}

def scan_bindings(self, i: int, default_value=None):
return self.renderable_scan_bindings(self.renderable_idx_of_child(i), default_value)

def renderable_scan_bindings(self, i: int, default_value=None):
return {}

def uses_agg_context(self, i: int) -> bool:
return False
return self.renderable_uses_agg_context(self.renderable_idx_of_child(i))

def renderable_uses_agg_context(self, i: int) -> bool:
return self.uses_agg_context(i)
return False

def uses_scan_context(self, i: int) -> bool:
return False
return self.renderable_uses_scan_context(self.renderable_idx_of_child(i))

def renderable_uses_scan_context(self, i: int) -> bool:
return self.uses_scan_context(i)
return False

def renderable_idx_of_child(self, i: int) -> int:
return i

# Used as a variable, bound by any node which defines the meaning of
# aggregations (e.g. MatrixMapRows, AggFilter, etc.), and "referenced" by
# any node which performs aggregations (e.g. AggFilter, ApplyAggOp, etc.).
agg_capability = 'agg_capability'

@classmethod
def uses_agg_capability(cls) -> bool:
return False

def child_context_without_bindings(self, i: int, parent_context):
(eval_c, agg_c, scan_c) = parent_context
@@ -135,6 +156,9 @@ def child_context_without_bindings(self, i: int, parent_context):
return parent_context

def child_context(self, i: int, parent_context, default_value=None):
return self.renderable_child_context(self.renderable_idx_of_child(i), parent_context, default_value)

def renderable_child_context(self, i: int, parent_context, default_value=None):
base = self.child_context_without_bindings(i, parent_context)
eval_b = self.bindings(i, default_value)
agg_b = self.agg_bindings(i, default_value)
@@ -145,11 +169,26 @@ def child_context(self, i: int, parent_context, default_value=None):
else:
return base

@property
def free_vars(self):
return set()

@property
def free_agg_vars(self):
return set()

@property
def free_scan_vars(self):
return set()


class IR(BaseIR):
def __init__(self, *children):
super().__init__(*children)
self._aggregations = None
self._free_vars = None
self._free_agg_vars = None
self._free_scan_vars = None

@property
def aggregations(self):
@@ -191,7 +230,7 @@ def typ(self):
assert self._type is not None, self
return self._type

def new_block(self, i: int) -> bool:
def renderable_new_block(self, i: int) -> bool:
return False

@abc.abstractmethod
@@ -204,6 +243,41 @@ def parse(self, code, ref_map={}, ir_map={}):
{k: t._parsable_string() for k, t in ref_map.items()},
ir_map)

@property
def free_vars(self):
def vars_from_child(i):
if self.uses_agg_context(i):
return self.children[i].free_agg_vars

This comment has been minimized.

Copy link
@tpoterba

tpoterba Nov 12, 2019

Collaborator

(let's dig into what's going on here, as we talked about)

if self.uses_scan_context(i):
return self.children[i].free_scan_vars
return self.children[i].free_vars.difference(self.bindings(i, 0).keys())

if self._free_vars is None:
self._free_vars = {
var for i in range(len(self.children))
for var in vars_from_child(i)}
if self.uses_agg_capability():
self._free_vars.add(BaseIR.agg_capability)
return self._free_vars

@property
def free_agg_vars(self):
if self._free_agg_vars is None:
self._free_agg_vars = {
var for i in range(len(self.children))
for var in self.children[i].free_agg_vars.difference(
self.agg_bindings(i, 0).keys())}
return self._free_agg_vars

@property
def free_scan_vars(self):
if self._free_scan_vars is None:
self._free_scan_vars = {
var for i in range(len(self.children))
for var in self.children[i].free_scan_vars.difference(
self.scan_bindings(i, 0).keys())}
return self._free_scan_vars


class TableIR(BaseIR):
def __init__(self, *children):
@@ -220,7 +294,7 @@ def typ(self):
assert self._type is not None, self
return self._type

def new_block(self, i: int) -> bool:
def renderable_new_block(self, i: int) -> bool:
return True

def parse(self, code, ref_map={}, ir_map={}):
@@ -245,7 +319,7 @@ def typ(self):
assert self._type is not None, self
return self._type

def new_block(self, i: int) -> bool:
def renderable_new_block(self, i: int) -> bool:
return True

def parse(self, code, ref_map={}, ir_map={}):
@@ -272,7 +346,7 @@ def typ(self):
assert self._type is not None, self
return self._type

def new_block(self, i: int) -> bool:
def renderable_new_block(self, i: int) -> bool:
return True

def parse(self, code, ref_map={}, ir_map={}):
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.