Skip to content

Commit

Permalink
Merge pull request #545 from great-expectations/feature/cli_add_datas…
Browse files Browse the repository at this point in the history
…ource_resilience_201907

Improved error detection and handling in CLI "add datasource" feature
  • Loading branch information
eugmandel committed Jul 16, 2019
2 parents 5ce4d49 + a23e3df commit cb01acc
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 53 deletions.
4 changes: 4 additions & 0 deletions great_expectations/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,10 @@ def init(target_directory):
)

data_source_name = add_datasource(context)

if not data_source_name: # no datasource was created
return

cli_message(
"""
========== Profiling ==========
Expand Down
80 changes: 53 additions & 27 deletions great_expectations/cli/datasource.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import os
import click

from .util import cli_message
from great_expectations.render import DefaultJinjaPageView
from great_expectations.exceptions import DatasourceInitializationError

from great_expectations.version import __version__ as __version__


Expand All @@ -27,7 +28,7 @@ def add_datasource(context):
msg_prompt_filesys_enter_base_path,
# default='/data/',
type=click.Path(
exists=False,
exists=True,
file_okay=False,
dir_okay=True,
readable=True
Expand Down Expand Up @@ -56,34 +57,56 @@ def add_datasource(context):
data_source_name = click.prompt(
msg_prompt_datasource_name, default="mydb", show_default=True)

cli_message(msg_sqlalchemy_config_connection.format(
data_source_name))
while True:
cli_message(msg_sqlalchemy_config_connection.format(
data_source_name))

drivername = click.prompt("What is the driver for the sqlalchemy connection?", default="postgres",
show_default=True)
host = click.prompt("What is the host for the sqlalchemy connection?", default="localhost",
show_default=True)
port = click.prompt("What is the port for the sqlalchemy connection?", default="5432",
show_default=True)
username = click.prompt("What is the username for the sqlalchemy connection?", default="postgres",
drivername = click.prompt("What is the driver for the sqlalchemy connection?", default="postgres",
show_default=True)
host = click.prompt("What is the host for the sqlalchemy connection?", default="localhost",
show_default=True)
password = click.prompt("What is the password for the sqlalchemy connection?", default="",
show_default=False, hide_input=True)
database = click.prompt("What is the database name for the sqlalchemy connection?", default="postgres",
port = click.prompt("What is the port for the sqlalchemy connection?", default="5432",
show_default=True)
username = click.prompt("What is the username for the sqlalchemy connection?", default="postgres",
show_default=True)
password = click.prompt("What is the password for the sqlalchemy connection?", default="",
show_default=False, hide_input=True)
database = click.prompt("What is the database name for the sqlalchemy connection?", default="postgres",
show_default=True)

credentials = {
"drivername": drivername,
"host": host,
"port": port,
"username": username,
"password": password,
"database": database
}
context.add_profile_credentials(data_source_name, **credentials)

try:
context.add_datasource(
data_source_name, "sqlalchemy", profile=data_source_name)
break
except (DatasourceInitializationError, ModuleNotFoundError) as de:
cli_message(
"""
Cannot connect to the database. Please check your environment and the configuration you provided.
credentials = {
"drivername": drivername,
"host": host,
"port": port,
"username": username,
"password": password,
"database": database
}
context.add_profile_credentials(data_source_name, **credentials)
<red>Actual error: {0:s}</red>>
""".format(str(de)))
if not click.confirm(
"""
Enter the credentials again?
""".format(str(de)),
default=True):
cli_message(
"""
Exiting datasource configuration.
You can add a datasource later by editing the great_expectations.yml file.
""")
return None

context.add_datasource(
data_source_name, "sqlalchemy", profile=data_source_name)

elif data_source_selection == "3": # Spark
path = click.prompt(
Expand Down Expand Up @@ -114,8 +137,11 @@ def add_datasource(context):
# context.add_datasource("dbt", "dbt", profile=dbt_profile)
if data_source_selection == "4": # None of the above
cli_message(msg_unknown_data_source)
print("Skipping datasource configuration. "
"You can add a datasource later by editing the great_expectations.yml file.")
cli_message(
"""
Skipping datasource configuration.
You can add a datasource later by editing the great_expectations.yml file.
""")
return None

return data_source_name
Expand Down
5 changes: 5 additions & 0 deletions great_expectations/cli/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,10 @@ def cli_message(string):
colored("\g<1>", "yellow"),
mod_string
)
mod_string = re.sub(
"<red>(.*?)</red>",
colored("\g<1>", "red"),
mod_string
)

six.print_(colored(mod_string))
2 changes: 1 addition & 1 deletion great_expectations/data_context/data_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def __init__(self, context_root_dir=None, expectation_explorer=False, data_asset

self._project_config = self._load_project_config()

if "datasources" not in self._project_config:
if not self._project_config.get("datasources"):
self._project_config["datasources"] = {}
for datasource in self._project_config["datasources"].keys():
self.get_datasource(datasource)
Expand Down
41 changes: 26 additions & 15 deletions great_expectations/datasource/sqlalchemy_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from .datasource import Datasource
from great_expectations.dataset.sqlalchemy_dataset import SqlAlchemyDataset
from .generator.query_generator import QueryGenerator
from great_expectations.exceptions import DatasourceInitializationError

logger = logging.getLogger(__name__)

Expand All @@ -29,6 +30,9 @@ class SqlAlchemyDatasource(Datasource):
"""

def __init__(self, name="default", data_context=None, profile=None, generators=None, **kwargs):
if not sqlalchemy:
raise DatasourceInitializationError(name, "ModuleNotFoundError: No module named 'sqlalchemy'")

if generators is None:
generators = {
"default": {"type": "queries"}
Expand All @@ -41,21 +45,27 @@ def __init__(self, name="default", data_context=None, profile=None, generators=N
self._datasource_config.update({
"profile": profile
})
# if an engine was provided, use that
if "engine" in kwargs:
self.engine = kwargs.pop("engine")

# if a connection string or url was provided, use that
elif "connection_string" in kwargs:
connection_string = kwargs.pop("connection_string")
self.engine = create_engine(connection_string, **kwargs)
elif "url" in kwargs:
url = kwargs.pop("url")
self.engine = create_engine(url, **kwargs)

# Otherwise, connect using remaining kwargs
else:
self._connect(self._get_sqlalchemy_connection_options(**kwargs))

try:
# if an engine was provided, use that
if "engine" in kwargs:
self.engine = kwargs.pop("engine")

# if a connection string or url was provided, use that
elif "connection_string" in kwargs:
connection_string = kwargs.pop("connection_string")
self.engine = create_engine(connection_string, **kwargs)
self.engine.connect()
elif "url" in kwargs:
url = kwargs.pop("url")
self.engine = create_engine(url, **kwargs)
self.engine.connect()

# Otherwise, connect using remaining kwargs
else:
self._connect(self._get_sqlalchemy_connection_options(**kwargs))
except sqlalchemy.exc.OperationalError as sqlalchemy_error:
raise DatasourceInitializationError(self._name, str(sqlalchemy_error))

self._build_generators()

Expand All @@ -74,6 +84,7 @@ def _get_sqlalchemy_connection_options(self, **kwargs):

def _connect(self, options):
self.engine = create_engine(options)
self.engine.connect()
self.meta = MetaData()

def _get_generator_class(self, type_):
Expand Down
7 changes: 6 additions & 1 deletion great_expectations/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,9 @@ def __init__(self, data_asset_name):
class BatchKwargsError(DataContextError):
def __init__(self, message, batch_kwargs):
self.message = message
self.batch_kwargs = batch_kwargs
self.batch_kwargs = batch_kwargs

class DatasourceInitializationError(GreatExpectationsError):
def __init__(self, datasource_name, message):
self.message = "Cannot initialize datasource %s, error: %s" % (datasource_name, message)

18 changes: 9 additions & 9 deletions tests/datasource/test_datasources.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,18 +107,18 @@ def test_create_sqlalchemy_datasource(data_context):
type_ = "sqlalchemy"
connection_kwargs = {
"drivername": "postgresql",
"username": "user",
"password": "pass",
"host": "host",
"port": 1234,
"database": "db",
"username": "",
"password": "",
"host": "localhost",
"port": 5432,
"database": "test_ci",
}

# It should be possible to create a sqlalchemy source using these params without
# saving a profile
data_context.add_datasource(name, type_, **connection_kwargs)
data_context_config = data_context.get_config()
assert name in data_context_config["datasources"]
assert name in data_context_config["datasources"]
assert data_context_config["datasources"][name]["type"] == type_

# We should be able to get it in this session even without saving the config
Expand All @@ -131,9 +131,9 @@ def test_create_sqlalchemy_datasource(data_context):
# But we should be able to add a source using a profile
name = "second_source"
data_context.add_datasource(name, type_, profile="test_sqlalchemy_datasource")

data_context_config = data_context.get_config()
assert name in data_context_config["datasources"]
assert name in data_context_config["datasources"]
assert data_context_config["datasources"][name]["type"] == type_
assert data_context_config["datasources"][name]["profile"] == profile_name

Expand All @@ -143,7 +143,7 @@ def test_create_sqlalchemy_datasource(data_context):
# Finally, we should be able to confirm that the folder structure is as expected
with open(os.path.join(data_context.root_directory, "uncommitted/credentials/profiles.yml"), "r") as profiles_file:
profiles = yaml.load(profiles_file)

assert profiles == {
profile_name: dict(**connection_kwargs)
}
Expand Down

0 comments on commit cb01acc

Please sign in to comment.