Skip to content

Commit

Permalink
Merge pull request #8 from giannisdoukas/optional-list-type
Browse files Browse the repository at this point in the history
Fix Optional-Lists
  • Loading branch information
giannisdoukas committed Jun 30, 2020
2 parents fac9981 + 7b858a0 commit 6fc8195
Show file tree
Hide file tree
Showing 9 changed files with 194 additions and 57 deletions.
1 change: 1 addition & 0 deletions .gitignore
Expand Up @@ -246,3 +246,4 @@ tmp.py
/html/
cwlbuild
/tests/repo-like/result.yaml
/tests/repo-like/messages.txt
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -42,4 +42,4 @@ jupyter repo2cwl https://github.com/giannisdoukas/cwl-annotated-jupyter-notebook

### Docs

[https://ipython2cwl.readthedocs.io/](https://readthedocs.org/projects/ipython2cwl/badge/?version=latest)
[https://ipython2cwl.readthedocs.io/](https://ipython2cwl.readthedocs.io/en/latest/)
83 changes: 53 additions & 30 deletions ipython2cwl/cwltoolextractor.py
Expand Up @@ -22,27 +22,35 @@
SETUP_TEMPLATE = f.read()


# TODO: does not support recursion if main function exists
# TODO: check if supports recursion if main function exists

class AnnotatedVariablesExtractor(ast.NodeTransformer):
input_type_mapper = {
CWLFilePathInput.__name__: (
(CWLFilePathInput.__name__,): (
'File',
'pathlib.Path',
),
CWLBooleanInput.__name__: (
(CWLBooleanInput.__name__,): (
'boolean',
'lambda flag: flag.upper() == "TRUE"',
),
CWLIntInput.__name__: (
(CWLIntInput.__name__,): (
'int',
'int',
),
CWLStringInput.__name__: (
(CWLStringInput.__name__,): (
'string',
'str',
),
}
input_type_mapper = {**input_type_mapper, **{
('List', *(t for t in types_names)): (types[0] + "[]", types[1])
for types_names, types in input_type_mapper.items()
}, **{
('Optional', *(t for t in types_names)): (types[0] + "?", types[1])
for types_names, types in input_type_mapper.items()
}}

output_type_mapper = {
CWLFilePathOutput.__name__
}
Expand All @@ -51,30 +59,29 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.extracted_nodes = []

def __get_annotation__(self, type_annotation):
annotation = None
if isinstance(type_annotation, ast.Name):
annotation = (type_annotation.id,)
elif isinstance(type_annotation, ast.Str):
annotation = (type_annotation.s,)
ann_expr = ast.parse(type_annotation.s.strip()).body[0]
if hasattr(ann_expr, 'value') and isinstance(ann_expr.value, ast.Subscript):
annotation = self.__get_annotation__(ann_expr.value)
elif isinstance(type_annotation, ast.Subscript):
annotation = (type_annotation.value.id, *self.__get_annotation__(type_annotation.slice.value))
return annotation

def visit_AnnAssign(self, node):
try:
if (isinstance(node.annotation, ast.Name) and node.annotation.id in self.input_type_mapper) or \
(isinstance(node.annotation, ast.Str) and node.annotation.s in self.input_type_mapper):
mapper = self.input_type_mapper[node.annotation.id]
annotation = self.__get_annotation__(node.annotation)
if annotation in self.input_type_mapper:
mapper = self.input_type_mapper[annotation]
self.extracted_nodes.append(
(node, mapper[0], mapper[1], True, True, False)
(node, mapper[0], mapper[1], not mapper[0].endswith('?'), True, False)
)
return None
elif isinstance(node.annotation, ast.Subscript):
if node.annotation.value.id == "Optional" \
and node.annotation.slice.value.id in self.input_type_mapper:
mapper = self.input_type_mapper[node.annotation.slice.value.id]
self.extracted_nodes.append(
(node, mapper[0] + '?', mapper[1], False, True, False)
)
return None
elif node.annotation.value.id == "List" \
and node.annotation.slice.value.id in self.input_type_mapper:
mapper = self.input_type_mapper[node.annotation.slice.value.id]
self.extracted_nodes.append(
(node, mapper[0] + '[]', mapper[1], True, True, False)
)
return None

elif (isinstance(node.annotation, ast.Name) and node.annotation.id in self.output_type_mapper) or \
(isinstance(node.annotation, ast.Str) and node.annotation.s in self.output_type_mapper):
self.extracted_nodes.append(
Expand All @@ -87,7 +94,7 @@ def visit_AnnAssign(self, node):
targets=[node.target],
value=node.value
)
except AttributeError:
except Exception:
pass
return node

Expand Down Expand Up @@ -152,6 +159,7 @@ def from_jupyter_notebook_node(cls, node: NotebookNode) -> 'AnnotatedIPython2CWL

@classmethod
def _wrap_script_to_method(cls, tree, variables) -> str:
add_args = cls.__get_add_arguments__([v for v in variables if v.is_input])
main_template_code = os.linesep.join([
f"def main({','.join([v.name for v in variables if v.is_input])}):",
"\tpass",
Expand All @@ -160,19 +168,34 @@ def _wrap_script_to_method(cls, tree, variables) -> str:
"import argparse",
'import pathlib',
"parser = argparse.ArgumentParser()",
*[f'parser.add_argument("--{variable.name}", '
f'type={variable.argparse_typeof}, '
f'required={variable.required})'
for variable in variables],
*add_args,
"args = parser.parse_args()",
f"main({','.join([f'{v.name}=args.{v.name}' for v in variables if v.is_input])})"
f"main({','.join([f'{v.name}=args.{v.name} ' for v in variables if v.is_input])})"
]],
])
main_function = ast.parse(main_template_code)
[node for node in main_function.body if isinstance(node, ast.FunctionDef) and node.name == 'main'][0] \
.body = tree.body
return astor.to_source(main_function)

@classmethod
def __get_add_arguments__(cls, variables):
args = []
for variable in variables:
is_array = variable.cwl_typeof.endswith('[]')
is_optional = variable.cwl_typeof.endswith('?')
arg: str = f'parser.add_argument("--{variable.name}", '
arg += f'type={variable.argparse_typeof}, '
arg += f'required={variable.required}, '
if is_array:
arg += f'nargs="+", '
if is_optional:
arg += f'default=None, '
arg = arg.strip()
arg += ')'
args.append(arg)
return args

def cwl_command_line_tool(self, docker_image_id: str = 'jn2cwl:latest') -> Dict:
"""
Creates the description of the CWL Command Line Tool.
Expand Down
29 changes: 13 additions & 16 deletions tests/repo-like/example1.ipynb
Expand Up @@ -2,30 +2,24 @@
"cells": [
{
"cell_type": "code",
"execution_count": 3,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from ipython2cwl.iotypes import CWLFilePathInput, CWLFilePathOutput\n",
"import yaml"
"from ipython2cwl.iotypes import CWLFilePathInput, CWLStringInput, CWLFilePathOutput\n",
"from typing import List\n",
"import yaml\n",
"import os"
]
},
{
"cell_type": "code",
"execution_count": 4,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'entry1': 1, 'entry2': 'foo', 'entry3': 'bar'}\n"
]
}
],
"outputs": [],
"source": [
"datafilename: CWLFilePathInput = 'data.yaml'\n",
"\n",
"messages: List[CWLStringInput] = ['hello', 'world']\n",
"with open(datafilename) as fd: \n",
" data = yaml.safe_load(fd)\n",
"print(data)"
Expand All @@ -42,13 +36,16 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"results_filename: CWLFilePathOutput = 'result.yaml'\n",
"with open(results_filename, 'w') as fd:\n",
" yaml.safe_dump(data, fd)"
" yaml.safe_dump(data, fd)\n",
"messages_outputs: CWLFilePathOutput = 'messages.txt'\n",
"with open(messages_outputs, 'w') as f:\n",
" f.write(' '.join(messages))"
]
}
],
Expand Down
3 changes: 1 addition & 2 deletions tests/repo-like/requirements.txt
@@ -1,2 +1 @@
PyYAML==5.3.1
ipython2cwl==0.0.1
PyYAML==5.3.1
38 changes: 31 additions & 7 deletions tests/simple.ipynb
Expand Up @@ -2,27 +2,30 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"from typing import List, Optional\n",
"import matplotlib\n",
"from ipython2cwl.iotypes import CWLFilePathInput, CWLFilePathOutput"
"from ipython2cwl.iotypes import CWLFilePathInput, CWLFilePathOutput, CWLStringInput"
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"dataset: CWLFilePathInput = 'example.csv'"
"dataset: CWLFilePathInput = './data/data.csv'\n",
"messages: List[CWLStringInput] = ['hello', 'world']\n",
"optional_message: Optional[CWLStringInput] = \"Hello from optional\""
]
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -36,17 +39,38 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# transform data\n",
"data.sort_values(by='Random B', ascending=False, inplace=True, ignore_index=True)\n",
"data.sort_values(by='Y', ascending=False, inplace=True, ignore_index=True)\n",
"fig = data.plot()\n",
"\n",
"after_transform_data: CWLFilePathOutput = 'new_data.png'\n",
"fig.figure.savefig(after_transform_data)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"messages_filename = 'messages.txt'\n",
"with open(messages_filename, 'w') as f:\n",
" f.write(' '.join(messages))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"if optional_message is not None:\n",
" print(optional_message)"
]
}
],
"metadata": {
Expand Down
49 changes: 49 additions & 0 deletions tests/test_cwltool.py → tests/test_cwltoolextractor.py
Expand Up @@ -328,3 +328,52 @@ def test_AnnotatedIPython2CWLToolConverter_exclamation_mark_command(self):
exec(new_script_without_magics)
locals()['main']('original\n!ls -l')
self.assertEqual('original\n!ls -l', globals()['printed_message'])

def test_AnnotatedIPython2CWLToolConverter_optional_array_input(self):
s1 = os.linesep.join([
'x1: CWLBooleanInput = True',
])
s2 = os.linesep.join([
'x1: "CWLBooleanInput" = True',
])
# all variables must be the same
self.assertEqual(
AnnotatedIPython2CWLToolConverter(s1)._variables[0],
AnnotatedIPython2CWLToolConverter(s2)._variables[0],
)

s1 = os.linesep.join([
'x1: Optional[CWLBooleanInput] = True',
])
s2 = os.linesep.join([
'x1: "Optional[CWLBooleanInput]" = True',
])
s3 = os.linesep.join([
'x1: Optional["CWLBooleanInput"] = True',
])
# all variables must be the same
self.assertEqual(
AnnotatedIPython2CWLToolConverter(s1)._variables[0],
AnnotatedIPython2CWLToolConverter(s2)._variables[0],
)
self.assertEqual(
AnnotatedIPython2CWLToolConverter(s1)._variables[0],
AnnotatedIPython2CWLToolConverter(s3)._variables[0],
)

# test that does not crash
self.assertListEqual([], AnnotatedIPython2CWLToolConverter(os.linesep.join([
'x1: RandomHint = True'
]))._variables)
self.assertListEqual([], AnnotatedIPython2CWLToolConverter(os.linesep.join([
'x1: List[RandomHint] = True'
]))._variables)
self.assertListEqual([], AnnotatedIPython2CWLToolConverter(os.linesep.join([
'x1: List["RandomHint"] = True'
]))._variables)
self.assertListEqual([], AnnotatedIPython2CWLToolConverter(os.linesep.join([
'x1: "List[List[Union[RandomHint, Foo]]]" = True'
]))._variables)
self.assertListEqual([], AnnotatedIPython2CWLToolConverter(os.linesep.join([
'x1: "RANDOM CHARACTERS!!!!!!" = True'
]))._variables)

0 comments on commit 6fc8195

Please sign in to comment.