Skip to content
This repository has been archived by the owner on Mar 28, 2019. It is now read-only.

Commit

Permalink
Use SQLAlchemy instead of raw pyscopg2
Browse files Browse the repository at this point in the history
SQLA presents several advantages:

* It has advanced pooling features (like connection refresh/recycle)
* It will allow per-request transactions, using pyramid-tm
* It will allow cliquet applications to implement custom storage backend using SQLAlchemy tools

However, it **does not** introduce:

* SQL abstraction or ORM
* Heterogenous backends like SQLite or whatever
  • Loading branch information
leplatrem committed Oct 21, 2015
1 parent ac87071 commit 97c2478
Show file tree
Hide file tree
Showing 14 changed files with 201 additions and 201 deletions.
4 changes: 2 additions & 2 deletions cliquet/__init__.py
Expand Up @@ -27,8 +27,8 @@
'backoff': None,
'batch_max_requests': 25,
'cache_backend': '',
'cache_pool_size': 10,
'cache_url': '',
'cache_pool_size': 10,
'cors_origins': '*',
'cors_max_age_seconds': 3600,
'eos': None,
Expand Down Expand Up @@ -69,9 +69,9 @@
'statsd_prefix': 'cliquet',
'statsd_url': None,
'storage_backend': '',
'storage_url': '',
'storage_max_fetch_size': 10000,
'storage_pool_size': 10,
'storage_url': '',
'userid_hmac_secret': '',
'version_prefix_redirect_enabled': True,
'trailing_slash_redirect_enabled': True,
Expand Down
12 changes: 6 additions & 6 deletions cliquet/cache/postgresql/__init__.py
Expand Up @@ -74,9 +74,9 @@ def ttl(self, key):
AND ttl IS NOT NULL;
"""
with self.client.connect() as cursor:
cursor.execute(query, (key,))
if cursor.rowcount > 0:
return cursor.fetchone()['ttl']
result = cursor.execute(query, (key,))
if result.rowcount > 0:
return result.fetchone()['ttl']
return -1

def expire(self, key, ttl):
Expand Down Expand Up @@ -105,9 +105,9 @@ def get(self, key):
query = "SELECT value FROM cache WHERE key = %s;"
with self.client.connect() as cursor:
cursor.execute(purge)
cursor.execute(query, (key,))
if cursor.rowcount > 0:
value = cursor.fetchone()['value']
result = cursor.execute(query, (key,))
if result.rowcount > 0:
value = result.fetchone()['value']
return json.loads(value)

def delete(self, key):
Expand Down
28 changes: 14 additions & 14 deletions cliquet/permission/postgresql/__init__.py
Expand Up @@ -95,8 +95,8 @@ def user_principals(self, user_id):
FROM user_principals
WHERE user_id = %(user_id)s;"""
with self.client.connect() as cursor:
cursor.execute(query, dict(user_id=user_id))
results = cursor.fetchall()
result = cursor.execute(query, dict(user_id=user_id))
results = result.fetchall()
return set([r['principal'] for r in results])

def add_principal_to_ace(self, object_id, permission, principal):
Expand Down Expand Up @@ -133,9 +133,9 @@ def object_permission_principals(self, object_id, permission):
WHERE object_id = %(object_id)s
AND permission = %(permission)s;"""
with self.client.connect() as cursor:
cursor.execute(query, dict(object_id=object_id,
permission=permission))
results = cursor.fetchall()
result = cursor.execute(query, dict(object_id=object_id,
permission=permission))
results = result.fetchall()
return set([r['principal'] for r in results])

def object_permission_authorized_principals(self, object_id, permission,
Expand All @@ -159,8 +159,8 @@ def object_permission_authorized_principals(self, object_id, permission,
ON (object_id = column1 AND permission = column2);
""" % perms_values
with self.client.connect() as cursor:
cursor.execute(query)
results = cursor.fetchall()
result = cursor.execute(query)
results = result.fetchall()
return set([r['principal'] for r in results])

