Skip to content

Commit

Permalink
add typing + fix small issues (#91)
Browse files Browse the repository at this point in the history
* add typing

* fix build
  • Loading branch information
leo-schick committed Feb 24, 2023
1 parent 3a9f74a commit 1743ec9
Show file tree
Hide file tree
Showing 25 changed files with 153 additions and 131 deletions.
8 changes: 4 additions & 4 deletions mara_pipelines/commands/bash.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Commands for running bash scripts"""

from typing import Union, Callable
from typing import Union, Callable, List, Tuple

from mara_page import html
from .. import pipelines
Expand All @@ -18,13 +18,13 @@ def __init__(self, command: Union[str, Callable]) -> None:
self._command = command

@property
def command(self):
def command(self) -> str:
return (self._command() if callable(self._command) else self._command).strip()

def shell_command(self):
def shell_command(self) -> str:
return self.command

def html_doc_items(self) -> [(str, str)]:
def html_doc_items(self) -> List[Tuple[str, str]]:
return [
('command', html.highlight_syntax(self.shell_command(), 'bash'))
]
23 changes: 12 additions & 11 deletions mara_pipelines/commands/files.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pathlib
import shlex
import sys
from typing import List, Tuple, Dict

import enum

Expand Down Expand Up @@ -76,7 +77,7 @@ def shell_command(self):
def mapper_file_path(self):
return self.parent.parent.base_path() / self.mapper_script_file_name

def html_doc_items(self) -> [(str, str)]:
def html_doc_items(self) -> List[Tuple[str, str]]:
return [('file name', _.i[self.file_name]),
('compression', _.tt[self.compression]),
('mapper script file name', _.i[self.mapper_script_file_name]),
Expand All @@ -89,17 +90,17 @@ def html_doc_items(self) -> [(str, str)]:
('csv format', _.tt[self.csv_format]),
('skip header', _.tt[self.skip_header]),
('delimiter char',
_.tt[json.dumps(self.delimiter_char) if self.delimiter_char != None else None]),
('quote char', _.tt[json.dumps(self.quote_char) if self.quote_char != None else None]),
_.tt[json.dumps(self.delimiter_char) if self.delimiter_char is not None else None]),
('quote char', _.tt[json.dumps(self.quote_char) if self.quote_char is not None else None]),
('null value string',
_.tt[json.dumps(self.null_value_string) if self.null_value_string != None else None]),
_.tt[json.dumps(self.null_value_string) if self.null_value_string is not None else None]),
('time zone', _.tt[self.timezone]),
(_.i['shell command'], html.highlight_syntax(self.shell_command(), 'bash'))]


class ReadSQLite(sql._SQLCommand):
def __init__(self, sqlite_file_name: str, target_table: str,
sql_statement: str = None, sql_file_name: str = None, replace: {str: str} = None,
sql_statement: str = None, sql_file_name: str = None, replace: Dict[str, str] = None,
db_alias: str = None, timezone: str = None) -> None:
sql._SQLCommand.__init__(self, sql_statement, sql_file_name, replace)
self.sqlite_file_name = sqlite_file_name
Expand All @@ -115,10 +116,10 @@ def db_alias(self):
def shell_command(self):
return (sql._SQLCommand.shell_command(self)
+ ' | ' + mara_db.shell.copy_command(
mara_db.dbs.SQLiteDB(file_name=config.data_dir().absolute() / self.sqlite_file_name),
mara_db.dbs.SQLiteDB(file_name=pathlib.Path(config.data_dir()).absolute() / self.sqlite_file_name),
self.db_alias, self.target_table, timezone=self.timezone))

def html_doc_items(self) -> [(str, str)]:
def html_doc_items(self) -> List[Tuple[str, str]]:
return [('sqlite file name', _.i[self.sqlite_file_name])] \
+ sql._SQLCommand.html_doc_items(self, None) \
+ [('target_table', _.tt[self.target_table]),
Expand Down Expand Up @@ -161,7 +162,7 @@ def shell_command(self):
def file_path(self):
return self.parent.parent.base_path() / self.file_name

def html_doc_items(self) -> [(str, str)]:
def html_doc_items(self) -> List[Tuple[str, str]]:
return [('file name', _.i[self.file_name]),
(_.i['content'], html.highlight_syntax(self.file_path().read_text().strip('\n')
if self.file_name and self.file_path().exists()
Expand All @@ -170,9 +171,9 @@ def html_doc_items(self) -> [(str, str)]:
('target_table', _.tt[self.target_table]),
('db alias', _.tt[self.db_alias()]),
('delimiter char',
_.tt[json.dumps(self.delimiter_char) if self.delimiter_char != None else None]),
('quote char', _.tt[json.dumps(self.quote_char) if self.quote_char != None else None]),
_.tt[json.dumps(self.delimiter_char) if self.delimiter_char is not None else None]),
('quote char', _.tt[json.dumps(self.quote_char) if self.quote_char is not None else None]),
('null value string',
_.tt[json.dumps(self.null_value_string) if self.null_value_string != None else None]),
_.tt[json.dumps(self.null_value_string) if self.null_value_string is not None else None]),
('time zone', _.tt[self.timezone]),
(_.i['shell command'], html.highlight_syntax(self.shell_command(), 'bash'))]
6 changes: 4 additions & 2 deletions mara_pipelines/commands/http.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
"""Commands for interacting with HTTP"""

from typing import List, Tuple, Dict

from mara_page import html, _
from .. import pipelines
from ..shell import http_request_command


class HttpRequest(pipelines.Command):
def __init__(self, url: str, headers: {str: str} = None, method: str = None, body: str = None) -> None:
def __init__(self, url: str, headers: Dict[str, str] = None, method: str = None, body: str = None) -> None:
"""
Executes a HTTP request
Expand All @@ -25,7 +27,7 @@ def __init__(self, url: str, headers: {str: str} = None, method: str = None, bod
def shell_command(self):
return http_request_command(self.url, self.headers, self.method, self.body)

def html_doc_items(self) -> [(str, str)]:
def html_doc_items(self) -> List[Tuple[str, str]]:
return [
('method', _.tt[self.method or 'GET']),
('url', _.tt[self.url]),
Expand Down
8 changes: 4 additions & 4 deletions mara_pipelines/commands/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import sys
import json
from html import escape
from typing import Union, Callable, List
from typing import Union, Callable, List, Optional, Tuple
from ..incremental_processing import file_dependencies
from ..logging import logger

Expand All @@ -24,7 +24,7 @@ class RunFunction(pipelines.Command):
Note:
if you want to pass arguments, then use a lambda function
"""
def __init__(self, function: Callable = None, args: [str] = None, file_dependencies: [str] = None) -> None:
def __init__(self, function: Optional[Callable] = None, args: Optional[List[str]] = None, file_dependencies: Optional[List[str]] = None) -> None:
self.function = function
self.args = args or []
self.file_dependencies = file_dependencies or []
Expand All @@ -48,7 +48,7 @@ def run(self) -> bool:

return True

def html_doc_items(self) -> [(str, str)]:
def html_doc_items(self) -> List[Tuple[str, str]]:
return [('function', _.pre[escape(str(self.function))]),
('args', _.tt[repr(self.args)]),
(_.i['implementation'], html.highlight_syntax(inspect.getsource(self.function), 'python')),
Expand All @@ -65,7 +65,7 @@ class ExecutePython(pipelines.Command):
file_dependencies: Run triggered based on whether a list of files changed since the last pipeline run
"""
def __init__(self, file_name: Union[Callable, str],
args: Union[Callable, List[str]] = None, file_dependencies: [str] = None) -> None:
args: Optional[Union[Callable, List[str]]] = None, file_dependencies: Optional[List[str]] = None) -> None:
self._file_name = file_name
self._args = args or []
self.file_dependencies = file_dependencies or []
Expand Down
44 changes: 22 additions & 22 deletions mara_pipelines/commands/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import json
import pathlib
import shlex
from typing import Callable, Union
from typing import Callable, Union, Dict, Optional, List, Tuple

import mara_db.dbs
import mara_db.shell
Expand All @@ -24,8 +24,8 @@ class _SQLCommand(pipelines.Command):
sql_file_name: The name of the file to run (relative to the directory of the parent pipeline)
replace: A set of replacements to perform against the sql query `{'replace`: 'with', ..}`
"""
def __init__(self, sql_statement: Union[Callable, str] = None, sql_file_name: str = None,
replace: {str: str} = None) -> None:
def __init__(self, sql_statement: Optional[Union[Callable, str]] = None, sql_file_name: Optional[str] = None,
replace: Optional[Dict[str, str]] = None) -> None:
if (not (sql_statement or sql_file_name)) or (sql_statement and sql_file_name):
raise ValueError('Please provide either sql_statement or sql_file_name (but not both)')

Expand Down Expand Up @@ -87,9 +87,9 @@ class ExecuteSQL(_SQLCommand):
sql_file_name: The name of the file to run (relative to the directory of the parent pipeline)
replace: A set of replacements to perform against the sql query `{'replace`: 'with', ..}`
"""
def __init__(self, sql_statement: str = None, sql_file_name: Union[str, Callable] = None,
replace: {str: str} = None, file_dependencies=None, db_alias: str = None,
echo_queries: bool = None, timezone: str = None) -> None:
def __init__(self, sql_statement: Optional[Union[str, Callable]] = None, sql_file_name: Optional[str] = None,
replace: Optional[Dict[str, str]] = None, file_dependencies: Optional[List[str]] = None, db_alias: Optional[str] = None,
echo_queries: Optional[bool] = None, timezone: Optional[str] = None) -> None:
_SQLCommand.__init__(self, sql_statement, sql_file_name, replace)

self._db_alias = db_alias
Expand Down Expand Up @@ -146,10 +146,10 @@ def html_doc_items(self):
class Copy(_SQLCommand):
"""Loads data from an external database"""

def __init__(self, source_db_alias: str, target_table: str, target_db_alias: str = None,
sql_statement: str = None, sql_file_name: Union[Callable, str] = None, replace: {str: str} = None,
timezone: str = None, csv_format: bool = None, delimiter_char: str = None,
file_dependencies=None) -> None:
def __init__(self, source_db_alias: str, target_table: str, target_db_alias: Optional[str] = None,
sql_statement: Optional[Union[Callable, str]] = None, sql_file_name: Optional[str] = None, replace: Dict[str, str] = None,
timezone: Optional[str] = None, csv_format: Optional[bool] = None, delimiter_char: Optional[str] = None,
file_dependencies: Optional[List[str]]=None) -> None:
_SQLCommand.__init__(self, sql_statement, sql_file_name, replace)
self.source_db_alias = source_db_alias
self.target_table = target_table
Expand All @@ -160,7 +160,7 @@ def __init__(self, source_db_alias: str, target_table: str, target_db_alias: str
self.file_dependencies = file_dependencies or []

@property
def target_db_alias(self):
def target_db_alias(self) -> str:
return self._target_db_alias or config.default_db_alias()

def file_path(self) -> pathlib.Path:
Expand Down Expand Up @@ -193,12 +193,12 @@ def run(self) -> bool:
file_dependencies.update(self.node_path(), dependency_type, pipeline_base_path, self.file_dependencies)
return True

def shell_command(self):
def shell_command(self) -> str:
return _SQLCommand.shell_command(self) \
+ ' | ' + mara_db.shell.copy_command(self.source_db_alias, self.target_db_alias, self.target_table,
self.timezone, self.csv_format, self.delimiter_char)

def html_doc_items(self) -> [(str, str)]:
def html_doc_items(self) -> List[Tuple[str, str]]:
return [('source db', _.tt[self.source_db_alias])] \
+ _SQLCommand.html_doc_items(self, self.source_db_alias) \
+ [('target db', _.tt[self.target_db_alias]),
Expand Down Expand Up @@ -237,12 +237,12 @@ class CopyIncrementally(_SQLCommand):
"""
def __init__(self, source_db_alias: str, source_table: str,
modification_comparison: str, comparison_value_placeholder: str,
target_table: str, primary_keys: [str],
sql_file_name: Union[str, Callable] = None, sql_statement: Union[str, Callable] = None,
target_db_alias: str = None, timezone: str = None, replace: {str: str} = None,
target_table: str, primary_keys: List[str],
sql_file_name: Optional[str] = None, sql_statement: Optional[Union[str, Callable]] = None,
target_db_alias: Optional[str] = None, timezone: Optional[str] = None, replace: Dict[str, str] = None,
use_explicit_upsert: bool = False,
csv_format: bool = None, delimiter_char: str = None,
modification_comparison_type: str = None) -> None:
csv_format: Optional[bool] = None, delimiter_char: Optional[str] = None,
modification_comparison_type: Optional[str] = None) -> None:
_SQLCommand.__init__(self, sql_statement, sql_file_name, replace)
self.source_db_alias = source_db_alias
self.source_table = source_table
Expand All @@ -259,7 +259,7 @@ def __init__(self, source_db_alias: str, source_table: str,
self.delimiter_char = delimiter_char

@property
def target_db_alias(self):
def target_db_alias(self) -> str:
return self._target_db_alias or config.default_db_alias()

def run(self) -> bool:
Expand Down Expand Up @@ -399,7 +399,7 @@ def _copy_command(self, target_table, replace):
target_table, timezone=self.timezone,
csv_format=self.csv_format, delimiter_char=self.delimiter_char))

