Skip to content

Commit

Permalink
fix hasadna/open-bus#387 - ensure closing of DB sessions
Browse files Browse the repository at this point in the history
  • Loading branch information
OriHoch committed Mar 11, 2022
1 parent 9161cf0 commit 7bba3db
Showing 1 changed file with 33 additions and 33 deletions.
66 changes: 33 additions & 33 deletions open_bus_stride_api/routers/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,45 +5,45 @@


def get_list(*args, convert_to_dict=None, **kwargs):
if convert_to_dict is None:
return [obj.__dict__ for obj in get_list_query(*args, **kwargs)]
else:
return [convert_to_dict(obj) for obj in get_list_query(*args, **kwargs)]
with get_session() as session:
if convert_to_dict is None:
return [obj.__dict__ for obj in get_list_query(session, *args, **kwargs)]
else:
return [convert_to_dict(obj) for obj in get_list_query(session, *args, **kwargs)]


def get_list_query(db_model, limit, offset, filters=None, max_limit=1000,
def get_list_query(session, db_model, limit, offset, filters=None, max_limit=1000,
order_by=None, allowed_order_by_fields=None,
post_session_query_hook=None):
if filters is None:
filters = []
with get_session() as session:
session_query = session.query(db_model)
if post_session_query_hook:
session_query = post_session_query_hook(session_query)
for filter in filters:
session_query = globals()['get_list_query_filter_{}'.format(filter['type'])](session_query, filters, filter)
if order_by:
order_by_args = []
for ob in order_by.split(','):
ob = ob.strip()
if not ob:
continue
ob = ob.split()
if len(ob) == 1:
field_name = ob[0]
direction = None
else:
field_name, direction = ob
assert not allowed_order_by_fields or field_name in allowed_order_by_fields, 'field name is not in allowed order_by fields: {}'.format(field_name)
order_by_args.append((sqlalchemy.desc if direction == 'desc' else sqlalchemy.asc)(getattr(db_model, field_name)))
session_query = session_query.order_by(*order_by_args)
if not limit and max_limit:
limit = max_limit
if limit:
session_query = session_query.limit(limit)
if offset:
session_query = session_query.offset(offset)
return session_query
session_query = session.query(db_model)
if post_session_query_hook:
session_query = post_session_query_hook(session_query)
for filter in filters:
session_query = globals()['get_list_query_filter_{}'.format(filter['type'])](session_query, filters, filter)
if order_by:
order_by_args = []
for ob in order_by.split(','):
ob = ob.strip()
if not ob:
continue
ob = ob.split()
if len(ob) == 1:
field_name = ob[0]
direction = None
else:
field_name, direction = ob
assert not allowed_order_by_fields or field_name in allowed_order_by_fields, 'field name is not in allowed order_by fields: {}'.format(field_name)
order_by_args.append((sqlalchemy.desc if direction == 'desc' else sqlalchemy.asc)(getattr(db_model, field_name)))
session_query = session_query.order_by(*order_by_args)
if not limit and max_limit:
limit = max_limit
if limit:
session_query = session_query.limit(limit)
if offset:
session_query = session_query.offset(offset)
return session_query


def get_list_query_filter_equals(session_query, filters, filter):
Expand Down

0 comments on commit 7bba3db

Please sign in to comment.