Skip to content

Commit

Permalink
Merge pull request #280 from kayak/pypika_immutable
Browse files Browse the repository at this point in the history
Changed fireant to use pypika immutability
  • Loading branch information
twheys committed Jan 27, 2020
2 parents b3c0c12 + 034abe0 commit 73a5123
Show file tree
Hide file tree
Showing 5 changed files with 139 additions and 112 deletions.
39 changes: 26 additions & 13 deletions fireant/database/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class DateTrunc(terms.Function):
"""

def __init__(self, field, date_format, alias=None):
super(DateTrunc, self).__init__('DATE_TRUNC', date_format, field, alias=alias)
super(DateTrunc, self).__init__("DATE_TRUNC", date_format, field, alias=alias)
# Setting the fields here means we can access the TRUNC args by name.
self.field = field
self.date_format = date_format
Expand All @@ -29,17 +29,29 @@ class PostgreSQLDatabase(Database):
# The pypika query class to use for constructing queries
query_cls = PostgreSQLQuery

def __init__(self, host='localhost', port=5432, database=None,
user=None, password=None, **kwags):
def __init__(
self,
host="localhost",
port=5432,
database=None,
user=None,
password=None,
**kwags
):
super(PostgreSQLDatabase, self).__init__(host, port, database, **kwags)
self.user = user
self.password = password

def connect(self):
import psycopg2

return psycopg2.connect(host=self.host, port=self.port, dbname=self.database,
user=self.user, password=self.password)
return psycopg2.connect(
host=self.host,
port=self.port,
dbname=self.database,
user=self.user,
password=self.password,
)

def trunc_date(self, field, interval):
return DateTrunc(field, str(interval))
Expand All @@ -48,14 +60,15 @@ def date_add(self, field, date_part, interval):
return fn.DateAdd(str(date_part), interval, field)

def get_column_definitions(self, schema, table, connection=None):
columns = Table('columns', schema='INFORMATION_SCHEMA')

columns_query = PostgreSQLQuery.from_(columns) \
.select(columns.column_name, columns.data_type) \
.where(columns.table_schema == schema) \
.where(columns.field('table_name') == table) \
.distinct() \
columns = Table("columns", schema="INFORMATION_SCHEMA")

columns_query = (
PostgreSQLQuery.from_(columns, immutable=False)
.select(columns.column_name, columns.data_type)
.where(columns.table_schema == schema)
.where(columns.field("table_name") == table)
.distinct()
.orderby(columns.column_name)
)

return self.fetch(str(columns_query), connection=connection)

192 changes: 105 additions & 87 deletions fireant/database/vertica.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,28 +4,26 @@
functions as fn,
terms,
)

from .base import Database
from .type_engine import TypeEngine

from .sql_types import (
Char,
VarChar,
Text,
Boolean,
Integer,
SmallInt,
BigInt,
Boolean,
Char,
Date,
DateTime,
Decimal,
Numeric,
DoublePrecision,
Float,
Integer,
Numeric,
Real,
DoublePrecision,
Date,
SmallInt,
Text,
Time,
DateTime,
Timestamp,
VarChar,
)
from .type_engine import TypeEngine


class Trunc(terms.Function):
Expand All @@ -34,7 +32,7 @@ class Trunc(terms.Function):
"""

def __init__(self, field, date_format, alias=None):
super(Trunc, self).__init__('TRUNC', field, date_format, alias=alias)
super(Trunc, self).__init__("TRUNC", field, date_format, alias=alias)
# Setting the fields here means we can access the TRUNC args by name.
self.field = field
self.date_format = date_format
Expand All @@ -50,16 +48,24 @@ class VerticaDatabase(Database):
query_cls = VerticaQuery

DATETIME_INTERVALS = {
'hour': 'HH',
'day': 'DD',
'week': 'IW',
'month': 'MM',
'quarter': 'Q',
'year': 'Y'
"hour": "HH",
"day": "DD",
"week": "IW",
"month": "MM",
"quarter": "Q",
"year": "Y",
}