def html_doc_items(self) -> [(str, str)]:
def html_doc_items(self) -> List[Tuple[str, str]]:
return [('source db', _.tt[self.source_db_alias]),
('source table', _.tt[self.source_table]),
('modification comparison', _.tt[self.modification_comparison])] \
Expand All @@ -415,13 +415,13 @@ def html_doc_items(self) -> [(str, str)]:
('use explicit upsert', _.tt[repr(self.use_explicit_upsert)])]


def _expand_pattern_substitution(replace: {str: str}) -> {str: str}:
def _expand_pattern_substitution(replace: Dict[str, str]) -> Dict[str, str]:
"""Helper function for replacing callables with their value in a dictionary"""
return {k: (str(v()) if callable(v) else str(v)) for k, v in replace.items()}


@functools.singledispatch
def _sql_syntax_higlighting_lexter(db):
def _sql_syntax_higlighting_lexter(db) -> str:
"""Returns the best lexer from http://pygments.org/docs/lexers/ for a database dialect"""
return 'sql'

Expand Down
4 changes: 2 additions & 2 deletions mara_pipelines/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def last_date() -> datetime.date:
return datetime.date(3000, 1, 1)


def max_number_of_parallel_tasks():
def max_number_of_parallel_tasks() -> int:
"""How many tasks can run in parallel at maximum"""
return multiprocessing.cpu_count()