def principals_accessible_objects(self, principals, permission,
Expand Down Expand Up @@ -202,8 +202,8 @@ def principals_accessible_objects(self, principals, permission,
principals=principals_values)

with self.client.connect() as cursor:
cursor.execute(query, placeholders)
results = cursor.fetchall()
result = cursor.execute(query, placeholders)
results = result.fetchall()
return set([r['object_id'] for r in results])

def check_permission(self, object_id, permission, principals,
Expand Down Expand Up @@ -236,9 +236,9 @@ def check_permission(self, object_id, permission, principals,
""" % dict(perms=perms_values, principals=principals_values)

with self.client.connect() as cursor:
cursor.execute(query)
result = cursor.fetchone()
return result['matched'] > 0
result = cursor.execute(query)
total = result.fetchone()
return total['matched'] > 0

def object_permissions(self, object_id, permissions=None):
query = """
Expand All @@ -252,8 +252,8 @@ def object_permissions(self, object_id, permissions=None):
AND permission IN %(permissions)s;"""
placeholders["permissions"] = tuple(permissions)
with self.client.connect() as cursor:
cursor.execute(query, placeholders)
results = cursor.fetchall()
result = cursor.execute(query, placeholders)
results = result.fetchall()
permissions = defaultdict(set)
for r in results:
permissions[r['permission']].add(r['principal'])
Expand Down
92 changes: 46 additions & 46 deletions cliquet/storage/postgresql/__init__.py
Expand Up @@ -105,9 +105,9 @@ def _check_database_timezone(self):
# Make sure database has UTC timezone.
query = "SELECT current_setting('TIMEZONE') AS timezone;"
with self.client.connect() as cursor:
cursor.execute(query)
result = cursor.fetchone()
timezone = result['timezone'].upper()
result = cursor.execute(query)
record = result.fetchone()
timezone = record['timezone'].upper()
if timezone != 'UTC': # pragma: no cover
msg = 'Database timezone is not UTC (%s)' % timezone
warnings.warn(msg)
Expand All @@ -121,18 +121,18 @@ def _check_database_encoding(self):
WHERE datname = current_database();
"""
with self.client.connect() as cursor:
cursor.execute(query)
result = cursor.fetchone()
encoding = result['encoding'].lower()
result = cursor.execute(query)
record = result.fetchone()
encoding = record['encoding'].lower()
assert encoding == 'utf8', 'Unexpected database encoding %s' % encoding

def _get_installed_version(self):
"""Return current version of schema or None if not any found.
"""
query = "SELECT tablename FROM pg_tables WHERE tablename = 'metadata';"
with self.client.connect() as cursor:
cursor.execute(query)
tables_exist = cursor.rowcount > 0
result = cursor.execute(query)
tables_exist = result.rowcount > 0

if not tables_exist:
return
Expand All @@ -144,14 +144,14 @@ def _get_installed_version(self):
ORDER BY value DESC;
"""
with self.client.connect() as cursor:
cursor.execute(query)
if cursor.rowcount > 0:
return int(cursor.fetchone()['version'])
result = cursor.execute(query)
if result.rowcount > 0:
return int(result.fetchone()['version'])
else:
# Guess current version.
query = "SELECT COUNT(*) FROM metadata;"
cursor.execute(query)
was_flushed = int(cursor.fetchone()[0]) == 0
result = cursor.execute(query)
was_flushed = int(result.fetchone()[0]) == 0
if was_flushed:
error_msg = 'Missing schema history: consider version %s.'
logger.warning(error_msg % self.schema_version)
Expand Down Expand Up @@ -180,9 +180,9 @@ def collection_timestamp(self, collection_id, parent_id, auth=None):
"""
placeholders = dict(parent_id=parent_id, collection_id=collection_id)
with self.client.connect(readonly=True) as cursor:
cursor.execute(query, placeholders)
result = cursor.fetchone()
return result['last_modified']
result = cursor.execute(query, placeholders)
record = result.fetchone()
return record['last_modified']

def create(self, collection_id, parent_id, record, id_generator=None,
unique_fields=None, id_field=DEFAULT_ID_FIELD,
Expand All @@ -207,8 +207,8 @@ def create(self, collection_id, parent_id, record, id_generator=None,
self._check_unicity(cursor, collection_id, parent_id, record,
unique_fields, id_field, modified_field,
for_creation=True)
cursor.execute(query, placeholders)
inserted = cursor.fetchone()
result = cursor.execute(query, placeholders)
inserted = result.fetchone()

record[modified_field] = inserted['last_modified']
return record
Expand All @@ -228,15 +228,15 @@ def get(self, collection_id, parent_id, object_id,
parent_id=parent_id,
collection_id=collection_id)
with self.client.connect(readonly=True) as cursor:
cursor.execute(query, placeholders)
if cursor.rowcount == 0:
result = cursor.execute(query, placeholders)
if result.rowcount == 0:
raise exceptions.RecordNotFoundError(object_id)
else:
result = cursor.fetchone()
existing = result.fetchone()

record = result['data']
record = existing['data']
record[id_field] = object_id
record[modified_field] = result['last_modified']
record[modified_field] = existing['last_modified']
return record

def update(self, collection_id, parent_id, object_id, record,
Expand Down Expand Up @@ -276,13 +276,13 @@ def update(self, collection_id, parent_id, object_id, record,
AND parent_id = %(parent_id)s
AND collection_id = %(collection_id)s;
"""
cursor.execute(query, placeholders)
query = query_update if cursor.rowcount > 0 else query_create
result = cursor.execute(query, placeholders)
query = query_update if result.rowcount > 0 else query_create

cursor.execute(query, placeholders)
result = cursor.fetchone()
result = cursor.execute(query, placeholders)
updated = result.fetchone()

record[modified_field] = result['last_modified']
record[modified_field] = updated['last_modified']
return record

def delete(self, collection_id, parent_id, object_id,
Expand Down Expand Up @@ -319,10 +319,10 @@ def delete(self, collection_id, parent_id, object_id,
collection_id=collection_id)

with self.client.connect() as cursor:
cursor.execute(query, placeholders)
if cursor.rowcount == 0:
result = cursor.execute(query, placeholders)
if result.rowcount == 0:
raise exceptions.RecordNotFoundError(object_id)
inserted = cursor.fetchone()
inserted = result.fetchone()

record = {}
record[modified_field] = inserted['last_modified']
Expand Down Expand Up @@ -375,11 +375,11 @@ def delete_all(self, collection_id, parent_id, filters=None,
placeholders.update(**holders)

with self.client.connect() as cursor:
cursor.execute(query % safeholders, placeholders)
results = cursor.fetchmany(self._max_fetch_size)
result = cursor.execute(query % safeholders, placeholders)
deleted = result.fetchmany(self._max_fetch_size)

records = []
for result in results:
for result in deleted:
record = {}
record[id_field] = result['id']
record[modified_field] = result['last_modified']
Expand Down Expand Up @@ -412,9 +412,9 @@ def purge_deleted(self, collection_id, parent_id, before=None,
placeholders['before'] = before

with self.client.connect() as cursor:
cursor.execute(query % safeholders, placeholders)
result = cursor.execute(query % safeholders, placeholders)

return cursor.rowcount
return result.rowcount

def get_all(self, collection_id, parent_id, filters=None, sorting=None,
pagination_rules=None, limit=None, include_deleted=False,
Expand Down Expand Up @@ -504,16 +504,16 @@ def get_all(self, collection_id, parent_id, filters=None, sorting=None,
safeholders['pagination_limit'] = 'LIMIT %s' % limit

with self.client.connect(readonly=True) as cursor:
cursor.execute(query % safeholders, placeholders)
results = cursor.fetchmany(self._max_fetch_size)
result = cursor.execute(query % safeholders, placeholders)
retrieved = result.fetchmany(self._max_fetch_size)

if not len(results):
if not len(retrieved):
return [], 0

count_total = results[0]['count_total']
count_total = retrieved[0]['count_total']

records = []
for result in results:
for result in retrieved:
record = result['data']
record[id_field] = result['id']
record[modified_field] = result['last_modified']
Expand Down Expand Up @@ -699,11 +699,11 @@ def _check_unicity(self, cursor, collection_id, parent_id, record,
else:
safeholders['condition_record'] = 'TRUE'

cursor.execute(query % safeholders, placeholders)
if cursor.rowcount > 0:
result = cursor.fetchone()
existing = self.get(collection_id, parent_id, result['id'])
raise exceptions.UnicityError(unique_fields[0], existing)
result = cursor.execute(query % safeholders, placeholders)
if result.rowcount > 0:
existing = result.fetchone()
record = self.get(collection_id, parent_id, existing['id'])
raise exceptions.UnicityError(unique_fields[0], record)


def load_from_config(config):
Expand Down

0 comments on commit 97c2478

Please sign in to comment.