Skip to content

Commit

Permalink
Merge 7a2e8d1 into 259af57
Browse files Browse the repository at this point in the history
  • Loading branch information
boblefrag committed May 11, 2015
2 parents 259af57 + 7a2e8d1 commit ee456d2
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 34 deletions.
31 changes: 31 additions & 0 deletions docs/source/orm/database.rst
Original file line number Diff line number Diff line change
Expand Up @@ -254,3 +254,34 @@ corresponding to the projection you actually use::
datetime.timedelta(-1, 33857, 32595)
>>> obj.birthday
datetime.datetime(2015, 5, 1, 14, 35, 42, 967405)


Update
______

Updating is similar o insert but the main difference is that when you
commonly insert a single row, when you update a table, you can update
a lot of rows in a single query on the database.

To reflect this, the syntax of update is where clause, updated column
and parameters for the where. For example, if you want to change all
the example_projection object where data is "hello" to goodbye, you
will write::

>>> example_projection.update("data=%s", {"data": "goodbye"}, ("hello",))

obviously you can use a Where object to make things more readable:

>>> example_projection.update(Where("data", "=" "%s"),
... {"data": "goodbye"}, ("hello",))

Last bt not least, like with insert you can ask the database for
returning::

>>> example_projection.update(Where("data", "=" "%s"),
... {"data": "goodbye"}, ("hello",), returning="id, num, data")

or

>>> example_projection.update(Where("data", "=" "%s"),
... {"data": "goodbye"}, ("hello",), returning="self")
8 changes: 8 additions & 0 deletions drunken_boat/db/postgresql/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,14 @@
DropError)


def field_is_nullable(field):
"""
Parse a field dictionnary and check if a nullable value is
acceptable by the database
"""
return field["is_nullable"] != "NO" or field["column_default"] is not None


class DB(DatabaseWrapper):

def __init__(self, **conn_params):
Expand Down
147 changes: 114 additions & 33 deletions drunken_boat/db/postgresql/projections.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from drunken_boat.db.postgresql.fields import Field
from drunken_boat.db.postgresql.query import Query, Where
from drunken_boat.db.postgresql import field_is_nullable
from drunken_boat.db.exceptions import NotFoundError


Expand Down Expand Up @@ -49,31 +50,37 @@ def get_joins(self, table, *args, **kwargs):
for f in field.projection(self.db).fields])
return " ".join(joins)

def get_query_from(self, *args, **kwargs):
where = kwargs.get("where")
if isinstance(where, Where):
where = where()
if not where:
where = self.get_where(*args, **kwargs)
table = self.get_table(*args, **kwargs)
joins = self.get_joins(table, *args, **kwargs)
return "{} {} {}".format(
table,
joins if joins else '',
"WHERE {}".format(where) if where else ''
)

def select(self, lazy=False, *args, **kwargs):
if kwargs.get("query"):
# if a query is already given, just use this one
query = kwargs["query"]
else:
where = kwargs.get("where")