def __init__(self, host='localhost', port=5433, database='vertica', user='vertica', password=None,
read_timeout=None, **kwags):
def __init__(
self,
host="localhost",
port=5433,
database="vertica",
user="vertica",
password=None,
read_timeout=None,
**kwags
):
super(VerticaDatabase, self).__init__(host, port, database, **kwags)
self.user = user
self.password = password
Expand All @@ -69,25 +75,35 @@ def __init__(self, host='localhost', port=5433, database='vertica', user='vertic
def connect(self):
import vertica_python

return vertica_python.connect(host=self.host, port=self.port, database=self.database,
user=self.user, password=self.password,
read_timeout=self.read_timeout,
unicode_error='replace')
return vertica_python.connect(
host=self.host,
port=self.port,
database=self.database,
user=self.user,
password=self.password,
read_timeout=self.read_timeout,
unicode_error="replace",
)

def trunc_date(self, field, interval):
trunc_date_interval = self.DATETIME_INTERVALS.get(str(interval), 'DD')
trunc_date_interval = self.DATETIME_INTERVALS.get(str(interval), "DD")
return Trunc(field, trunc_date_interval)

def date_add(self, field, date_part, interval):
return fn.TimestampAdd(str(date_part), interval, field)

def get_column_definitions(self, schema, table, connection=None):
table_columns = Table('columns')

table_query = VerticaQuery.from_(table_columns) \
.select(table_columns.column_name, table_columns.data_type) \
.where((table_columns.table_schema == schema) & (table_columns.field('table_name') == table)) \
table_columns = Table("columns")

table_query = (
VerticaQuery.from_(table_columns, immutable=False)
.select(table_columns.column_name, table_columns.data_type)
.where(
(table_columns.table_schema == schema)
& (table_columns.field("table_name") == table)
)
.distinct()
)

return self.fetch(str(table_query), connection=connection)

Expand All @@ -99,9 +115,7 @@ def import_csv(self, table, file_path, connection=None):
:param file_path: The path of the file to be imported.
:param connection: (Optional) The connection to execute this query with.
"""
import_query = VerticaQuery \
.from_file(file_path) \
.copy_(table)
import_query = VerticaQuery.from_file(file_path).copy_(table)

self.execute(str(import_query), connection=connection)

Expand All @@ -113,12 +127,13 @@ def create_temporary_table_from_columns(self, table, columns, connection=None):
:param columns: The columns of the new temporary table.
:param connection: (Optional) The connection to execute this query with.
"""
create_query = VerticaQuery \
.create_table(table) \
.temporary() \
.local() \
.preserve_rows() \
create_query = (
VerticaQuery.create_table(table)
.temporary()
.local()
.preserve_rows()
.columns(*columns)
)

self.execute(str(create_query), connection=connection)

Expand All @@ -130,63 +145,66 @@ def create_temporary_table_from_select(self, table, select_query, connection=Non
:param select_query: The query to be used for selecting data of an existing table for the new temporary table.
:param connection: (Optional) The connection to execute this query with.
"""
create_query = VerticaQuery \
.create_table(table) \
.temporary() \
.local() \
.preserve_rows() \
create_query = (
VerticaQuery.create_table(table)
.temporary()
.local()
.preserve_rows()
.as_select(select_query)
)

self.execute(str(create_query), connection=connection)


class VerticaTypeEngine(TypeEngine):
vertica_to_ansi_mapper = {
'char': Char,
'varchar': VarChar,
'varchar2': VarChar,
'longvarchar': Text,
'boolean': Boolean,
'int': Integer,
'integer': Integer,
'int8': Integer,
'smallint': SmallInt,
'tinyint': SmallInt,
'bigint': BigInt,
'decimal': Decimal,
'numeric': Numeric,
'number': Numeric,
'float': Float,
'float8': Float,
'real': Real,
'double': DoublePrecision,
'date': Date,
'time': Time,
'timetz': Time,
'datetime': DateTime,
'smalldatetime': DateTime,
'timestamp': Timestamp,
'timestamptz': Timestamp,
"char": Char,
"varchar": VarChar,
"varchar2": VarChar,
"longvarchar": Text,
"boolean": Boolean,
"int": Integer,
"integer": Integer,
"int8": Integer,
"smallint": SmallInt,
"tinyint": SmallInt,
"bigint": BigInt,
"decimal": Decimal,
"numeric": Numeric,
"number": Numeric,
"float": Float,
"float8": Float,
"real": Real,
"double": DoublePrecision,
"date": Date,
"time": Time,
"timetz": Time,
"datetime": DateTime,
"smalldatetime": DateTime,
"timestamp": Timestamp,
"timestamptz": Timestamp,
}

ansi_to_vertica_mapper = {
'CHAR': 'char',
'VARCHAR': 'varchar',
'TEXT': 'longvarchar',
'BOOLEAN': 'boolean',
'INTEGER': 'integer',
'SMALLINT': 'smallint',
'BIGINT': 'bigint',
'DECIMAL': 'decimal',
'NUMERIC': 'numeric',
'FLOAT': 'float',
'REAL': 'real',
'DOUBLEPRECISION': 'double',
'DATE': 'date',
'TIME': 'time',
'DATETIME': 'datetime',
'TIMESTAMP': 'timestamp',
"CHAR": "char",
"VARCHAR": "varchar",
"TEXT": "longvarchar",
"BOOLEAN": "boolean",
"INTEGER": "integer",
"SMALLINT": "smallint",
"BIGINT": "bigint",
"DECIMAL": "decimal",
"NUMERIC": "numeric",
"FLOAT": "float",
"REAL": "real",
"DOUBLEPRECISION": "double",
"DATE": "date",
"TIME": "time",
"DATETIME": "datetime",
"TIMESTAMP": "timestamp",
}

def __init__(self):
super(VerticaTypeEngine, self).__init__(self.vertica_to_ansi_mapper, self.ansi_to_vertica_mapper)
super(VerticaTypeEngine, self).__init__(
self.vertica_to_ansi_mapper, self.ansi_to_vertica_mapper
)
2 changes: 1 addition & 1 deletion fireant/queries/builder/dataset_blender_query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ def _blend_query(dimensions, metrics, orders, field_maps, queries):
base_query, *join_queries = queries
base_field_map, *join_field_maps = field_maps

blender_query = Query.from_(base_query)
blender_query = Query.from_(base_query, immutable=False)
for join_sql, join_field_map in zip(join_queries, join_field_maps):
criteria = _blender_join_criteria(
base_query, join_sql, dimensions, base_field_map, join_field_map
Expand Down
16 changes: 6 additions & 10 deletions fireant/queries/sql_transformer.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,5 @@
from typing import Iterable

from pypika import (
Table,
functions as fn,
)

from fireant.database import Database
from fireant.dataset.fields import Field
from fireant.dataset.filters import Filter
Expand All @@ -14,10 +9,11 @@
alias_selector,
flatten,
)
from .field_helper import (
make_term_for_field,
make_term_for_field,
from pypika import (
Table,
functions as fn,
)
from .field_helper import make_term_for_field
from .finders import (
find_and_group_references_for_dimensions,
find_joins_for_tables,
Expand Down Expand Up @@ -155,7 +151,7 @@ def make_slicer_query(
:return:
"""
query = database.query_cls.from_(base_table)
query = database.query_cls.from_(base_table, immutable=False)
elements = flatten([metrics, dimensions, filters])

# Add joins
Expand Down Expand Up @@ -203,7 +199,7 @@ def make_latest_query(
joins: Iterable[Join] = (),
dimensions: Iterable[Field] = (),
):
query = database.query_cls.from_(base_table)
query = database.query_cls.from_(base_table, immutable=False)

# Add joins
join_tables_needed_for_query = find_required_tables_to_join(dimensions, base_table)
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
six
pandas==0.23.4
pypika==0.35.20
pypika==0.35.21
toposort==1.5
typing==3.6.2
python-dateutil==2.8.0
Expand Down

0 comments on commit 73a5123

Please sign in to comment.