Skip to content

Commit

Permalink
Merge pull request #77 from x8lucas8x/add-split-and-regex-like-functions
Browse files Browse the repository at this point in the history
Add split and regexp_like functions for different dialects
  • Loading branch information
x8lucas8x committed Oct 26, 2017
2 parents 4593eac + 2b121e2 commit 916c9a2
Show file tree
Hide file tree
Showing 5 changed files with 171 additions and 14 deletions.
2 changes: 1 addition & 1 deletion pypika/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,4 +68,4 @@

__author__ = "Timothy Heys"
__email__ = "theys@kayak.com"
__version__ = "0.8.0"
__version__ = "0.9.0"
60 changes: 58 additions & 2 deletions pypika/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,16 @@
"""
Package for SQL functions wrappers
"""
from pypika.enums import SqlTypes
from pypika.terms import Function, Star, AggregateFunction
from pypika.enums import (
SqlTypes,
Dialects,
)
from pypika.terms import (
Function,
Star,
AggregateFunction,
ValueWrapper,
)
from pypika.utils import builder

__author__ = "Timothy Heys"
Expand Down Expand Up @@ -179,6 +187,54 @@ def __init__(self, term, alias=None):
super(Trim, self).__init__('TRIM', term, alias=alias)


class SplitPart(Function):
def __init__(self, term, delimiter, index, alias=None):
super(SplitPart, self).__init__('SPLIT_PART', term, delimiter, index, alias=alias)

def get_name_for_dialect(self, dialect=None):
return {
Dialects.MYSQL: 'SUBSTRING_INDEX',
Dialects.POSTGRESQL: 'SPLIT_PART',
Dialects.REDSHIFT: 'SPLIT_PART',
Dialects.VERTICA: 'SPLIT_PART',
Dialects.ORACLE: 'REGEXP_SUBSTR',
}.get(dialect, None)

def get_args_for_dialect(self, dialect=None):
term, delimiter, index = self.args

return {
Dialects.MYSQL: (term, delimiter, index),
Dialects.POSTGRESQL: (term, delimiter, index),
Dialects.REDSHIFT: (term, delimiter, index),
Dialects.VERTICA: (term, delimiter, index),
Dialects.ORACLE: (term, ValueWrapper('[^{}]+'.format(delimiter.value)), 1, index)
}.get(dialect, None)


class RegexpLike(Function):
def __init__(self, term, pattern, modifiers, alias=None):
super(RegexpLike, self).__init__('REGEXP_LIKE', term, pattern, modifiers, alias=alias)

def get_name_for_dialect(self, dialect=None):
return {
Dialects.POSTGRESQL: 'REGEXP_MATCHES',
Dialects.REDSHIFT: 'REGEXP_MATCHES',
Dialects.VERTICA: 'REGEXP_LIKE',
Dialects.ORACLE: 'REGEXP_LIKE',
}.get(dialect, self.name)

def get_args_for_dialect(self, dialect=None):
term, pattern, modifiers = self.args

return {
Dialects.POSTGRESQL: (term, pattern, modifiers),
Dialects.REDSHIFT: (term, pattern, modifiers),
Dialects.VERTICA: (term, pattern, modifiers),
Dialects.ORACLE: (term, pattern, modifiers)
}.get(dialect, None)


# Date Functions
class Now(Function):
def __init__(self, alias=None):
Expand Down
36 changes: 33 additions & 3 deletions pypika/terms.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
)
from pypika.utils import (
CaseException,
DialectNotSupported,
alias_sql,
builder,
ignoredeepcopy,
Expand Down Expand Up @@ -654,19 +655,48 @@ def get_special_params_sql(self, **kwargs):
def get_function_sql(self, **kwargs):
special_params_sql = self.get_special_params_sql(**kwargs)

dialect = kwargs.get('dialect', None)
dialect_name = self.get_name_for_dialect(dialect=dialect)
dialect_args = self.get_args_for_dialect(dialect=dialect)

if dialect_name is None or dialect_args is None:
raise DialectNotSupported('The function {} has no support for {} dialect'.format(self.name, dialect))

return '{name}({args}{special})'.format(
name=self.name,
name=dialect_name,
args=','.join(p.get_sql(with_alias=False, **kwargs)
if hasattr(p, 'get_sql')
else str(p)
for p in self.args),
for p in dialect_args),
special=(' ' + special_params_sql) if special_params_sql else '',
)

def get_name_for_dialect(self, dialect=None):
"""
This function will transform the original function name into the equivalent for different dialects.
In practice this method should be overriden on subclasses whenever different dialects support is
required. Otherwise the original name will be used.
:param dialect: one of the options in the Dialects enum.
:return: the function name that should be used by the get_function_sql method when serializing.
"""
return self.name

def get_args_for_dialect(self, dialect=None):
"""
This function will transform the original function args into the equivalent for different dialects.
In practice this method should be overriden on subclasses whenever different dialects support is
required. Otherwise the original arguments will be used.
:param dialect: one of the options in the Dialects enum.
:return: the function args that should be used by the get_function_sql method when serializing.
"""
return self.args

