Skip to content

Commit

Permalink
Add optional default to generated params (#426)
Browse files Browse the repository at this point in the history
Co-authored-by: Paul Colomiets <paul@colomiets.name>
  • Loading branch information
fantix and tailhook authored May 26, 2023
1 parent bb7522c commit 21b024a
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 80 deletions.
5 changes: 4 additions & 1 deletion edgedb/codegen/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def _generate(
kw_only = True
for el_name, el in dr.input_type.elements.items():
args[el_name] = self._generate_code_with_cardinality(
el.type, el_name, el.cardinality
el.type, el_name, el.cardinality, keyword_argument=True
)

if self._async:
Expand Down Expand Up @@ -502,6 +502,7 @@ def _generate_code_with_cardinality(
type_: typing.Optional[describe.AnyType],
name_hint: str,
cardinality: edgedb.Cardinality,
keyword_argument: bool = False,
):
rv = self._generate_code(type_, name_hint)
if cardinality == edgedb.Cardinality.AT_MOST_ONE:
Expand All @@ -510,6 +511,8 @@ def _generate_code_with_cardinality(
else:
self._imports.add("typing")
rv = f"typing.Optional[{rv}]"
if keyword_argument:
rv = f"{rv} = None"
return rv

def _find_name(self, name: str) -> str:
Expand Down
50 changes: 25 additions & 25 deletions tests/codegen/test-project2/generated_async_edgeql.py.assert
Original file line number Diff line number Diff line change
Expand Up @@ -165,55 +165,55 @@ async def my_query(
executor: edgedb.AsyncIOExecutor,
*,
a: uuid.UUID,
b: uuid.UUID | None,
b: uuid.UUID | None = None,
c: str,
d: str | None,
d: str | None = None,
e: bytes,
f: bytes | None,
f: bytes | None = None,
g: int,
h: int | None,
h: int | None = None,
i: int,
j: int | None,
j: int | None = None,
k: int,
l: int | None,
l: int | None = None,
m: float,
n: float | None,
n: float | None = None,
o: float,
p: float | None,
p: float | None = None,
q: bool,
r: bool | None,
r: bool | None = None,
s: datetime.datetime,
t: datetime.datetime | None,
t: datetime.datetime | None = None,
u: datetime.datetime,
v: datetime.datetime | None,
v: datetime.datetime | None = None,
w: datetime.date,
x: datetime.date | None,
x: datetime.date | None = None,
y: datetime.time,
z: datetime.time | None,
z: datetime.time | None = None,
aa: datetime.timedelta,
ab: datetime.timedelta | None,
ab: datetime.timedelta | None = None,
ac: int,
ad: int | None,
ad: int | None = None,
ae: edgedb.RelativeDuration,
af: edgedb.RelativeDuration | None,
af: edgedb.RelativeDuration | None = None,
ag: edgedb.DateDuration,
ah: edgedb.DateDuration | None,
ah: edgedb.DateDuration | None = None,
ai: edgedb.ConfigMemory,
aj: edgedb.ConfigMemory | None,
aj: edgedb.ConfigMemory | None = None,
ak: edgedb.Range[int],
al: edgedb.Range[int] | None,
al: edgedb.Range[int] | None = None,
am: edgedb.Range[int],
an: edgedb.Range[int] | None,
an: edgedb.Range[int] | None = None,
ao: edgedb.Range[float],
ap: edgedb.Range[float] | None,
ap: edgedb.Range[float] | None = None,
aq: edgedb.Range[float],
ar: edgedb.Range[float] | None,
ar: edgedb.Range[float] | None = None,
as_: edgedb.Range[datetime.datetime],
at: edgedb.Range[datetime.datetime] | None,
at: edgedb.Range[datetime.datetime] | None = None,
au: edgedb.Range[datetime.datetime],
av: edgedb.Range[datetime.datetime] | None,
av: edgedb.Range[datetime.datetime] | None = None,
aw: edgedb.Range[datetime.date],
ax: edgedb.Range[datetime.date] | None,
ax: edgedb.Range[datetime.date] | None = None,
) -> MyQueryResult:
return await executor.query_single(
"""\
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,55 +91,55 @@ async def my_query(
executor: edgedb.AsyncIOExecutor,
*,
a: uuid.UUID,
b: typing.Optional[uuid.UUID],
b: typing.Optional[uuid.UUID] = None,
c: str,
d: typing.Optional[str],
d: typing.Optional[str] = None,
e: bytes,
f: typing.Optional[bytes],
f: typing.Optional[bytes] = None,
g: int,
h: typing.Optional[int],
h: typing.Optional[int] = None,
i: int,
j: typing.Optional[int],
j: typing.Optional[int] = None,
k: int,
l: typing.Optional[int],
l: typing.Optional[int] = None,
m: float,
n: typing.Optional[float],
n: typing.Optional[float] = None,
o: float,
p: typing.Optional[float],
p: typing.Optional[float] = None,
q: bool,
r: typing.Optional[bool],
r: typing.Optional[bool] = None,
s: datetime.datetime,
t: typing.Optional[datetime.datetime],
t: typing.Optional[datetime.datetime] = None,
u: datetime.datetime,
v: typing.Optional[datetime.datetime],
v: typing.Optional[datetime.datetime] = None,
w: datetime.date,
x: typing.Optional[datetime.date],
x: typing.Optional[datetime.date] = None,
y: datetime.time,
z: typing.Optional[datetime.time],
z: typing.Optional[datetime.time] = None,
aa: datetime.timedelta,
ab: typing.Optional[datetime.timedelta],
ab: typing.Optional[datetime.timedelta] = None,
ac: int,
ad: typing.Optional[int],
ad: typing.Optional[int] = None,
ae: edgedb.RelativeDuration,
af: typing.Optional[edgedb.RelativeDuration],
af: typing.Optional[edgedb.RelativeDuration] = None,
ag: edgedb.DateDuration,
ah: typing.Optional[edgedb.DateDuration],
ah: typing.Optional[edgedb.DateDuration] = None,
ai: edgedb.ConfigMemory,
aj: typing.Optional[edgedb.ConfigMemory],
aj: typing.Optional[edgedb.ConfigMemory] = None,
ak: edgedb.Range[int],
al: typing.Optional[edgedb.Range[int]],
al: typing.Optional[edgedb.Range[int]] = None,
am: edgedb.Range[int],
an: typing.Optional[edgedb.Range[int]],
an: typing.Optional[edgedb.Range[int]] = None,
ao: edgedb.Range[float],
ap: typing.Optional[edgedb.Range[float]],
ap: typing.Optional[edgedb.Range[float]] = None,
aq: edgedb.Range[float],
ar: typing.Optional[edgedb.Range[float]],
ar: typing.Optional[edgedb.Range[float]] = None,
as_: edgedb.Range[datetime.datetime],
at: typing.Optional[edgedb.Range[datetime.datetime]],
at: typing.Optional[edgedb.Range[datetime.datetime]] = None,
au: edgedb.Range[datetime.datetime],
av: typing.Optional[edgedb.Range[datetime.datetime]],
av: typing.Optional[edgedb.Range[datetime.datetime]] = None,
aw: edgedb.Range[datetime.date],
ax: typing.Optional[edgedb.Range[datetime.date]],
ax: typing.Optional[edgedb.Range[datetime.date]] = None,
) -> MyQueryResult:
return await executor.query_single(
"""\
Expand Down
50 changes: 25 additions & 25 deletions tests/codegen/test-project2/parpkg/subpkg/my_query_edgeql.py.assert
Original file line number Diff line number Diff line change
Expand Up @@ -82,55 +82,55 @@ def my_query(
executor: edgedb.Executor,
*,
a: uuid.UUID,
b: typing.Optional[uuid.UUID],
b: typing.Optional[uuid.UUID] = None,
c: str,
d: typing.Optional[str],
d: typing.Optional[str] = None,
e: bytes,
f: typing.Optional[bytes],
f: typing.Optional[bytes] = None,
g: int,
h: typing.Optional[int],
h: typing.Optional[int] = None,
i: int,
j: typing.Optional[int],
j: typing.Optional[int] = None,
k: int,
l: typing.Optional[int],
l: typing.Optional[int] = None,
m: float,
n: typing.Optional[float],
n: typing.Optional[float] = None,
o: float,
p: typing.Optional[float],
p: typing.Optional[float] = None,
q: bool,
r: typing.Optional[bool],
r: typing.Optional[bool] = None,
s: datetime.datetime,
t: typing.Optional[datetime.datetime],
t: typing.Optional[datetime.datetime] = None,
u: datetime.datetime,
v: typing.Optional[datetime.datetime],
v: typing.Optional[datetime.datetime] = None,
w: datetime.date,
x: typing.Optional[datetime.date],
x: typing.Optional[datetime.date] = None,
y: datetime.time,
z: typing.Optional[datetime.time],
z: typing.Optional[datetime.time] = None,
aa: datetime.timedelta,
ab: typing.Optional[datetime.timedelta],
ab: typing.Optional[datetime.timedelta] = None,
ac: int,
ad: typing.Optional[int],
ad: typing.Optional[int] = None,
ae: edgedb.RelativeDuration,
af: typing.Optional[edgedb.RelativeDuration],
af: typing.Optional[edgedb.RelativeDuration] = None,
ag: edgedb.DateDuration,
ah: typing.Optional[edgedb.DateDuration],
ah: typing.Optional[edgedb.DateDuration] = None,
ai: edgedb.ConfigMemory,
aj: typing.Optional[edgedb.ConfigMemory],
aj: typing.Optional[edgedb.ConfigMemory] = None,
ak: edgedb.Range[int],
al: typing.Optional[edgedb.Range[int]],
al: typing.Optional[edgedb.Range[int]] = None,
am: edgedb.Range[int],
an: typing.Optional[edgedb.Range[int]],
an: typing.Optional[edgedb.Range[int]] = None,
ao: edgedb.Range[float],
ap: typing.Optional[edgedb.Range[float]],
ap: typing.Optional[edgedb.Range[float]] = None,
aq: edgedb.Range[float],
ar: typing.Optional[edgedb.Range[float]],
ar: typing.Optional[edgedb.Range[float]] = None,
as_: edgedb.Range[datetime.datetime],
at: typing.Optional[edgedb.Range[datetime.datetime]],
at: typing.Optional[edgedb.Range[datetime.datetime]] = None,
au: edgedb.Range[datetime.datetime],
av: typing.Optional[edgedb.Range[datetime.datetime]],
av: typing.Optional[edgedb.Range[datetime.datetime]] = None,
aw: edgedb.Range[datetime.date],
ax: typing.Optional[edgedb.Range[datetime.date]],
ax: typing.Optional[edgedb.Range[datetime.date]] = None,
) -> MyQueryResult:
return executor.query_single(
"""\
Expand Down
17 changes: 13 additions & 4 deletions tests/test_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,13 @@ async def test_codegen(self):
for project in container.iterdir():
if project.name == "linked":
continue
cwd = td_path / project.name
shutil.copytree(project, cwd)
await self._test_codegen(env, cwd)
with self.subTest(msg=project.name):
cwd = td_path / project.name
shutil.copytree(project, cwd)
try:
await self._test_codegen(env, cwd)
except subprocess.CalledProcessError as e:
self.fail("Codegen failed: " + e.stdout.decode())

async def _test_codegen(self, env, cwd: pathlib.Path):
async def run(*args, extra_env=None):
Expand All @@ -67,6 +71,11 @@ async def run(*args, extra_env=None):
p.terminate()
await p.wait()
raise
else:
if p.returncode:
raise subprocess.CalledProcessError(
p.returncode, args, output=await p.stdout.read(),
)

cmd = env.get("EDGEDB_PYTHON_TEST_CODEGEN_CMD", "edgedb-py")
await run(
Expand All @@ -90,7 +99,7 @@ async def run(*args, extra_env=None):

for f in cwd.rglob("*.py"):
a = f.with_suffix(".py.assert")
self.assertEqual(f.read_text(), a.read_text())
self.assertEqual(f.read_text(), a.read_text(), msg=f.name)
for a in cwd.rglob("*.py.assert"):
f = a.with_suffix("")
self.assertTrue(f.exists(), f"{f} doesn't exist")

0 comments on commit 21b024a

Please sign in to comment.