Skip to content

Commit

Permalink
Merge pull request #29 from gisce/imp_custom_joins
Browse files Browse the repository at this point in the history
Support to use custom joins in select and where.
  • Loading branch information
ecarreras committed Sep 17, 2019
2 parents b0ecf23 + 64554a8 commit 170d72c
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 10 deletions.
46 changes: 46 additions & 0 deletions ooquery/operators.py
Expand Up @@ -27,3 +27,49 @@ class Not(OOOperator, operators.Not):
'&': And,
'!': Not
}


class JoinType(object):
type_ = None
__slots__ = ('field', )

def __init__(self, field):
self.field = field

def __repr__(self):
return '{}({})'.format(self.type_, self.field)

def __str__(self):
return self.field


class InnerJoin(JoinType):
type_ = 'INNER'


class LeftJoin(JoinType):
type_ = 'LEFT'


class LeftOuterJoin(JoinType):
type_ = 'LEFT OUTER'


class RightJoin(JoinType):
type_ = 'RIGHT'


class RightOuterJoin(JoinType):
type_ = 'RIGHT OUTER'


class FullJoin(JoinType):
type_ = 'FULL'


class FullOuterJoin(JoinType):
type_ = 'FULL OUTER'


class CrossJoin(JoinType):
type_ = 'CROSS'
26 changes: 19 additions & 7 deletions ooquery/parser.py
Expand Up @@ -10,6 +10,7 @@


class Parser(object):

def __init__(self, table, foreign_key=None):
self.operators = OPERATORS_MAP
self.table = table
Expand All @@ -31,37 +32,43 @@ def join_on(self):
def get_field_from_table(self, table, field):
return getattr(table, field)

def get_field_from_related_table(self, join_path_list, field_name):
self.parse_join(join_path_list)
def get_field_from_related_table(self, join_path_list, field_name, join_type='INNER'):
self.parse_join(join_path_list, join_type)
path = '.'.join(join_path_list)
join = self.joins_map.get(path)
if join:
table = join.right
return self.get_field_from_table(table, field_name)

def get_table_field(self, table, field):
if isinstance(field, JoinType):
join_type = field.type_
field = field.field
else:
join_type = 'INNER'
if '.' in field:
return self.get_field_from_related_table(
field.split('.')[:-1], field.split('.')[-1]
field.split('.')[:-1], field.split('.')[-1],
join_type
)
else:
return self.get_field_from_table(table, field)

def parse_join(self, fields_join):
def parse_join(self, fields_join, join_type):
table = self.table
self.join_path = []
for field_join in fields_join:
self.join_path.append(field_join)
fk = self.foreign_key(table._name, field_join)
table_join = Table(fk['foreign_table_name'])
join = Join(self.join_on, table_join)
join = Join(self.join_on, table_join, type_=join_type)
column = getattr(table, fk['column_name'])
fk_col = getattr(join.right, fk['foreign_column_name'])
join.condition = Equal(column, fk_col)
dotted_path = '.'.join(self.join_path)
join = self.get_join(dotted_path)
if not join:
join = self.join_on.join(table_join)
join = self.join_on.join(table_join, type_=join_type)
join.condition = Equal(column, fk_col)
self.joins_map[dotted_path] = join
self.joins.append(join)
Expand All @@ -88,10 +95,15 @@ def get_expressions(self, expression):

for idx, field in enumerate(fields):
columns.append(self.get_table_field(self.table, field))
if isinstance(field, JoinType):
join_type = expression[0].type_
field = field.field
else:
join_type = 'INNER'
if '.' in field:
fields_join = field.split('.')[:-1]
field_join = field.split('.')[-1]
self.parse_join(fields_join)
self.parse_join(fields_join, join_type)
join = self.joins_map['.'.join(field.split('.')[:-1])]
columns[idx] = self.get_table_field(join.right, field_join)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -2,7 +2,7 @@

setup(
name='ooquery',
version='0.19.0',
version='0.20.0-rc1',
packages=find_packages(),
url='https://github.com/gisce/ooquery',
license='MIT',
Expand Down
45 changes: 43 additions & 2 deletions spec/ooquery_spec.py
@@ -1,6 +1,7 @@
# coding=utf-8
from ooquery import OOQuery
from ooquery.expression import Field
from ooquery.operators import *
from sql import Table, Literal, NullsFirst, NullsLast
from sql.operators import And, Concat
from sql.aggregate import Max
Expand Down Expand Up @@ -28,7 +29,7 @@
sel.where = And((t.field3 == 4,))
expect(tuple(sql)).to(equal(tuple(sel)))

with it('should have where mehtod and compare two fields of the table'):
with it('should have where method and compare two fields of the table'):
q = OOQuery('table')
sql = q.select(['field1', 'field2']).where([('field3', '>', Field('field4'))])
t = Table('table')
Expand Down Expand Up @@ -61,7 +62,6 @@ def dummy_fk(table, field):
sel.where = And((join.left.field1 == join.right.name,))
expect(tuple(sql)).to(equal(tuple(sel)))


with it('must support joins'):
def dummy_fk(table, field):
fks = {
Expand Down Expand Up @@ -456,3 +456,44 @@ def dummy_fk(table, field):
('parent_id.ean13', '=', '3020178572427')
])
expect(q.parser).to(not_(equal(parser)))

with it('must support different joins'):

def dummy_fk(table, field):
fks = {
'table_2': {
'constraint_name': 'fk_contraint_name_3',
'table_name': 'table',
'column_name': 'table_2',
'foreign_table_name': 'table2',
'foreign_column_name': 'id'
},
'table_3': {
'constraint_name': 'fk_contraint_name_3',
'table_name': 'table',
'column_name': 'table_3',
'foreign_table_name': 'table3',
'foreign_column_name': 'id'
}
}
return fks[field]

q = OOQuery('table', dummy_fk)
sql = q.select(
['field1', 'field2', LeftJoin('table_2.name')],
).where([
(LeftOuterJoin('table_3.code'), '=', 'XXX')
])
t = Table('table')
t2 = Table('table2')
t3 = Table('table3')

join = t.join(t2, type_='LEFT')
join.condition = join.left.table_2 == join.right.id

join2 = join.join(t3, type_='LEFT OUTER')
join2.condition = join.left.table_3 == join2.right.id

sel = join2.select(t.field1.as_('field1'), t.field2.as_('field2'), t2.name.as_('table_2.name'))
sel.where = And((join2.right.code == 'XXX',))
expect(tuple(sql)).to(equal(tuple(sel)))

0 comments on commit 170d72c

Please sign in to comment.