Skip to content

Commit

Permalink
Merge 1358c0f into 3fc1245
Browse files Browse the repository at this point in the history
  • Loading branch information
twheys committed Oct 29, 2018
2 parents 3fc1245 + 1358c0f commit 5e00798
Show file tree
Hide file tree
Showing 31 changed files with 1,679 additions and 1,823 deletions.
80 changes: 70 additions & 10 deletions fireant/database/base.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,100 @@
from functools import wraps

import pandas as pd
import time

from pypika import (
Query,
enums,
functions as fn,
terms,
)
from .logger import (
query_logger,
slow_query_logger,
)


def log(func):
@wraps(func)
def wrapper(database, query):
start_time = time.time()
query_logger.debug(query)

result = func(database, query)

duration = round(time.time() - start_time, 4)
query_log_msg = '[{duration} seconds]: {query}'.format(duration=duration,
query=query)
query_logger.info(query_log_msg)

if duration >= database.slow_query_log_min_seconds:
slow_query_logger.warning(query_log_msg)

return result

return wrapper


@log
def fetch(database, query):
with database.connect() as connection:
cursor = connection.cursor()
cursor.execute(query)
return cursor.fetchall()


@log
def fetch_data(database, query):
with database.connect() as connection:
return pd.read_sql(query, connection, coerce_float=True, parse_dates=True)


class Database(object):
"""
WRITEME
This is a abstract base class used for interfacing with a database platform.
"""
# The pypika query class to use for constructing queries
query_cls = Query

slow_query_log_min_seconds = 15

def __init__(self, host=None, port=None, database=None, max_processes=2, cache_middleware=None):
self.host = host
self.port = port
self.database = database
self.max_processes = max_processes
self.cache_middleware = cache_middleware

def connect(self):
"""
This function must establish a connection to the database platform and return it.
"""
raise NotImplementedError

def trunc_date(self, field, interval):
"""
This function must create a Pypika function which truncates a Date or DateTime object to a specific interval.
"""
raise NotImplementedError

def date_add(self, field: terms.Term, date_part: str, interval: int):
""" Database specific function for adding or subtracting dates """
"""
This function must add/subtract a Date or Date/Time object.
"""
raise NotImplementedError

def fetch(self, query):
with self.connect() as connection:
cursor = connection.cursor()
cursor.execute(query)
return cursor.fetchall()

def to_char(self, definition):
return fn.Cast(definition, enums.SqlTypes.VARCHAR)

def fetch(self, query):
if self.cache_middleware is not None:
return self.cache_middleware(fetch)(self, query)
return fetch(self, query)

def fetch_data(self, query):
with self.connect() as connection:
return pd.read_sql(query, connection, coerce_float=True, parse_dates=True)
if self.cache_middleware is not None:
return self.cache_middleware(fetch_data)(self, query)

return fetch_data(self, query)
File renamed without changes.
20 changes: 6 additions & 14 deletions fireant/database/mysql.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import pandas as pd

from pypika import (
Dialects,
MySQLQuery,
enums,
functions as fn,
terms,
)

from .base import Database


Expand Down Expand Up @@ -40,11 +40,11 @@ class MySQLDatabase(Database):
# The pypika query class to use for constructing queries
query_cls = MySQLQuery

def __init__(self, database=None, host='localhost', port=3306,
user=None, password=None, charset='utf8mb4'):
self.host = host
self.port = port
self.database = database
def __init__(self, host='localhost', port=3306, database=None,
user=None, password=None, charset='utf8mb4', max_processes=1, cache_middleware=None):
super(MySQLDatabase, self).__init__(host, port, database,
max_processes=max_processes,
cache_middleware=cache_middleware)
self.user = user
self.password = password
self.charset = charset
Expand All @@ -57,14 +57,6 @@ def connect(self):
charset=self.charset,
cursorclass=pymysql.cursors.Cursor)

def fetch(self, query):
with self.connect().cursor() as cursor:
cursor.execute(query)
return cursor.fetchall()

def fetch_data(self, query):
return pd.read_sql(query, self.connect())

def trunc_date(self, field, interval):
return Trunc(field, str(interval))

Expand Down
20 changes: 5 additions & 15 deletions fireant/database/postgresql.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import pandas as pd

from pypika import (
PostgreSQLQuery,
functions as fn,
Expand Down Expand Up @@ -29,11 +27,11 @@ class PostgreSQLDatabase(Database):
# The pypika query class to use for constructing queries
query_cls = PostgreSQLQuery

def __init__(self, database=None, host='localhost', port=5432,
user=None, password=None):
self.host = host
self.port = port
self.database = database
def __init__(self, host='localhost', port=5432, database=None,
user=None, password=None, max_processes=1, cache_middleware=None):
super(PostgreSQLDatabase, self).__init__(host, port, database,
max_processes=max_processes,
cache_middleware=cache_middleware)
self.user = user
self.password = password

Expand All @@ -43,14 +41,6 @@ def connect(self):
return psycopg2.connect(host=self.host, port=self.port, dbname=self.database,
user=self.user, password=self.password)

def fetch(self, query):
with self.connect().cursor() as cursor:
cursor.execute(query)
return cursor.fetchall()

def fetch_data(self, query):
return pd.read_sql(query, self.connect())

def trunc_date(self, field, interval):
return DateTrunc(field, str(interval))

Expand Down
8 changes: 5 additions & 3 deletions fireant/database/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ class RedshiftDatabase(PostgreSQLDatabase):
# The pypika query class to use for constructing queries
query_cls = RedshiftQuery

def __init__(self, database=None, host=None, port=5439, user=None, password=None):
super(RedshiftDatabase, self).__init__(database=database, host=host, port=port,
user=user, password=password)
def __init__(self, host='localhost', port=5439, database=None,
user=None, password=None, max_processes=1, cache_middleware=None):
super(RedshiftDatabase, self).__init__(host, port, database, user, password,
max_processes=max_processes,
cache_middleware=cache_middleware)
11 changes: 4 additions & 7 deletions fireant/database/vertica.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,10 @@ class VerticaDatabase(Database):
}

def __init__(self, host='localhost', port=5433, database='vertica', user='vertica', password=None,
read_timeout=None):
self.host = host
self.port = port
self.database = database
read_timeout=None, max_processes=1, cache_middleware=None):
super(VerticaDatabase, self).__init__(host, port, database,
max_processes=max_processes,
cache_middleware=cache_middleware)
self.user = user
self.password = password
self.read_timeout = read_timeout
Expand All @@ -54,9 +54,6 @@ def connect(self):
read_timeout=self.read_timeout,
unicode_error='replace')

def fetch(self, query):
return super(VerticaDatabase, self).fetch(query)

def trunc_date(self, field, interval):
trunc_date_interval = self.DATETIME_INTERVALS.get(str(interval), 'DD')
return Trunc(field, trunc_date_interval)
Expand Down

0 comments on commit 5e00798

Please sign in to comment.