Skip to content

Commit

Permalink
Merge 264fb0e into 4bf968c
Browse files Browse the repository at this point in the history
  • Loading branch information
twheys committed Jan 21, 2019
2 parents 4bf968c + 264fb0e commit 3947a4f
Show file tree
Hide file tree
Showing 6 changed files with 313 additions and 69 deletions.
1 change: 1 addition & 0 deletions fireant/database/__init__.py
Expand Up @@ -2,4 +2,5 @@
from .mysql import MySQLDatabase
from .postgresql import PostgreSQLDatabase
from .redshift import RedshiftDatabase
from .snowflake import SnowflakeDatabase
from .vertica import VerticaDatabase
102 changes: 102 additions & 0 deletions fireant/database/snowflake.py
@@ -0,0 +1,102 @@
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import serialization

from pypika import (
VerticaQuery,
functions as fn,
terms,
)
from .base import Database

try:
from snowflake import connector as snowflake
except:
pass


class Trunc(terms.Function):
"""
Wrapper for Vertica TRUNC function for truncating dates.
"""

def __init__(self, field, date_format, alias=None):
super(Trunc, self).__init__('TRUNC', field, date_format, alias=alias)
# Setting the fields here means we can access the TRUNC args by name.
self.field = field
self.date_format = date_format
self.alias = alias


class SnowflakeDatabase(Database):
"""
Vertica client that uses the vertica_python driver.
"""

# The pypika query class to use for constructing queries
query_cls = VerticaQuery

DATETIME_INTERVALS = {
'hour': 'HH',
'day': 'DD',
'week': 'IW',
'month': 'MM',
'quarter': 'Q',
'year': 'Y'
}
_private_key = None

def __init__(self, user='snowflake', password=None,
account='snowflake', database='snowflake',
private_key_data=None, private_key_password=None,
region=None, warehouse=None,
max_processes=1, cache_middleware=None):
super(SnowflakeDatabase, self).__init__(database=database,
max_processes=max_processes,
cache_middleware=cache_middleware)
self.user = user
self.password = password
self.account = account
self.private_key_data = private_key_data
self.private_key_password = private_key_password
self.region = region
self.warehouse = warehouse

def connect(self):
import snowflake

return snowflake.connector.connect(database=self.database,
account=self.account,
user=self.user,
password=self.password,
private_key=self._get_private_key(),
region=self.region,
warehouse=self.warehouse)

def trunc_date(self, field, interval):
trunc_date_interval = self.DATETIME_INTERVALS.get(str(interval), 'DD')
return Trunc(field, trunc_date_interval)

def date_add(self, field, date_part, interval):
return fn.TimestampAdd(str(date_part), interval, field)

def _get_private_key(self):
if self._private_key is None:
self._private_key = self._load_private_key_data()

return self._private_key

def _load_private_key_data(self):
if self.private_key_data is None:
return None

private_key_password = None \
if self.private_key_password is None \
else self.private_key_password.encode()

pkey = serialization.load_pem_private_key(self.private_key_data.encode(),
private_key_password,
backend=default_backend())

return pkey.private_bytes(encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption())
136 changes: 136 additions & 0 deletions fireant/tests/database/test_snowflake.py
@@ -0,0 +1,136 @@
from unittest import TestCase
from unittest.mock import (
ANY,
Mock,
patch,
)

from fireant.database import SnowflakeDatabase
from fireant.slicer import *
from pypika import Field


class TestSnowflake(TestCase):
def test_defaults(self):
snowflake = SnowflakeDatabase()

self.assertEqual('snowflake', snowflake.account)
self.assertEqual('snowflake', snowflake.database)
self.assertEqual('snowflake', snowflake.user)
self.assertIsNone(snowflake.password)
self.assertIsNone(snowflake.private_key_data)
self.assertIsNone(snowflake.private_key_password)
self.assertIsNone(snowflake.region)
self.assertIsNone(snowflake.warehouse)

def test_connect_with_password(self):
mock_snowflake = Mock(name='mock_snowflake')
mock_connector = mock_snowflake.connector

# need to patch this here so it can be imported in the function scope
with patch.dict('sys.modules', snowflake=mock_snowflake):
mock_connector.connect.return_value = 'OK'