Expand Down Expand Up @@ -90,7 +90,7 @@ def slack_token() -> typing.Optional[str]:


@functools.lru_cache(maxsize=None)
def event_handlers() -> [events.EventHandler]:
def event_handlers() -> typing.List[events.EventHandler]:
"""
Configure additional event handlers that listen to pipeline events, e.g. chat bots that announce failed runs
Expand Down
19 changes: 10 additions & 9 deletions mara_pipelines/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,17 @@
from multiprocessing import queues
from multiprocessing.context import BaseContext
from queue import Empty
from typing import Set, List, Dict, Optional

from . import pipelines, config
from .logging import logger, pipeline_events, system_statistics, run_log, node_cost
from . import events


def run_pipeline(pipeline: pipelines.Pipeline, nodes: {pipelines.Node} = None,
def run_pipeline(pipeline: pipelines.Pipeline, nodes: Optional[Set[pipelines.Node]] = None,
with_upstreams: bool = False,
interactively_started: bool = False
) -> [events.Event]:
) -> List[events.Event]:
"""
Runs a pipeline in a forked sub process. Acts as a generator that yields events from the sub process.
Expand Down Expand Up @@ -72,20 +73,20 @@ def run():
logger.redirect_output(event_queue, pipeline.path())

# all nodes that have not run yet, ordered by priority
node_queue: [pipelines.Node] = []
node_queue: List[pipelines.Node] = []

