diff --git a/docs/source/orm/database.rst b/docs/source/orm/database.rst index ddbafed..fa0b6cb 100644 --- a/docs/source/orm/database.rst +++ b/docs/source/orm/database.rst @@ -254,3 +254,34 @@ corresponding to the projection you actually use:: datetime.timedelta(-1, 33857, 32595) >>> obj.birthday datetime.datetime(2015, 5, 1, 14, 35, 42, 967405) + + +Update +______ + +Updating is similar o insert but the main difference is that when you +commonly insert a single row, when you update a table, you can update +a lot of rows in a single query on the database. + +To reflect this, the syntax of update is where clause, updated column +and parameters for the where. For example, if you want to change all +the example_projection object where data is "hello" to goodbye, you +will write:: + + >>> example_projection.update("data=%s", {"data": "goodbye"}, ("hello",)) + +obviously you can use a Where object to make things more readable: + + >>> example_projection.update(Where("data", "=" "%s"), + ... {"data": "goodbye"}, ("hello",)) + +Last bt not least, like with insert you can ask the database for +returning:: + + >>> example_projection.update(Where("data", "=" "%s"), + ... {"data": "goodbye"}, ("hello",), returning="id, num, data") + +or + + >>> example_projection.update(Where("data", "=" "%s"), + ... {"data": "goodbye"}, ("hello",), returning="self") diff --git a/drunken_boat/db/postgresql/__init__.py b/drunken_boat/db/postgresql/__init__.py index fd25a02..cff6dbd 100644 --- a/drunken_boat/db/postgresql/__init__.py +++ b/drunken_boat/db/postgresql/__init__.py @@ -5,6 +5,14 @@ DropError) +def field_is_nullable(field): + """ + Parse a field dictionnary and check if a nullable value is + acceptable by the database + """ + return field["is_nullable"] != "NO" or field["column_default"] is not None + + class DB(DatabaseWrapper): def __init__(self, **conn_params): diff --git a/drunken_boat/db/postgresql/projections.py b/drunken_boat/db/postgresql/projections.py index cec7d12..229f0cc 100644 --- a/drunken_boat/db/postgresql/projections.py +++ b/drunken_boat/db/postgresql/projections.py @@ -1,6 +1,7 @@ import os from drunken_boat.db.postgresql.fields import Field from drunken_boat.db.postgresql.query import Query, Where +from drunken_boat.db.postgresql import field_is_nullable from drunken_boat.db.exceptions import NotFoundError @@ -49,31 +50,37 @@ def get_joins(self, table, *args, **kwargs): for f in field.projection(self.db).fields]) return " ".join(joins) + def get_query_from(self, *args, **kwargs): + where = kwargs.get("where") + if isinstance(where, Where): + where = where() + if not where: + where = self.get_where(*args, **kwargs) + table = self.get_table(*args, **kwargs) + joins = self.get_joins(table, *args, **kwargs) + return "{} {} {}".format( + table, + joins if joins else '', + "WHERE {}".format(where) if where else '' + ) + def select(self, lazy=False, *args, **kwargs): if kwargs.get("query"): # if a query is already given, just use this one query = kwargs["query"] else: - where = kwargs.get("where") - - if isinstance(where, Where): - where = where() - if not where: - where = self.get_where(*args, **kwargs) - table = self.get_table(*args, **kwargs) - joins = self.get_joins(table, *args, **kwargs) + query_from = self.get_query_from(self, *args, **kwargs) fields = [] for field in self.fields: select = field.get_select() if select: fields.append(select) select_query = ", ".join(fields) - query = "SELECT {} FROM {} {} {}".format( + query = "SELECT {} FROM {}".format( select_query, - table, - joins if joins else '', - "WHERE {}".format(where) if where else '' + query_from ) + Q = Query(self, query, params=kwargs.pop("params", None), **kwargs.get("related", {})) @@ -127,6 +134,23 @@ def __init__(self, DB): def get_where(self, *args, **kwargs): pass + def check_constrains(self, values): + """ + For each field on the table check if a value is provided. If no + value is provided, ensure the field is nullable or a default + value is provided. + """ + errors = [] + for field in self.get_table_fields: + if not values.get( + field["column_name"]) and not field_is_nullable(field): + errors.append("{} of type {} is required".format( + field["column_name"], + field["data_type"] + )) + if len(errors) != 0: + raise ValueError("\n".join(errors)) + def get_table(self, *args, **kwargs): if hasattr(self, "Meta"): if hasattr(self.Meta, "table"): @@ -136,6 +160,11 @@ def get_table(self, *args, **kwargs): @property def get_table_fields(self): + """ + Introspect the table to get all the fields of the table and + retreive column_name, data_type, is_nullable and + column_default. + """ result = [] if not hasattr(self, "schema"): schema = "public" @@ -152,35 +181,39 @@ def get_table_fields(self): result.append(dict(zip(fields, elem))) return result - def insert(self, values, returning=None): + def make_returning(self, sql_template, returning=None): + if returning: + returning_fields = [] + if returning == "self": + returning_fields = [f.get_select() for f in self.fields] + else: + returning_fields = [returning] + + if not len(returning_fields) == 0: + sql_template += "returning {}".format( + ",".join(returning_fields)) + return sql_template + + def insert(self, values, returning=None, **kwargs): + """ + Insert a new row into the table checking for constraints. If + returning is set, return the corresponding column(s). If the + special "self" is given to returning, return the + DatabaseObject used by this Projection + """ if not values: raise ValueError("values parameter cannot be an empty dict") - db_fields = self.get_table_fields - errors = [] - for fields in db_fields: - if not values.get(fields["column_name"]) and \ - fields["is_nullable"] == "NO" and \ - fields["column_default"] is None: - errors.append("{} of type {} is required".format( - fields["column_name"], - fields["data_type"] - )) - if len(errors) != 0: - raise ValueError("\n".join(errors)) + + self.check_constrains(values) keys = [] vals = [] - for k, v in values.items(): keys.append(k) vals.append(v) sql_template = "insert into {} ({}) VALUES ({})" - if returning: - if returning == "self": - sql_template += "returning {}".format( - ", ".join([f.db_name for f in self.fields])) - else: - sql_template += "returning {}".format(returning) + sql_template = self.make_returning(sql_template, returning) + sql = sql_template.format( self.Meta.table, ", ".join(tuple(keys)), @@ -188,6 +221,7 @@ def insert(self, values, returning=None): ) params = vals res = None + with self.db.cursor() as cur: try: cur.execute(sql, params) @@ -198,5 +232,52 @@ def insert(self, values, returning=None): res = cur.fetchone() self.db.conn.commit() if returning == "self": - return self.hydrate(res)[0] + results = [self.hydrate(res)[0]] + for field in self.fields: + if hasattr(field, "extra"): + results = field.extra(self, results) + return results[0] return res + + def update(self, where, values, where_params, returning=None, **kwargs): + args = [] + params = [] + for k, v in values.items(): + args.append(k) + params.append(v) + [params.append(p) for p in where_params] + if isinstance(where, Where): + where = where() + joins = self.get_joins(self.Meta.table) + sql_template = """ + UPDATE {table} SET {args} {joins} WHERE {where}""".format( + table=self.Meta.table, + joins=joins if joins else '', + args=", ".join(["{}=%s".format(arg) for arg in args]), + where=where + ) + sql_template = self.make_returning(sql_template, returning) + res = None + with self.db.cursor() as cur: + try: + cur.execute(sql_template, params) + except Exception as e: + self.db.conn.rollback() + raise e + if returning: + res = cur.fetchall() + self.db.conn.commit() + results = [] + if res: + if returning == "self": + for result in res: + results.append( + self.hydrate(result)[0] + ) + for field in self.fields: + if hasattr(field, "extra"): + results = field.extra(self, results) + else: + results = res + + return results diff --git a/drunken_boat/db/postgresql/tests/projections_fixtures.py b/drunken_boat/db/postgresql/tests/projections_fixtures.py index 5a9b121..8bca8db 100644 --- a/drunken_boat/db/postgresql/tests/projections_fixtures.py +++ b/drunken_boat/db/postgresql/tests/projections_fixtures.py @@ -58,6 +58,7 @@ class Meta: class AuthorProjectionReverse(Projection): + name = CharField() books = ReverseForeign( join=["id", "author_id"], projection=BookProjectionReverse @@ -65,3 +66,14 @@ class AuthorProjectionReverse(Projection): class Meta: table = "author" + + +class AuthorProjectionReverseEm(Projection): + + books = ReverseForeign( + join=["id", "author_id"], + projection=BookProjection + ) + + class Meta: + table = "author" diff --git a/drunken_boat/db/postgresql/tests/test_db_projection_integration.py b/drunken_boat/db/postgresql/tests/test_db_projection_integration.py index 5b2685e..3248456 100644 --- a/drunken_boat/db/postgresql/tests/test_db_projection_integration.py +++ b/drunken_boat/db/postgresql/tests/test_db_projection_integration.py @@ -1,4 +1,4 @@ -from psycopg2 import DataError +from psycopg2 import DataError, ProgrammingError import pytest import datetime from drunken_boat.db.exceptions import NotFoundError @@ -200,6 +200,19 @@ def test_projection_inserting(prepare_test): assert isinstance(doc, projections_fixtures.DataBaseObjectWithMeth) +def test_projection_update(prepare_test): + projection = projections_fixtures.TestProjectionWithVirtual(get_test_db()) + pytest.raises(ProgrammingError, projection.update, + Where("id", "=", "%s"), + {"godawa": "Kaboom"}, + (4,)) + p = projection.update(Where("id", "=", "%s"), {"title": "Kaboom"}, (4,)) + assert p == [] + p = projection.update(Where("id", "=", "%s"), + {"title": "hello"}, (4,), returning="id") + assert p == [(4,)] + + def test_projection_foreign(prepare_test): projection_author = projections_fixtures.AuthorProjection(get_test_db()) projection_author.insert( @@ -220,6 +233,35 @@ def test_projection_reverse(prepare_test): assert author.id == book.author_id +def test_projection_reverse_insert(prepare_test): + projection_author = projections_fixtures.AuthorProjectionReverse( + get_test_db()) + p = projection_author.insert({"name": "hello"}, returning="self") + assert p.name == "hello" + assert hasattr(p, "books") + + projection_author_empty = projections_fixtures.AuthorProjectionReverseEm( + get_test_db()) + p = projection_author_empty.insert({"name": "hello"}, returning="self") + assert isinstance(p, DataBaseObject) + + +def test_projection_reverse_update(prepare_test): + projection_author = projections_fixtures.AuthorProjectionReverse( + get_test_db()) + p = projection_author.update("name=%s", {"name": "champomy"}, ("hello",), + returning="self")[0] + assert p.name == "champomy" + + projection_author_empty = projections_fixtures.AuthorProjectionReverseEm( + get_test_db()) + + p = projection_author_empty.update( + "name=%s", {"name": "hello"}, ("champomy",), + returning="self") + assert p[0].id + + def test_projection_reverse_with_params(prepare_test): projection_author = projections_fixtures.AuthorProjectionReverse( get_test_db())