Skip to content

Commit

Permalink
Merge pull request #22 from jojoduquartier/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
jojoduquartier committed Feb 1, 2020
2 parents 307b858 + f8531c8 commit 78a6f35
Show file tree
Hide file tree
Showing 7 changed files with 130 additions and 48 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,7 @@ First release deployed to PyPI

## [Version 1.0.1]
There was a bug with the way credentials were stored. This caused the whole credential branch
for a dialect to be replaced. It has been fixed
for a dialect to be replaced. It has been fixed

## [Version 1.0.2]
The snowflake dialect has been added
37 changes: 21 additions & 16 deletions dsdbmanager/__init__.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,38 @@
import json
from .dbobject import DsDbManager, DbMiddleware
from .configuring import ConfigFilesManager
from .dbobject import DsDbManager, DbMiddleware

__version__ = '1.0.1'
configurer = ConfigFilesManager()
__version__ = '1.0.2'
__configurer__ = ConfigFilesManager()

# first initialize empty files
# if there are no host files, create an empty json file #
if not configurer.host_location.exists():
if not __configurer__.host_location.exists():
try:
configurer.host_location.touch(exist_ok=True)
with configurer.host_location.open('w') as f:
__configurer__.host_location.touch(exist_ok=True)
with __configurer__.host_location.open('w') as f:
json.dump({}, f)
except OSError as e:
raise Exception("Could not write at host file location", e)

# if there are credential files create an empty json file #
if not configurer.credential_location.exists():
configurer.credential_location.touch(exist_ok=True)
if not __configurer__.credential_location.exists():
__configurer__.credential_location.touch(exist_ok=True)
try:
with configurer.credential_location.open('w') as f:
with __configurer__.credential_location.open('w') as f:
json.dump({}, f)
except OSError as e:
raise Exception("Could not write at credential file location", e)

# if there are no keys, create one and store it at key location #
if not configurer.key_location.exists():
configurer.key_location.touch(exist_ok=True)
configurer.key_location.write_bytes(configurer.generate_key())
if not __configurer__.key_location.exists():
__configurer__.key_location.touch(exist_ok=True)
__configurer__.key_location.write_bytes(__configurer__.generate_key())

# functions for users to use
add_database = configurer.add_new_database_info
remove_database = configurer.remove_database
reset_credentials = configurer.reset_credentials
add_database = __configurer__.add_new_database_info
remove_database = __configurer__.remove_database
reset_credentials = __configurer__.reset_credentials


# easy access for databases
Expand All @@ -52,6 +52,10 @@ def mssql():
return DsDbManager('mssql')


def snowflake():
return DsDbManager('snowflake')


def from_engine(engine, schema: str = None):
"""
Main objective is to use this to create DbMiddleware objects on sqlite engines for quick testing purposes
Expand All @@ -76,7 +80,8 @@ def from_engine(engine, schema: str = None):
>>> df1: pd.DataFrame = db.test()
>>> df1.equals(df)
True
>>>
>>> engine.dispose()
"""
return DbMiddleware(engine, False, schema)
2 changes: 1 addition & 1 deletion dsdbmanager/configuring.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def add_new_database_info(self):
return None

# additional_infos
host = click.prompt("Host/Database Address", type=str)
host = click.prompt("Host/Database Address or Snowflake Account", type=str)
schema = click.prompt("Schema - Enter if none", default='', type=str)
sid = click.prompt("SID - Enter if none", default='', type=str)
service_name = click.prompt("Service Name - Enter if none", default='', type=str)
Expand Down
2 changes: 1 addition & 1 deletion dsdbmanager/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
KEY_PATH = config_folder / ".configkey"

# database flavors
FLAVORS_FOR_CONFIG = ('oracle', 'mysql', 'mssql', 'teradata',)
FLAVORS_FOR_CONFIG = ('oracle', 'mysql', 'mssql', 'teradata', 'snowflake')

CACHE_SIZE = 64
CHUNK_SIZE = 30000
68 changes: 40 additions & 28 deletions dsdbmanager/dbobject.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .mysql_ import Mysql
from .oracle_ import Oracle
from .teradata_ import Teradata
from .snowflake_ import Snowflake
from sqlalchemy.engine import reflection
from .configuring import ConfigFilesManager
from .utils import d_frame, inspect_table, filter_maker
Expand All @@ -30,7 +31,8 @@
Oracle,
Teradata,
Mysql,
Mssql
Mssql,
Snowflake
]


