Skip to content

Commit

Permalink
Properly track implicit dependance on __type__ in function calls (#…
Browse files Browse the repository at this point in the history
…5715)

When passing objects to functions, there is an implicit dependence on
the `__type__` field (which may be used for function overloading).
This dependence isn't tracked, so in inheritance cases a computed
that depends on it may be processed before `__type__` is inherited.

Fix this by making sure to get `__type__` through proper channels.

Fixes #5661.

In addition to actually fixing it, I added a hack that will make
`__type__` always be processed first, to work around not easily being
able to repair the expr refs in existing databases. We can drop the hack
on our dev branch immediately, though.
(I tested this flow manually.)
  • Loading branch information
msullivan committed Jul 6, 2023
1 parent 7a13070 commit becd952
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 6 deletions.
5 changes: 3 additions & 2 deletions edb/edgeql/compiler/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -837,8 +837,9 @@ def finalize_args(
):
arg_type_path_id = pathctx.extend_path_id(
arg.path_id,
ptrcls=arg_type.getptr(
ctx.env.schema, sn.UnqualName('__type__')),
ptrcls=setgen.resolve_ptr(
arg_type, '__type__', track_ref=None, ctx=ctx
),
ctx=ctx,
)
else:
Expand Down
16 changes: 13 additions & 3 deletions edb/schema/inheriting.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,11 +223,21 @@ def get_inherited_ref_layout(
rev_refs = {v: k for k, v in base_refs.items()}
base_refs = {
rev_refs[v]: v
for v in reversed(
sd.sort_by_cross_refs(schema, base_refs.values()))
for v in sd.sort_by_cross_refs(schema, base_refs.values())
}

for k, v in base_refs.items():
# HACK: Because of issue #5661, we previously did not always
# properly discover dependencies on __type__ in computeds.
# This was fixed, but it may persist in existing databases.
# Currently, expr refs are not compared when diffing schemas,
# so a schema repair can't fix this. Thus, in addition to
# actually fixing the bug, we hack around it by forcing
# __type__ to sort to the front.
# TODO: Drop this after cherry-picking.
if (tname := sn.UnqualName('__type__')) in base_refs:
base_refs[tname] = base_refs.pop(tname)

for k, v in reversed(base_refs.items()):
if not v.should_propagate(schema):
continue
if base == self.scls and not v.get_owned(schema):
Expand Down
43 changes: 43 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8080,6 +8080,49 @@ def test_schema_migrations_rewrites_01(self):
""",
])

def test_schema_migrations_implicit_type_01(self):
self._assert_migration_equivalence([
r"""
abstract type Pinnable {
property pinned := __source__ in <Pinnable>{};
}
""",
r"""
abstract type Pinnable {
property pinned := __source__ in <Pinnable>{};
}
type Foo extending Pinnable {}
""",
])

def test_schema_migrations_implicit_type_02(self):
self._assert_migration_equivalence([
r"""
abstract type Person {
multi link friends : Person{
constraint expression on (
__subject__ != __subject__@source
);
};
}
""",
r"""
abstract type Person {
multi link friends : Person{
constraint expression on (
__subject__ != __subject__@source
);
};
}
type Employee extending Person{
department: str;
}
""",
])


class TestDescribe(tb.BaseSchemaLoadTest):
"""Test the DESCRIBE command."""
Expand Down
6 changes: 5 additions & 1 deletion tests/test_sql_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,7 +879,11 @@ async def test_sql_query_copy_01(self):
"Movie", output=out, format="csv", delimiter="\t"
)
out = io.StringIO(out.getvalue().decode("utf-8"))
names = set(row[6] for row in csv.reader(out, delimiter="\t"))
# FIXME(#5716): Once COPY and information_schema are
# harmonized to agree on the order of columns, we should query
# information_schema to get the column number instead of
# hardcoding it.
names = set(row[7] for row in csv.reader(out, delimiter="\t"))
self.assertEqual(names, {"Forrest Gump", "Saving Private Ryan"})

async def test_sql_query_error_01(self):
Expand Down

0 comments on commit becd952

Please sign in to comment.