Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
313 additions
and
69 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
from cryptography.hazmat.backends import default_backend | ||
from cryptography.hazmat.primitives import serialization | ||
|
||
from pypika import ( | ||
VerticaQuery, | ||
functions as fn, | ||
terms, | ||
) | ||
from .base import Database | ||
|
||
try: | ||
from snowflake import connector as snowflake | ||
except: | ||
pass | ||
|
||
|
||
class Trunc(terms.Function): | ||
""" | ||
Wrapper for Vertica TRUNC function for truncating dates. | ||
""" | ||
|
||
def __init__(self, field, date_format, alias=None): | ||
super(Trunc, self).__init__('TRUNC', field, date_format, alias=alias) | ||
# Setting the fields here means we can access the TRUNC args by name. | ||
self.field = field | ||
self.date_format = date_format | ||
self.alias = alias | ||
|
||
|
||
class SnowflakeDatabase(Database): | ||
""" | ||
Vertica client that uses the vertica_python driver. | ||
""" | ||
|
||
# The pypika query class to use for constructing queries | ||
query_cls = VerticaQuery | ||
|
||
DATETIME_INTERVALS = { | ||
'hour': 'HH', | ||
'day': 'DD', | ||
'week': 'IW', | ||
'month': 'MM', | ||
'quarter': 'Q', | ||
'year': 'Y' | ||
} | ||
_private_key = None | ||
|
||
def __init__(self, user='snowflake', password=None, | ||
account='snowflake', database='snowflake', | ||
private_key_data=None, private_key_password=None, | ||
region=None, warehouse=None, | ||
max_processes=1, cache_middleware=None): | ||
super(SnowflakeDatabase, self).__init__(database=database, | ||
max_processes=max_processes, | ||
cache_middleware=cache_middleware) | ||
self.user = user | ||
self.password = password | ||
self.account = account | ||
self.private_key_data = private_key_data | ||
self.private_key_password = private_key_password | ||
self.region = region | ||
self.warehouse = warehouse | ||
|
||
def connect(self): | ||
import snowflake | ||
|
||
return snowflake.connector.connect(database=self.database, | ||
account=self.account, | ||
user=self.user, | ||
password=self.password, | ||
private_key=self._get_private_key(), | ||
region=self.region, | ||
warehouse=self.warehouse) | ||
|
||
def trunc_date(self, field, interval): | ||
trunc_date_interval = self.DATETIME_INTERVALS.get(str(interval), 'DD') | ||
return Trunc(field, trunc_date_interval) | ||
|
||
def date_add(self, field, date_part, interval): | ||
return fn.TimestampAdd(str(date_part), interval, field) | ||
|
||
def _get_private_key(self): | ||
if self._private_key is None: | ||
self._private_key = self._load_private_key_data() | ||
|
||
return self._private_key | ||
|
||
def _load_private_key_data(self): | ||
if self.private_key_data is None: | ||
return None | ||
|
||
private_key_password = None \ | ||
if self.private_key_password is None \ | ||
else self.private_key_password.encode() | ||
|
||
pkey = serialization.load_pem_private_key(self.private_key_data.encode(), | ||
private_key_password, | ||
backend=default_backend()) | ||
|
||
return pkey.private_bytes(encoding=serialization.Encoding.DER, | ||
format=serialization.PrivateFormat.PKCS8, | ||
encryption_algorithm=serialization.NoEncryption()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
from unittest import TestCase | ||
from unittest.mock import ( | ||
ANY, | ||
Mock, | ||
patch, | ||
) | ||
|
||
from fireant.database import SnowflakeDatabase | ||
from fireant.slicer import * | ||
from pypika import Field | ||
|
||
|
||
class TestSnowflake(TestCase): | ||
def test_defaults(self): | ||
snowflake = SnowflakeDatabase() | ||
|
||
self.assertEqual('snowflake', snowflake.account) | ||
self.assertEqual('snowflake', snowflake.database) | ||
self.assertEqual('snowflake', snowflake.user) | ||
self.assertIsNone(snowflake.password) | ||
self.assertIsNone(snowflake.private_key_data) | ||
self.assertIsNone(snowflake.private_key_password) | ||
self.assertIsNone(snowflake.region) | ||
self.assertIsNone(snowflake.warehouse) | ||
|
||
def test_connect_with_password(self): | ||
mock_snowflake = Mock(name='mock_snowflake') | ||
mock_connector = mock_snowflake.connector | ||
|
||
# need to patch this here so it can be imported in the function scope | ||
with patch.dict('sys.modules', snowflake=mock_snowflake): | ||
mock_connector.connect.return_value = 'OK' | ||
|
||
snowflake = SnowflakeDatabase(user='test_user', | ||
password='test_pass', | ||
account='test_account', | ||
database='test_database') | ||
result = snowflake.connect() | ||
|
||
self.assertEqual('OK', result) | ||
mock_connector.connect.assert_called_once_with(user='test_user', | ||
password='test_pass', | ||
account='test_account', | ||
database='test_database', | ||
private_key=None, | ||
region=None, | ||
warehouse=None) | ||
|
||
@patch('fireant.database.snowflake.serialization') | ||
def test_connect_with_pkey(self, mock_serialization): | ||
mock_snowflake = Mock(name='mock_snowflake') | ||
mock_connector = mock_snowflake.connector | ||
mock_pkey = mock_serialization.load_pem_private_key.return_value = Mock(name='pkey') | ||
|
||
# need to patch this here so it can be imported in the function scope | ||
with patch.dict('sys.modules', snowflake=mock_snowflake): | ||
mock_connector.connect.return_value = 'OK' | ||
|
||
snowflake = SnowflakeDatabase(user='test_user', | ||
private_key_data='abcdefg', | ||
private_key_password='1234', | ||
account='test_account', | ||
database='test_database') | ||
result = snowflake.connect() | ||
|
||
with self.subTest('returns connection'): | ||
self.assertEqual('OK', result) | ||
|
||
with self.subTest('connects with credentials'): | ||
mock_serialization.load_pem_private_key.assert_called_once_with(b'abcdefg', | ||
b'1234', | ||
backend=ANY) | ||
|
||
with self.subTest('connects with credentials'): | ||
mock_connector.connect.assert_called_once_with(user='test_user', | ||
password=None, | ||
account='test_account', | ||
database='test_database', | ||
private_key=mock_pkey.private_bytes.return_value, | ||
region=None, | ||
warehouse=None) | ||
|
||
def test_trunc_hour(self): | ||
result = SnowflakeDatabase().trunc_date(Field('date'), hourly) | ||
|
||
self.assertEqual('TRUNC("date",\'HH\')', str(result)) | ||
|
||
def test_trunc_day(self): | ||
result = SnowflakeDatabase().trunc_date(Field('date'), daily) | ||
|
||
self.assertEqual('TRUNC("date",\'DD\')', str(result)) | ||
|
||
def test_trunc_week(self): | ||
result = SnowflakeDatabase().trunc_date(Field('date'), weekly) | ||
|
||
self.assertEqual('TRUNC("date",\'IW\')', str(result)) | ||
|
||
def test_trunc_quarter(self): | ||
result = SnowflakeDatabase().trunc_date(Field('date'), quarterly) | ||
|
||
self.assertEqual('TRUNC("date",\'Q\')', str(result)) | ||
|
||
def test_trunc_year(self): | ||
result = SnowflakeDatabase().trunc_date(Field('date'), annually) | ||
|
||
self.assertEqual('TRUNC("date",\'Y\')', str(result)) | ||
|
||
def test_date_add_hour(self): | ||
result = SnowflakeDatabase().date_add(Field('date'), 'hour', 1) | ||
|
||
self.assertEqual('TIMESTAMPADD(\'hour\',1,"date")', str(result)) | ||
|
||
def test_date_add_day(self): | ||
result = SnowflakeDatabase().date_add(Field('date'), 'day', 1) | ||
|
||
self.assertEqual('TIMESTAMPADD(\'day\',1,"date")', str(result)) | ||
|
||
def test_date_add_week(self): | ||
result = SnowflakeDatabase().date_add(Field('date'), 'week', 1) | ||
|
||
self.assertEqual('TIMESTAMPADD(\'week\',1,"date")', str(result)) | ||
|
||
def test_date_add_month(self): | ||
result = SnowflakeDatabase().date_add(Field('date'), 'month', 1) | ||
|
||
self.assertEqual('TIMESTAMPADD(\'month\',1,"date")', str(result)) | ||
|
||
def test_date_add_quarter(self): | ||
result = SnowflakeDatabase().date_add(Field('date'), 'quarter', 1) | ||
|
||
self.assertEqual('TIMESTAMPADD(\'quarter\',1,"date")', str(result)) | ||
|
||
def test_date_add_year(self): | ||
result = SnowflakeDatabase().date_add(Field('date'), 'year', 1) | ||
|
||
self.assertEqual('TIMESTAMPADD(\'year\',1,"date")', str(result)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,11 @@ | ||
-r requirements.txt | ||
|
||
mock | ||
matplotlib | ||
pymysql==0.8.0 | ||
vertica-python==0.7.3 | ||
psycopg2==2.7.3.2 | ||
matplotlib | ||
snowflake-connector-python==1.7.2 | ||
bumpversion==0.5.3 | ||
wheel==0.30.0 | ||
watchdog==0.8.3 | ||
flake8==3.5.0 | ||
|
||
mock |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,5 +4,4 @@ pypika==0.20.1 | |
toposort==1.5 | ||
typing==3.6.2 | ||
python-dateutil==2.7.3 | ||
|
||
mock | ||
cryptography==2.4.2 |
Oops, something went wrong.