Expand Down Expand Up @@ -66,21 +68,22 @@ def insert_into_table(df: pd.DataFrame, table_name: str, engine: sa.engine.Engin

# insert
count, last_successful_insert = 0, None
for group in groups:
try:
result = engine.execute(tbl.insert(), group)
last_successful_insert = group[-1]
count += result.rowcount
except exc.OperationalError as _:
"Try Again"
time.sleep(2)

with engine.connect() as connection:
for group in groups:
try:
result = engine.execute(tbl.insert(), group)
result = connection.execute(tbl.insert(), group)
last_successful_insert = group[-1]
count += result.rowcount
except exc.OperationalError as e:
raise OperationalError(f"Failed to insert records. Last successful{last_successful_insert}", e)
except exc.OperationalError as _:
"Try Again"
time.sleep(2)

try:
result = connection.execute(tbl.insert(), group)
last_successful_insert = group[-1]
count += result.rowcount
except exc.OperationalError as e:
raise OperationalError(f"Failed to insert records. Last successful{last_successful_insert}", e)

return count

Expand Down Expand Up @@ -130,21 +133,24 @@ def update_on_table(df: pd.DataFrame, keys: update_key_type, values: update_key_

# update
count, last_successful_update = 0, None
for group in groups:
try:
result = engine.execute(update_statement, group)
last_successful_update = group[-1]
count += result.rowcount
except exc.OperationalError as _:
# try again
time.sleep(2)

with engine.connect() as connection:
for group in groups:
try:
result = engine.execute(update_statement, group)
result = connection.execute(update_statement, group)
last_successful_update = group[-1]
count += result.rowcount
except exc.OperationalError as e:
raise OperationalError(f"Failed to update records. Last successful update: {last_successful_update}", e)
except exc.OperationalError as _:
# try again
time.sleep(2)

try:
result = connection.execute(update_statement, group)
last_successful_update = group[-1]
count += result.rowcount
except exc.OperationalError as e:
raise OperationalError(
f"Failed to update records. Last successful update: {last_successful_update}", e
)

return count

Expand Down Expand Up @@ -199,7 +205,8 @@ def wrapped(
query = query.where(sa.and_(*filters))

# execute
results = engine.execute(query)
with engine.connect() as connection:
results = connection.execute(query)

# fetch
if rows is not None:
Expand Down Expand Up @@ -299,7 +306,7 @@ def __exit__(self, exc_type, exc_val, exc_tb):
@toolz.curry
def db_middleware(config_manager: ConfigFilesManager, flavor: str, db_name: str,
connection_object: connection_object_type, config_schema: str, connect_only: bool,
schema: str = None) -> DbMiddleware:
schema: str = None, **engine_kwargs) -> DbMiddleware:
"""
Try connecting to the database. Write credentials on success. Using a function only so that the connection
is only attempted when function is called.
Expand All @@ -310,6 +317,7 @@ def db_middleware(config_manager: ConfigFilesManager, flavor: str, db_name: str,
:param config_schema: the schema provided when adding database
:param connect_only: True if all we want is connect and not inspect for tables or views
:param schema: if user wants to specify a different schema than the one supplied when adding database
:param engine_kwargs: engine arguments, like echo, or warehouse, schema and role for snowflake
:return:
"""

Expand All @@ -323,7 +331,8 @@ def db_middleware(config_manager: ConfigFilesManager, flavor: str, db_name: str,

engine: sa.engine.base.Engine = connection_object.create_engine(
config_manager.encrypt_decrypt(username, encrypt=False).decode("utf-8"),
config_manager.encrypt_decrypt(password, encrypt=False).decode("utf-8")
config_manager.encrypt_decrypt(password, encrypt=False).decode("utf-8"),
**engine_kwargs
)

try:
Expand Down Expand Up @@ -401,6 +410,9 @@ def _connection_object_creator(self, db_name: str):
if self._flavor.lower() == 'mysql':
return Mysql(db_name, self._host_dict)

if self._flavor.lower() == 'snowflake':
return Snowflake(db_name, self._host_dict)

def __getitem__(self, item):
return self.__dict__[item]

Expand Down
52 changes: 52 additions & 0 deletions dsdbmanager/snowflake_.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import typing
import sqlalchemy as sa
from .configuring import ConfigFilesManager
from .exceptions_ import MissingFlavor, MissingDatabase, MissingPackage

host_type = typing.Dict[str, typing.Dict[str, typing.Dict[str, str]]]


class Snowflake:
def __init__(self, db_name: str, host_dict: host_type = None):
"""
:param db_name: database name
:param host_dict: optional database info with host, ports etc
"""
self.db_name = db_name
self.host_dict: host_type = ConfigFilesManager().get_hosts() if not host_dict else host_dict

if not self.host_dict or 'snowflake' not in self.host_dict:
raise MissingFlavor("No databases available for snowflake", None)

self.host_dict = self.host_dict.get('snowflake').get(self.db_name, {})

if not self.host_dict:
raise MissingDatabase(f"{self.db_name} has not been added for snowflake", None)

def create_engine(self, user: str = None, pwd: str = None, **kwargs):
"""
:param user: username
:param pwd: password
:param kwargs: for compatibility/additional sqlalchemy create_engine kwargs or things like role, warehouse etc.
:return: sqlalchemy engine
"""
try:
from snowflake.sqlalchemy import URL
except ImportError as e:
raise MissingPackage("You need the snowflake-sqlalchemy package to initiate connection", e)

host = self.host_dict.get('host')

url = URL(
account=host,
user=user,
password=pwd,
database=self.db_name,
**kwargs
)

# TODO - find a way to identify kwrags consumed by URL and pull them out of kwargs and pass the rest below

return sa.create_engine(url)
12 changes: 11 additions & 1 deletion test/test_connectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dsdbmanager.mysql_ import Mysql
from dsdbmanager.oracle_ import Oracle
from dsdbmanager.teradata_ import Teradata
from dsdbmanager.snowflake_ import Snowflake
from dsdbmanager.exceptions_ import MissingFlavor, MissingDatabase, MissingPackage


Expand Down Expand Up @@ -36,7 +37,14 @@ def setUpClass(cls):
'host': 'somehost',
'port': 0000
}
}
},
'snowflake': {
'database1': {
'name': 'database1',
'host': 'somehost',
'port': 0000
}
},
}

@classmethod
Expand All @@ -50,12 +58,14 @@ def test_connectors(self):
'mysql',
'mssql',
'teradata',
'snowflake'
],
[
Oracle,
Mysql,
Mssql,
Teradata,
Snowflake
]):
with self.subTest(flavor=name):
# test with host improper host file. should raise MissingFlavor
Expand Down

0 comments on commit 78a6f35

Please sign in to comment.