Skip to content

Commit

Permalink
Merge pull request #106 from dlt-hub/rfix/renames-sources-dlt-init
Browse files Browse the repository at this point in the history
dlt init renames sources and resources
  • Loading branch information
rudolfix committed Dec 5, 2022
2 parents cfcc6a6 + 1892aff commit af4ce68
Show file tree
Hide file tree
Showing 10 changed files with 166 additions and 92 deletions.
66 changes: 44 additions & 22 deletions dlt/cli/init_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import ast
import click
import shutil
from astunparse import unparse
from types import ModuleType
from typing import Dict, List, Tuple
from importlib.metadata import version as pkg_version
Expand All @@ -27,10 +28,10 @@
from dlt.cli.exceptions import CliCommandException


def _find_argument_nodes_to_replace(visitor: PipelineScriptVisitor, replace_nodes: List[Tuple[str, str]], init_script_name: str) -> List[Tuple[ast.Constant, str, str]]:
def _find_call_arguments_to_replace(visitor: PipelineScriptVisitor, replace_nodes: List[Tuple[str, str]], init_script_name: str) -> List[Tuple[ast.AST, ast.AST, str]]:
# the input tuple (call argument name, replacement value)
# the returned tuple (node, replacement value, node type)
transformed_nodes: List[Tuple[ast.Constant, str, str]] = []
transformed_nodes: List[Tuple[ast.AST, ast.AST, str]] = []
known_calls: Dict[str, List[inspect.BoundArguments]] = visitor.known_calls
for arg_name, calls in known_calls.items():
for args in calls:
Expand All @@ -40,7 +41,7 @@ def _find_argument_nodes_to_replace(visitor: PipelineScriptVisitor, replace_node
if not isinstance(dn_node, ast.Constant) or not isinstance(dn_node.value, str):
raise CliCommandException("init", f"The pipeline script {init_script_name} must pass the {t_arg_name} as string to '{arg_name}' function in line {dn_node.lineno}")
else:
transformed_nodes.append((dn_node, t_value, t_arg_name))
transformed_nodes.append((dn_node, ast.Constant(value=t_value, kind=None), t_arg_name))

# there was at least one replacement
for t_arg_name, _ in replace_nodes:
Expand All @@ -49,16 +50,34 @@ def _find_argument_nodes_to_replace(visitor: PipelineScriptVisitor, replace_node
return transformed_nodes


def _detect_required_configs(visitor: PipelineScriptVisitor, script_module: ModuleType, init_script_name: str) -> Tuple[Dict[str, WritableConfigValue], Dict[str, WritableConfigValue]]:
def _find_source_calls_to_replace(visitor: PipelineScriptVisitor, pipeline_name: str) -> List[Tuple[ast.AST, ast.AST, str]]:
transformed_nodes: List[Tuple[ast.AST, ast.AST, str]] = []
for source_def in visitor.known_sources_resources.values():
# recreate function name as a ast.Name with known source code location
func_name = ast.Name(source_def.name)
func_name.lineno = func_name.end_lineno = source_def.lineno
func_name.col_offset = visitor.source_lines[func_name.lineno - 1].index(source_def.name) # find where function name starts
func_name.end_col_offset = func_name.col_offset + len(source_def.name)
# append function name to be replaces
transformed_nodes.append((func_name, ast.Name(id=pipeline_name + "_" + source_def.name), ""))

for calls in visitor.known_sources_resources_calls.values():
for call in calls:
transformed_nodes.append((call.func, ast.Name(id=pipeline_name + "_" + unparse(call.func)), ""))

return transformed_nodes


def _detect_required_configs(visitor: PipelineScriptVisitor) -> Tuple[Dict[str, WritableConfigValue], Dict[str, WritableConfigValue]]:
# all detected secrets with namespaces
required_secrets: Dict[str, WritableConfigValue] = {}
# all detected configs with namespaces
required_config: Dict[str, WritableConfigValue] = {}

# skip sources without spec. those are not imported and most probably are inner functions. also skip the sources that are not called
# also skip the sources that are called from functions, the parent of call object to the source must be None (no outer function)
known_imported_sources = {name: _SOURCES[name] for name in visitor.known_sources
if name in _SOURCES and name in visitor.known_source_calls and any(call.parent is None for call in visitor.known_source_calls[name])} # type: ignore
known_imported_sources = {name: _SOURCES[name] for name in visitor.known_sources_resources
if name in _SOURCES and name in visitor.known_sources_resources_calls and any(call.parent is None for call in visitor.known_sources_resources_calls[name])} # type: ignore

for source_name, source_info in known_imported_sources.items():
source_config = source_info.SPEC()
Expand All @@ -80,8 +99,7 @@ def _detect_required_configs(visitor: PipelineScriptVisitor, script_module: Modu
return required_secrets, required_config


def _rewrite_script(script_source: str, transformed_nodes: List[Tuple[ast.Constant, str, str]]) -> str:
module_source_lines: List[str] = ast._splitlines_no_ff(script_source) # type: ignore
def _rewrite_script(source_script_lines: List[str], transformed_nodes: List[Tuple[ast.AST, ast.AST, str]]) -> str:
script_lines: List[str] = []
last_line = -1
last_offset = -1
Expand All @@ -91,24 +109,24 @@ def _rewrite_script(script_source: str, transformed_nodes: List[Tuple[ast.Consta
if last_line != node.lineno - 1:
# add remainder from the previous line
if last_offset >= 0:
script_lines.append(module_source_lines[last_line][last_offset:])
script_lines.append(source_script_lines[last_line][last_offset:])
# add all new lines from previous line to current
script_lines.extend(module_source_lines[last_line+1:node.lineno-1])
script_lines.extend(source_script_lines[last_line+1:node.lineno-1])
# add trailing characters until node in current line starts
script_lines.append(module_source_lines[node.lineno-1][:node.col_offset])
script_lines.append(source_script_lines[node.lineno-1][:node.col_offset])
elif last_offset >= 0:
# no line change, add the characters from the end of previous node to the current
script_lines.append(module_source_lines[last_line][last_offset:node.col_offset])
script_lines.append(source_script_lines[last_line][last_offset:node.col_offset])

# replace node value
script_lines.append(f'"{t_value}"')
script_lines.append(unparse(t_value).strip())
last_line = node.end_lineno - 1
last_offset = node.end_col_offset

# add all that was missing
if last_offset >= 0:
script_lines.append(module_source_lines[last_line][last_offset:])
script_lines.extend(module_source_lines[last_line+1:])
script_lines.append(source_script_lines[last_line][last_offset:])
script_lines.extend(source_script_lines[last_line+1:])

dest_script = "".join(script_lines)
# validate by parsing
Expand Down Expand Up @@ -138,8 +156,8 @@ def init_command(pipeline_name: str, destination_name: str, use_generic_template

# get init script variant or the default
init_script_name = os.path.join("variants", pipeline_name + ".py")
if clone_storage.has_file(init_script_name):
# use variant
is_variant = clone_storage.has_file(init_script_name)
if is_variant:
dest_pipeline_script = pipeline_name + ".py"
click.echo(f"Using a verified pipeline {fmt.bold(dest_pipeline_script)}")
if use_generic_template:
Expand Down Expand Up @@ -176,30 +194,34 @@ def init_command(pipeline_name: str, destination_name: str, use_generic_template
raise CliCommandException("init", f"The pipeline script {init_script_name} does not seem to initialize pipeline with dlt.pipeline. Please initialize pipeline explicitly in init scripts.")

# find all arguments in all calls to replace
transformed_nodes = _find_argument_nodes_to_replace(
transformed_nodes = _find_call_arguments_to_replace(
visitor,
[("destination", destination_name), ("pipeline_name", pipeline_name), ("dataset_name", pipeline_name + "_data")],
init_script_name
)

# inspect the script
script_module = inspect_pipeline_script(clone_storage.storage_path, clone_storage.to_relative_path(init_script_name))
inspect_pipeline_script(clone_storage.storage_path, clone_storage.to_relative_path(init_script_name))

if len(_SOURCES) == 0:
raise CliCommandException("init", f"The pipeline script {init_script_name} is not creating or importing any sources or resources")

for source_q_name, source_config in _SOURCES.items():
if source_q_name not in visitor.known_sources:
if source_q_name not in visitor.known_sources_resources:
raise CliCommandException("init", f"The pipeline script {init_script_name} imports a source/resource {source_config.f.__name__} from module {source_config.module.__name__}. In init scripts you must declare all sources and resources in single file.")

# rename sources and resources
if not is_variant:
transformed_nodes.extend(_find_source_calls_to_replace(visitor, pipeline_name))

# detect all the required secrets and configs that should go into tomls files
required_secrets, required_config = _detect_required_configs(visitor, script_module, init_script_name)
required_secrets, required_config = _detect_required_configs(visitor)
# add destination spec to required secrets
credentials_type = destination_spec().get_resolvable_fields()["credentials"]
required_secrets["destinations:" + destination_name] = WritableConfigValue("credentials", credentials_type, ("destination", destination_name))

# modify the script
dest_script_source = _rewrite_script(visitor.source, transformed_nodes)
dest_script_source = _rewrite_script(visitor.source_lines, transformed_nodes)

# welcome message
click.echo()
Expand Down
2 changes: 1 addition & 1 deletion dlt/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def parse_init_script(command: str, script_source: str, init_script_name: str) -
tree = ast.parse(source=script_source)
set_ast_parents(tree)
visitor = PipelineScriptVisitor(script_source)
visitor.visit(tree)
visitor.visit_passes(tree)
if len(visitor.mod_aliases) == 0:
raise CliCommandException(command, f"The pipeline script {init_script_name} does not import dlt and does not seem to run any pipelines")

Expand Down
4 changes: 2 additions & 2 deletions dlt/destinations/bigquery/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from dlt.common.configuration.specs import GcpClientCredentials

from dlt.destinations.typing import DBCursor
from dlt.destinations.sql_client import SqlClientBase, raise_database_error
from dlt.destinations.sql_client import SqlClientBase, raise_database_error, raise_open_connection_error

# terminal reasons as returned in BQ gRPC error response
# https://cloud.google.com/bigquery/docs/error-messages
Expand All @@ -30,6 +30,7 @@ def __init__(self, dataset_name: str, credentials: GcpClientCredentials) -> None
self.default_retry = bigquery.DEFAULT_RETRY.with_deadline(credentials.retry_deadline)
self.default_query = bigquery.QueryJobConfig(default_dataset=self.fully_qualified_dataset_name())

@raise_open_connection_error
def open_connection(self) -> None:
self._client = bigquery.Client(
self.credentials.project_id,
Expand Down Expand Up @@ -84,7 +85,6 @@ def drop_dataset(self) -> None:
timeout=self.credentials.http_timeout
)

# @raise_database_error
def execute_sql(self, sql: AnyStr, *args: Any, **kwargs: Any) -> Optional[Sequence[Sequence[Any]]]:
with self.execute_query(sql, *args, **kwargs) as curr:
if not curr.description:
Expand Down
14 changes: 11 additions & 3 deletions dlt/destinations/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,18 @@ def __init__(self, dbapi_exception: Exception) -> None:
super().__init__(dbapi_exception)


class LoadClientNoConnection(DestinationTransientException):
def __init__(self, client_type: str) -> None:
class DestinationConnectionError(DestinationTransientException):
def __init__(self, client_type: str, dataset_name: str, reason: str, inner_exc: Exception) -> None:
self.client_type = client_type
super().__init__(f"Connection in sql client {client_type} is closed. Open the connection with 'client.open_connection' or with the 'with client:' statement")
self.dataset_name = dataset_name
self.inner_exc = inner_exc
super().__init__(f"Connection with {client_type} to dataset name {dataset_name} failed. Please check if you configured the credentials at all and provided the right credentials values. You can be also denied access or your internet connection may be down. The actual reason given is: {reason}")

class LoadClientNotConnected(DestinationTransientException):
def __init__(self, client_type: str, dataset_name: str) -> None:
self.client_type = client_type
self.dataset_name = dataset_name
super().__init__(f"Connection with {client_type} to dataset {dataset_name} is closed. Open the connection with 'client.open_connection' or with the 'with client:' statement")


class DestinationSchemaWillNotUpdate(DestinationTerminalException):
Expand Down
3 changes: 2 additions & 1 deletion dlt/destinations/postgres/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from dlt.destinations.exceptions import DatabaseTerminalException, DatabaseTransientException, DatabaseUndefinedRelation
from dlt.destinations.typing import DBCursor
from dlt.destinations.sql_client import SqlClientBase, raise_database_error
from dlt.destinations.sql_client import SqlClientBase, raise_database_error, raise_open_connection_error


class Psycopg2SqlClient(SqlClientBase["psycopg2.connection"]):
Expand All @@ -32,6 +32,7 @@ def open_connection(self) -> None:
# we'll provide explicit transactions see _reset
self._reset_connection()

@raise_open_connection_error
def close_connection(self) -> None:
if self._conn:
self._conn.close()
Expand Down
16 changes: 14 additions & 2 deletions dlt/destinations/sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from types import TracebackType
from typing import Any, ContextManager, Generic, Iterator, Optional, Sequence, Type, AnyStr
from dlt.common.typing import TFun
from dlt.destinations.exceptions import LoadClientNoConnection
from dlt.destinations.exceptions import DestinationConnectionError, LoadClientNotConnected

from dlt.destinations.typing import TNativeConn, DBCursor

Expand Down Expand Up @@ -75,7 +75,7 @@ def with_alternative_dataset_name(self, dataset_name: str) -> Iterator["SqlClien

def _ensure_native_conn(self) -> None:
if not self.native_connection:
raise LoadClientNoConnection(type(self).__name__ + ":" + self.dataset_name)
raise LoadClientNotConnected(type(self).__name__ , self.dataset_name)

@staticmethod
@abstractmethod
Expand Down Expand Up @@ -111,3 +111,15 @@ def _wrap(self: SqlClientBase[Any], *args: Any, **kwargs: Any) -> Any:
return _wrap_gen # type: ignore
else:
return _wrap # type: ignore


def raise_open_connection_error(f: TFun) -> TFun:

@wraps(f)
def _wrap(self: SqlClientBase[Any], *args: Any, **kwargs: Any) -> Any:
try:
return f(self, *args, **kwargs)
except Exception as ex:
raise DestinationConnectionError(type(self).__name__, self.dataset_name, str(ex), ex)

return _wrap # type: ignore
2 changes: 1 addition & 1 deletion dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ class Pipeline:
pipeline_name: str
default_schema_name: str = None
schema_names: List[str] = []
first_run: bool = None
first_run: bool = False
"""Indicates a first run of the pipeline, where run ends with successful loading of data"""
full_refresh: bool
must_attach_to_local_pipeline: bool
Expand Down
Loading

0 comments on commit af4ce68

Please sign in to comment.