snowflake = SnowflakeDatabase(user='test_user',
password='test_pass',
account='test_account',
database='test_database')
result = snowflake.connect()

self.assertEqual('OK', result)
mock_connector.connect.assert_called_once_with(user='test_user',
password='test_pass',
account='test_account',
database='test_database',
private_key=None,
region=None,
warehouse=None)

@patch('fireant.database.snowflake.serialization')
def test_connect_with_pkey(self, mock_serialization):
mock_snowflake = Mock(name='mock_snowflake')
mock_connector = mock_snowflake.connector
mock_pkey = mock_serialization.load_pem_private_key.return_value = Mock(name='pkey')

# need to patch this here so it can be imported in the function scope
with patch.dict('sys.modules', snowflake=mock_snowflake):
mock_connector.connect.return_value = 'OK'

snowflake = SnowflakeDatabase(user='test_user',
private_key_data='abcdefg',
private_key_password='1234',
account='test_account',
database='test_database')
result = snowflake.connect()

with self.subTest('returns connection'):
self.assertEqual('OK', result)

with self.subTest('connects with credentials'):
mock_serialization.load_pem_private_key.assert_called_once_with(b'abcdefg',
b'1234',
backend=ANY)

with self.subTest('connects with credentials'):
mock_connector.connect.assert_called_once_with(user='test_user',
password=None,
account='test_account',
database='test_database',
private_key=mock_pkey.private_bytes.return_value,
region=None,
warehouse=None)

def test_trunc_hour(self):
result = SnowflakeDatabase().trunc_date(Field('date'), hourly)

self.assertEqual('TRUNC("date",\'HH\')', str(result))

def test_trunc_day(self):
result = SnowflakeDatabase().trunc_date(Field('date'), daily)

self.assertEqual('TRUNC("date",\'DD\')', str(result))

def test_trunc_week(self):
result = SnowflakeDatabase().trunc_date(Field('date'), weekly)

self.assertEqual('TRUNC("date",\'IW\')', str(result))

def test_trunc_quarter(self):
result = SnowflakeDatabase().trunc_date(Field('date'), quarterly)

self.assertEqual('TRUNC("date",\'Q\')', str(result))

def test_trunc_year(self):
result = SnowflakeDatabase().trunc_date(Field('date'), annually)

self.assertEqual('TRUNC("date",\'Y\')', str(result))

def test_date_add_hour(self):
result = SnowflakeDatabase().date_add(Field('date'), 'hour', 1)

self.assertEqual('TIMESTAMPADD(\'hour\',1,"date")', str(result))

def test_date_add_day(self):
result = SnowflakeDatabase().date_add(Field('date'), 'day', 1)

self.assertEqual('TIMESTAMPADD(\'day\',1,"date")', str(result))

def test_date_add_week(self):
result = SnowflakeDatabase().date_add(Field('date'), 'week', 1)

self.assertEqual('TIMESTAMPADD(\'week\',1,"date")', str(result))

def test_date_add_month(self):
result = SnowflakeDatabase().date_add(Field('date'), 'month', 1)

self.assertEqual('TIMESTAMPADD(\'month\',1,"date")', str(result))

def test_date_add_quarter(self):
result = SnowflakeDatabase().date_add(Field('date'), 'quarter', 1)

self.assertEqual('TIMESTAMPADD(\'quarter\',1,"date")', str(result))

def test_date_add_year(self):
result = SnowflakeDatabase().date_add(Field('date'), 'year', 1)

self.assertEqual('TIMESTAMPADD(\'year\',1,"date")', str(result))
7 changes: 3 additions & 4 deletions requirements-dev.txt
@@ -1,12 +1,11 @@
-r requirements.txt

mock
matplotlib
pymysql==0.8.0
vertica-python==0.7.3
psycopg2==2.7.3.2
matplotlib
snowflake-connector-python==1.7.2
bumpversion==0.5.3
wheel==0.30.0
watchdog==0.8.3
flake8==3.5.0

mock
3 changes: 1 addition & 2 deletions requirements.txt
Expand Up @@ -4,5 +4,4 @@ pypika==0.20.1
toposort==1.5
typing==3.6.2
python-dateutil==2.7.3

mock
cryptography==2.4.2

0 comments on commit 3947a4f

Please sign in to comment.