if isinstance(where, Where):
where = where()
if not where:
where = self.get_where(*args, **kwargs)
table = self.get_table(*args, **kwargs)
joins = self.get_joins(table, *args, **kwargs)
query_from = self.get_query_from(self, *args, **kwargs)
fields = []
for field in self.fields:
select = field.get_select()
if select:
fields.append(select)
select_query = ", ".join(fields)
query = "SELECT {} FROM {} {} {}".format(
query = "SELECT {} FROM {}".format(
select_query,
table,
joins if joins else '',
"WHERE {}".format(where) if where else ''
query_from
)

Q = Query(self, query,
params=kwargs.pop("params", None),
**kwargs.get("related", {}))
Expand Down Expand Up @@ -127,6 +134,23 @@ def __init__(self, DB):
def get_where(self, *args, **kwargs):
pass

def check_constrains(self, values):
"""
For each field on the table check if a value is provided. If no
value is provided, ensure the field is nullable or a default
value is provided.
"""
errors = []
for field in self.get_table_fields:
if not values.get(
field["column_name"]) and not field_is_nullable(field):
errors.append("{} of type {} is required".format(
field["column_name"],
field["data_type"]
))
if len(errors) != 0:
raise ValueError("\n".join(errors))

def get_table(self, *args, **kwargs):
if hasattr(self, "Meta"):
if hasattr(self.Meta, "table"):
Expand All @@ -136,6 +160,11 @@ def get_table(self, *args, **kwargs):

@property
def get_table_fields(self):
"""
Introspect the table to get all the fields of the table and
retreive column_name, data_type, is_nullable and
column_default.
"""
result = []
if not hasattr(self, "schema"):
schema = "public"
Expand All @@ -152,42 +181,47 @@ def get_table_fields(self):
result.append(dict(zip(fields, elem)))
return result

def insert(self, values, returning=None):
def make_returning(self, sql_template, returning=None):
if returning:
returning_fields = []
if returning == "self":
returning_fields = [f.get_select() for f in self.fields]
else:
returning_fields = [returning]

if not len(returning_fields) == 0:
sql_template += "returning {}".format(
",".join(returning_fields))
return sql_template

def insert(self, values, returning=None, **kwargs):
"""
Insert a new row into the table checking for constraints. If
returning is set, return the corresponding column(s). If the
special "self" is given to returning, return the
DatabaseObject used by this Projection
"""
if not values:
raise ValueError("values parameter cannot be an empty dict")
db_fields = self.get_table_fields
errors = []
for fields in db_fields:
if not values.get(fields["column_name"]) and \
fields["is_nullable"] == "NO" and \
fields["column_default"] is None:
errors.append("{} of type {} is required".format(
fields["column_name"],
fields["data_type"]
))
if len(errors) != 0:
raise ValueError("\n".join(errors))

self.check_constrains(values)
keys = []
vals = []

for k, v in values.items():
keys.append(k)
vals.append(v)

sql_template = "insert into {} ({}) VALUES ({})"
if returning:
if returning == "self":
sql_template += "returning {}".format(
", ".join([f.db_name for f in self.fields]))
else:
sql_template += "returning {}".format(returning)
sql_template = self.make_returning(sql_template, returning)

sql = sql_template.format(
self.Meta.table,
", ".join(tuple(keys)),
", ".join(["%s" for k in keys])
)
params = vals
res = None

with self.db.cursor() as cur:
try:
cur.execute(sql, params)
Expand All @@ -198,5 +232,52 @@ def insert(self, values, returning=None):
res = cur.fetchone()
self.db.conn.commit()
if returning == "self":
return self.hydrate(res)[0]
results = [self.hydrate(res)[0]]
for field in self.fields:
if hasattr(field, "extra"):
results = field.extra(self, results)
return results[0]
return res

def update(self, where, values, where_params, returning=None, **kwargs):
args = []
params = []
for k, v in values.items():
args.append(k)
params.append(v)
[params.append(p) for p in where_params]
if isinstance(where, Where):
where = where()
joins = self.get_joins(self.Meta.table)
sql_template = """
UPDATE {table} SET {args} {joins} WHERE {where}""".format(
table=self.Meta.table,
joins=joins if joins else '',
args=", ".join(["{}=%s".format(arg) for arg in args]),
where=where
)
sql_template = self.make_returning(sql_template, returning)
res = None
with self.db.cursor() as cur:
try:
cur.execute(sql_template, params)
except Exception as e:
self.db.conn.rollback()
raise e
if returning:
res = cur.fetchall()
self.db.conn.commit()
results = []
if res:
if returning == "self":
for result in res:
results.append(
self.hydrate(result)[0]
)
for field in self.fields:
if hasattr(field, "extra"):
results = field.extra(self, results)
else:
results = res

return results
12 changes: 12 additions & 0 deletions drunken_boat/db/postgresql/tests/projections_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,22 @@ class Meta:


class AuthorProjectionReverse(Projection):
name = CharField()
books = ReverseForeign(
join=["id", "author_id"],
projection=BookProjectionReverse
)

class Meta:
table = "author"


class AuthorProjectionReverseEm(Projection):

books = ReverseForeign(
join=["id", "author_id"],
projection=BookProjection
)

class Meta:
table = "author"
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from psycopg2 import DataError
from psycopg2 import DataError, ProgrammingError
import pytest
import datetime
from drunken_boat.db.exceptions import NotFoundError
Expand Down Expand Up @@ -200,6 +200,19 @@ def test_projection_inserting(prepare_test):
assert isinstance(doc, projections_fixtures.DataBaseObjectWithMeth)


def test_projection_update(prepare_test):
projection = projections_fixtures.TestProjectionWithVirtual(get_test_db())
pytest.raises(ProgrammingError, projection.update,
Where("id", "=", "%s"),
{"godawa": "Kaboom"},
(4,))
p = projection.update(Where("id", "=", "%s"), {"title": "Kaboom"}, (4,))
assert p == []
p = projection.update(Where("id", "=", "%s"),
{"title": "hello"}, (4,), returning="id")
assert p == [(4,)]


def test_projection_foreign(prepare_test):
projection_author = projections_fixtures.AuthorProjection(get_test_db())
projection_author.insert(
Expand All @@ -220,6 +233,35 @@ def test_projection_reverse(prepare_test):
assert author.id == book.author_id


def test_projection_reverse_insert(prepare_test):
projection_author = projections_fixtures.AuthorProjectionReverse(
get_test_db())
p = projection_author.insert({"name": "hello"}, returning="self")
assert p.name == "hello"
assert hasattr(p, "books")

projection_author_empty = projections_fixtures.AuthorProjectionReverseEm(
get_test_db())
p = projection_author_empty.insert({"name": "hello"}, returning="self")
assert isinstance(p, DataBaseObject)


def test_projection_reverse_update(prepare_test):
projection_author = projections_fixtures.AuthorProjectionReverse(
get_test_db())
p = projection_author.update("name=%s", {"name": "champomy"}, ("hello",),
returning="self")[0]
assert p.name == "champomy"

projection_author_empty = projections_fixtures.AuthorProjectionReverseEm(
get_test_db())

p = projection_author_empty.update(
"name=%s", {"name": "hello"}, ("champomy",),
returning="self")
assert p[0].id


def test_projection_reverse_with_params(prepare_test):
projection_author = projections_fixtures.AuthorProjectionReverse(
get_test_db())
Expand Down

0 comments on commit ee456d2

Please sign in to comment.