Skip to content

Commit

Permalink
Fix more extension migration problems (#5626)
Browse files Browse the repository at this point in the history
 * Casts actually can get dropped now, so fix collection handling
 * We need to always create collections with if_not_exists and
   if_unused when computing deltas, since we can't see into the
   extension to know if they will do it.
 * Totally rework extension delete. Calling delta_objects from the
   actual DDL execution flow is *not* a good idea and we shouldn't
   do it.
 * Support putting extension modules in ext::
  • Loading branch information
msullivan committed Jun 9, 2023
1 parent 47dc6a0 commit 02d2119
Show file tree
Hide file tree
Showing 10 changed files with 108 additions and 73 deletions.
1 change: 1 addition & 0 deletions edb/lib/std/00-prelude.edgeql
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@


CREATE MODULE std;
CREATE MODULE ext;
3 changes: 2 additions & 1 deletion edb/pgsql/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -6919,6 +6919,7 @@ def apply(
version = self.scls.get_version(schema)._asdict()
version['stage'] = version['stage'].name.lower()

ext_module = self.scls.get_ext_module(schema)
metadata = {
ext_id: {
'id': ext_id,
Expand All @@ -6928,7 +6929,7 @@ def apply(
'version': version,
'builtin': self.scls.get_builtin(schema),
'internal': self.scls.get_internal(schema),
'ext_module': str(self.scls.get_ext_module(schema)),
'ext_module': ext_module and str(ext_module),
'sql_extensions': list(self.scls.get_sql_extensions(schema)),
}
}
Expand Down
15 changes: 15 additions & 0 deletions edb/schema/casts.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,3 +457,18 @@ class AlterCast(CastCommand, sd.AlterObject[Cast]):

class DeleteCast(CastCommand, sd.DeleteObject[Cast]):
astnode = qlast.DropCast

def _delete_begin(
self,
schema: s_schema.Schema,
context: sd.CommandContext,
) -> s_schema.Schema:
schema = super()._delete_begin(schema, context)
if not context.canonical:
from_type = self.scls.get_from_type(schema)
if op := from_type.as_type_delete_if_dead(schema):
self.add_caused(op)
to_type = self.scls.get_to_type(schema)
if op := to_type.as_type_delete_if_dead(schema):
self.add_caused(op)
return schema
10 changes: 0 additions & 10 deletions edb/schema/delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -2492,16 +2492,6 @@ def canonicalize_attributes(
"""
return schema

def canonicalize_attributes_recursively(
self,
schema: s_schema.Schema,
context: CommandContext,
) -> s_schema.Schema:
schema = self.canonicalize_attributes(schema, context)
for sub in self.get_subcommands(type=ObjectCommand):
schema = sub.canonicalize_attributes_recursively(schema, context)
return schema

def update_field_status(
self,
schema: s_schema.Schema,
Expand Down
67 changes: 31 additions & 36 deletions edb/schema/extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
from . import annos as s_anno
from . import casts as s_casts
from . import delta as sd
from . import functions as s_func
from . import modules as s_mod
from . import name as sn
from . import objects as so
from . import schema as s_schema
Expand Down Expand Up @@ -64,7 +66,7 @@ class ExtensionPackage(
)

ext_module = so.SchemaField(
str, default=None, coerce=True, compcoef=0.9)
str, default=None, compcoef=0.9)

@classmethod
def get_schema_class_displayname(cls) -> str:
Expand All @@ -88,6 +90,7 @@ class Extension(

package = so.SchemaField(
ExtensionPackage,
compcoef=0.0,
)


Expand Down Expand Up @@ -334,20 +337,20 @@ class DeleteExtension(

astnode = qlast.DropExtension

def _delete_begin(
def _canonicalize(
self,
schema: s_schema.Schema,
context: sd.CommandContext,
) -> s_schema.Schema:
module = self.scls.get_package(schema).get_ext_module(schema)
schema = super()._delete_begin(schema, context)
scls: Extension,
) -> List[sd.Command]:
commands = super()._canonicalize(schema, context, scls)

if context.canonical or not module:
return schema
module = scls.get_package(schema).get_ext_module(schema)

# If the extension included a module, delete everything in it.
from . import ddl as s_ddl
if not module:
return commands

# If the extension included a module, delete everything in it.
module_name = sn.UnqualName(module)

def _name_in_mod(name: sn.Name) -> bool:
Expand All @@ -366,32 +369,24 @@ def _name_in_mod(name: sn.Name) -> bool:
or _name_in_mod(obj.get_to_type(schema).get_name(schema))
):
drop = obj.init_delta_command(
schema,
sd.DeleteObject,
schema, sd.DeleteObject
)
self.add(drop)

def filt(schema: s_schema.Schema, obj: so.Object) -> bool:
return not _name_in_mod(obj.get_name(schema)) or obj == self.scls

# We handle deleting the module contents in a heavy-handed way:
# do a schema diff.
delta = s_ddl.delta_schemas(
schema, schema,
included_modules=[
sn.UnqualName(module),
],
schema_b_filters=[filt],
include_extensions=True,
linearize_delta=True,
)
# delta_schemas claims everything is canonical, because all
# objects should be present, and its main audience is dumping
# it out as an AST. But the results will be filled with
# shells, which we need to get rid of in order for certain
# reflection things to work (annotations, for one).
for sub in delta.get_subcommands(type=sd.ObjectCommand):
schema = sub.canonicalize_attributes_recursively(schema, context)
self.add(sub)
commands.append(drop)

return schema
# Delete everything in the module
for obj in schema.get_objects(
included_modules=(module_name,),
):
if not isinstance(obj, s_func.Parameter):
drop, _, _ = obj.init_delta_branch(
schema, context, sd.DeleteObject
)
commands.append(drop)

# We add the module delete directly as add_caused, since the sorting
# we do doesn't work.
module_obj = schema.get_global(s_mod.Module, module_name)

self.add_caused(module_obj.init_delta_command(schema, sd.DeleteObject))

return commands
12 changes: 10 additions & 2 deletions edb/schema/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,11 @@

from . import annos as s_anno
from . import delta as sd
from . import name as sn
from . import objects as so
from . import schema as s_schema

RESERVED_MODULE_NAMES = {
'ext',
'super',
}

Expand Down Expand Up @@ -66,6 +66,7 @@ def _validate_legal_command(
super()._validate_legal_command(schema, context)

last = str(self.classname)
enclosing = None
if '::' in str(self.classname):
enclosing, _, last = str(self.classname).rpartition('::')
if not schema.has_module(enclosing):
Expand All @@ -78,7 +79,14 @@ def _validate_legal_command(

if (
not context.stdmode and not context.testmode
and (modname := self.classname) in s_schema.STD_MODULES
and (
(modname := self.classname) in s_schema.STD_MODULES
or (
enclosing
and (modname := sn.UnqualName(enclosing))
in s_schema.STD_MODULES
)
)
):
raise errors.SchemaDefinitionError(
f'cannot {self._delta_action} {self.get_verbosename()}: '
Expand Down
1 change: 1 addition & 0 deletions edb/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
sn.UnqualName('pg'),
sn.UnqualName('std::_test'),
sn.UnqualName('fts'),
sn.UnqualName('ext'),
)

# Specifies the order of processing of files and directories in lib/
Expand Down
25 changes: 25 additions & 0 deletions edb/schema/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,31 @@ def __init_subclass__(
_collection_impls[schema_name] = cls
cls._schema_name = schema_name

def as_create_delta(
self: CollectionTypeT,
schema: s_schema.Schema,
context: so.ComparisonContext,
) -> sd.ObjectCommand[CollectionTypeT]:
delta = super().as_create_delta(schema=schema, context=context)
assert isinstance(delta, sd.CreateObject)
if not isinstance(self, CollectionExprAlias):
delta.if_not_exists = True
return delta

def as_delete_delta(
self: CollectionTypeT,
*,
schema: s_schema.Schema,
context: so.ComparisonContext,
) -> sd.ObjectCommand[CollectionTypeT]:
delta = super().as_delete_delta(schema=schema, context=context)
assert isinstance(delta, sd.DeleteObject)
if not isinstance(self, CollectionExprAlias):
delta.if_exists = True
delta.if_unused = True
delta.canonical = False
return delta

@classmethod
def get_displayname_static(cls, name: s_name.Name) -> str:
if isinstance(name, s_name.QualName):
Expand Down
36 changes: 23 additions & 13 deletions tests/test_edgeql_ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15646,7 +15646,7 @@ async def _extension_test_02a(self):
''')

await self.con.execute('''
create scalar type vc5 extending varchar::varchar<5>;
create scalar type vc5 extending ext::varchar::varchar<5>;
create type X {
create property foo: vc5;
};
Expand All @@ -15656,13 +15656,16 @@ async def _extension_test_02a(self):
'''
describe scalar type vc5;
''',
['create scalar type default::vc5 extending varchar::varchar<5>;'],
[
'create scalar type default::vc5 '
'extending ext::varchar::varchar<5>;'
],
)
await self.assert_query_result(
'''
describe scalar type vc5 as sdl;
''',
['scalar type default::vc5 extending varchar::varchar<5>;'],
['scalar type default::vc5 extending ext::varchar::varchar<5>;'],
)

await self.assert_query_result(
Expand Down Expand Up @@ -15698,7 +15701,7 @@ async def _extension_test_02a(self):
"invalid scalar type argument",
):
await self.con.execute('''
create scalar type fail extending varchar::varchar<foo>;
create scalar type fail extending ext::varchar::varchar<foo>;
''')

async with self.assertRaisesRegexTx(
Expand All @@ -15714,12 +15717,12 @@ async def _extension_test_02a(self):
"incorrect number of arguments",
):
await self.con.execute('''
create scalar type yyy extending varchar::varchar<1, 2>;
create scalar type yyy extending ext::varchar::varchar<1, 2>;
''')

# If no params are specified, it just makes a normal scalar type
await self.con.execute('''
create scalar type vc extending varchar::varchar {
create scalar type vc extending ext::varchar::varchar {
create constraint expression on (false);
};
''')
Expand All @@ -15736,7 +15739,7 @@ async def _extension_test_02b(self):
START MIGRATION TO {
using extension varchar version "1.0";
module default {
scalar type vc5 extending varchar::varchar<5>;
scalar type vc5 extending ext::varchar::varchar<5>;
type X {
foo: vc5;
};
Expand Down Expand Up @@ -15784,27 +15787,34 @@ async def test_edgeql_ddl_extensions_02(self):
# Make an extension that wraps some of varchar
await self.con.execute('''
create extension package varchar VERSION '1.0' {
set ext_module := "varchar";
set ext_module := "ext::varchar";
set sql_extensions := [];
create module varchar;
create scalar type varchar::varchar {
create module ext::varchar;
create scalar type ext::varchar::varchar {
create annotation std::description := 'why are we doing this';
set id := <uuid>'26dc1396-0196-11ee-a005-ad0eaed0df03';
set sql_type := "varchar";
set sql_type_scheme := "varchar({__arg_0__})";
set num_params := 1;
};
create cast from varchar::varchar to std::str {
create cast from ext::varchar::varchar to std::str {
SET volatility := 'Immutable';
USING SQL CAST;
};
create cast from std::str to varchar::varchar {
create cast from std::str to ext::varchar::varchar {
SET volatility := 'Immutable';
USING SQL CAST;
};
# This is meaningless but I need to test having an array in a cast.
create cast from ext::varchar::varchar to array<std::float32> {
SET volatility := 'Immutable';
USING SQL $$
select array[0.0]
$$
};
create abstract index varchar::with_param(
create abstract index ext::varchar::with_param(
named only lists: int64
) {
set code := ' ((__col__) NULLS FIRST)';
Expand Down
11 changes: 0 additions & 11 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,17 +817,6 @@ def test_schema_module_reserved_01(self):
}
"""

@tb.must_fail(
errors.SchemaDefinitionError,
"module 'ext' is a reserved module name"
)
def test_schema_module_reserved_02(self):
"""
module foo {
module ext {}
}
"""

@tb.must_fail(
errors.InvalidFunctionDefinitionError,
r"cannot create the `test::foo\(VARIADIC bar: "
Expand Down

0 comments on commit 02d2119

Please sign in to comment.