Skip to content

Commit

Permalink
Merge pull request #332 from illusional/optional-codegen-args
Browse files Browse the repository at this point in the history
Optional codegen args
  • Loading branch information
mr-c committed May 21, 2020
2 parents 8ac0a15 + 09b805e commit ea0518b
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 13 deletions.
13 changes: 11 additions & 2 deletions schema_salad/codegen.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Generate langauge specific loaders for a particular SALAD schema."""
import sys
from io import open
from typing import Any, Dict, List, MutableMapping, Optional
from typing import Any, Dict, List, MutableMapping, MutableSequence, Optional

from . import schema
from .codegen_base import CodeGenBase
Expand Down Expand Up @@ -63,8 +63,16 @@ def codegen(
document_roots.append(rec["name"])

field_names = []
optional_fields = set()
for field in rec.get("fields", []):
field_names.append(shortname(field["name"]))
field_name = shortname(field["name"])
field_names.append(field_name)
tp = field["type"]
if (
isinstance(tp, MutableSequence)
and tp[0] == "https://w3id.org/cwl/salad#null"
):
optional_fields.add(field_name)

idfield = ""
for field in rec.get("fields", []):
Expand All @@ -78,6 +86,7 @@ def codegen(
rec.get("abstract", False),
field_names,
idfield,
optional_fields,
)
gen.add_vocab(shortname(rec["name"]), rec["name"])

Expand Down
3 changes: 2 additions & 1 deletion schema_salad/codegen_base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Base class for the generation of loaders from schema-salad definitions."""
import collections
from typing import Any, Dict, List, MutableSequence, Optional, Union
from typing import Any, Dict, List, MutableSequence, Optional, Union, Set

from . import schema

Expand Down Expand Up @@ -76,6 +76,7 @@ def begin_class(
abstract: bool,
field_names: MutableSequence[str],
idfield: str,
optional_fields: Set[str],
) -> None:
"""Produce the header for the given class."""
raise NotImplementedError()
Expand Down
12 changes: 11 additions & 1 deletion schema_salad/java_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,16 @@
import string
from io import StringIO
from io import open as io_open
from typing import Any, Dict, List, MutableMapping, MutableSequence, Optional, Union
from typing import (
Any,
Dict,
List,
MutableMapping,
MutableSequence,
Optional,
Union,
Set,
)
from urllib.parse import urlsplit

import pkg_resources
Expand Down Expand Up @@ -154,6 +163,7 @@ def begin_class(
abstract: bool,
field_names: MutableSequence[str],
idfield: str,
optional_fields: Set[str],
) -> None:
cls = self.interface_name(classname)
self.current_class = cls
Expand Down
41 changes: 32 additions & 9 deletions schema_salad/python_codegen.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
"""Python code generator for a given schema salad definition."""
from io import StringIO
from typing import IO, Any, Dict, List, MutableMapping, MutableSequence, Optional, Union
from typing import (
IO,
Any,
Dict,
List,
MutableMapping,
MutableSequence,
Optional,
Union,
Set,
)

from pkg_resources import resource_stream

Expand Down Expand Up @@ -74,12 +84,13 @@ def prologue(self):

def begin_class(
self, # pylint: disable=too-many-arguments
classname, # type: str
extends, # type: MutableSequence[str]
doc, # type: str
abstract, # type: bool
field_names, # type: MutableSequence[str]
idfield, # type: str
classname: str,
extends: MutableSequence[str],
doc: str,
abstract: bool,
field_names: MutableSequence[str],
idfield: str,
optional_fields: Set[str],
): # type: (...) -> None
classname = self.safe_name(classname)

Expand All @@ -102,11 +113,21 @@ def begin_class(
self.out.write(" pass\n\n\n")
return

required_field_names = [f for f in field_names if f not in optional_fields]
optional_field_names = [f for f in field_names if f in optional_fields]

safe_inits = [" self,"] # type: List[str]
safe_inits.extend(
[
" {}, # type: Any".format(self.safe_name(f))
for f in field_names
for f in required_field_names
if f != "class"
]
)
safe_inits.extend(
[
" {}=None, # type: Any".format(self.safe_name(f))
for f in optional_field_names
if f != "class"
]
)
Expand Down Expand Up @@ -237,10 +258,12 @@ def end_class(self, classname, field_names):
" attrs = frozenset({attrs})\n".format(attrs=field_names)
)

safe_inits = [
safe_init_fields = [
self.safe_name(f) for f in field_names if f != "class"
] # type: List[str]

safe_inits = [f + "=" + f for f in safe_init_fields]

safe_inits.extend(
["extension_fields=extension_fields", "loadingOptions=loadingOptions"]
)
Expand Down

0 comments on commit ea0518b

Please sign in to comment.