# data needed for computing cost
node_durations_and_run_times = node_cost.node_durations_and_run_times(pipeline) if use_historical_node_cost else {}

# Putting nodes into the node queue
def queue(nodes: [pipelines.Node]):
def queue(nodes: List[pipelines.Node]):
for node in nodes:
node_cost.compute_cost(node, node_durations_and_run_times)
node_queue.append(node)
node_queue.sort(key=lambda node: node.cost, reverse=True)

if nodes: # only run a set of child nodes
def with_all_upstreams(nodes: {pipelines.Node}):
def with_all_upstreams(nodes: Set[pipelines.Node]):
"""recursively find all upstreams of a list of nodes"""
return functools.reduce(set.union, [with_all_upstreams(node.upstreams) for node in nodes], nodes)

Expand All @@ -110,11 +111,11 @@ def with_all_upstreams(nodes: {pipelines.Node}):
# book keeping
run_start_time = datetime.datetime.now(tz.utc)
# all nodes that already ran or that won't be run anymore
processed_nodes: {pipelines.Node} = set()
processed_nodes: Set[pipelines.Node] = set()
# running pipelines with start times and number of running children
running_pipelines: {pipelines.Pipeline: [datetime.datetime, int]} = {}
failed_pipelines: {pipelines.Pipeline} = set() # pipelines with failed tasks
running_task_processes: {pipelines.Task: TaskProcess} = {}
running_pipelines: Dict[pipelines.Pipeline, [datetime.datetime, int]] = {}
failed_pipelines: Set[pipelines.Pipeline] = set() # pipelines with failed tasks
running_task_processes: Dict[pipelines.Task, TaskProcess] = {}

# make sure any running tasks are killed when this executor process is shutdown
executor_pid = os.getpid()
Expand Down
Loading

0 comments on commit 1743ec9

Please sign in to comment.