def get_sql(self, with_alias=False, with_namespace=False, quote_char=None, **kwargs):
# FIXME escape

function_sql = self.get_function_sql(with_namespace=with_namespace, quote_char=quote_char)
function_sql = self.get_function_sql(with_namespace=with_namespace, quote_char=quote_char, **kwargs)

if not with_alias or self.alias is None:
return function_sql
Expand Down
83 changes: 75 additions & 8 deletions pypika/tests/test_functions.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,24 @@
# coding: utf8
import unittest

from pypika import (Query as Q,
Table as T,
Field as F,
functions as fn,
CaseException,
Case,
Interval,
DatePart)
from pypika import (
Query as Q,
Table as T,
Field as F,
functions as fn,
CaseException,
Case,
Interval,
DatePart,
MySQLQuery,
VerticaQuery,
PostgreSQLQuery,
RedshiftQuery,
OracleQuery,
)
from pypika.enums import (SqlTypes,
Dialects)
from pypika.utils import DialectNotSupported

__author__ = "Timothy Heys"
__email__ = "theys@kayak.com"
Expand Down Expand Up @@ -358,6 +366,65 @@ def test__length__field(self):
self.assertEqual("SELECT LENGTH(\"foo\") FROM \"abc\"", str(q))


class SplitPartFunctionTests(unittest.TestCase):
t = T('abc')

def test__split_part__field_with_vertica_dialect(self):
q = VerticaQuery.from_(self.t).select(fn.SplitPart(self.t.foo, '|', 3))

self.assertEqual("SELECT SPLIT_PART(\"foo\",\'|\',3) FROM \"abc\"", str(q))

def test__split_part__field_with_mysql_dialect(self):
q = MySQLQuery.from_(self.t).select(fn.SplitPart(self.t.foo, '|', 3))

self.assertEqual("SELECT SUBSTRING_INDEX(`foo`,\'|\',3) FROM `abc`", str(q))

def test__split_part__field_with_postgresql_dialect(self):
q = PostgreSQLQuery.from_(self.t).select(fn.SplitPart(self.t.foo, '|', 3))

self.assertEqual("SELECT SPLIT_PART(\"foo\",\'|\',3) FROM \"abc\"", str(q))

def test__split_part__field_with_redshift_dialect(self):
q = RedshiftQuery.from_(self.t).select(fn.SplitPart(self.t.foo, '|', 3))

self.assertEqual("SELECT SPLIT_PART(\"foo\",\'|\',3) FROM \"abc\"", str(q))

def test__split_part__field_with_oracle_dialect(self):
q = OracleQuery.from_(self.t).select(fn.SplitPart(self.t.foo, '|', 3))

self.assertEqual("SELECT REGEXP_SUBSTR(\"foo\",\'[^|]+\',1,3) FROM \"abc\"", str(q))


class RegexpLikeFunctionTests(unittest.TestCase):
t = T('abc')

def test__regexp_like__field_with_vertica_dialect(self):
q = VerticaQuery.from_(self.t).select(fn.RegexpLike(self.t.foo, '^a', 'x'))

self.assertEqual("SELECT REGEXP_LIKE(\"foo\",\'^a\',\'x\') FROM \"abc\"", str(q))

def test__regexp_like__field_with_mysql_dialect(self):
q = MySQLQuery.from_(self.t).select(fn.RegexpLike(self.t.foo, '^a', 'x'))

with self.assertRaises(DialectNotSupported):
str(q)

def test__regexp_like__field_with_postgresql_dialect(self):
q = PostgreSQLQuery.from_(self.t).select(fn.RegexpLike(self.t.foo, '^a', 'x'))

self.assertEqual("SELECT REGEXP_MATCHES(\"foo\",\'^a\',\'x\') FROM \"abc\"", str(q))

def test__regexp_like__field_with_redshift_dialect(self):
q = RedshiftQuery.from_(self.t).select(fn.RegexpLike(self.t.foo, '^a', 'x'))

self.assertEqual("SELECT REGEXP_MATCHES(\"foo\",\'^a\',\'x\') FROM \"abc\"", str(q))

def test__regexp_like__field_with_oracle_dialect(self):
q = OracleQuery.from_(self.t).select(fn.RegexpLike(self.t.foo, '^a', 'x'))

self.assertEqual("SELECT REGEXP_LIKE(\"foo\",\'^a\',\'x\') FROM \"abc\"", str(q))


class CastTests(unittest.TestCase):
t = T('abc')

Expand Down
4 changes: 4 additions & 0 deletions pypika/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@ class RollupException(Exception):
pass


class DialectNotSupported(Exception):
pass


def builder(func):
"""
Decorator for wrapper "builder" functions. These are functions on the Query class or other classes used for
Expand Down

0 comments on commit 916c9a2

Please sign in to comment.