Skip to content

HTTPS clone URL

Subversion checkout URL

You can clone with
or
.
Download ZIP
Browse files

Fixes issue #67; uses primary key instead of hardcoded 'id' column to…

… access instances of models.
  • Loading branch information...
commit 0ded7acd577e62ef78b62007b509a53769ca2790 1 parent e690857
@jfinkels authored
View
2  CHANGES
@@ -13,6 +13,8 @@ Version 0.6
Not yet released.
+- Added support for accessing model instances via arbitrary primary keys,
+ instead of requiring an integer column named ``id``.
- Added support for pagination of responses.
- Fixed issue due to symbolic link from :file:`README` to :file:`README.md`
when running ``pip bundle foobar Flask-Restless``.
View
8 docs/basicusage.rst
@@ -13,7 +13,8 @@ The basic setup for Flask-SQLAlchemy is the same. First, create your
and model classes as usual but with the following two (reasonable) restrictions
on models:
-1. They must have an ``id`` column of type :class:`sqlalchemy.Integer`.
+1. They must have a primary key column of type :class:`sqlalchemy.Integer` or
+ type :class:`sqlalchemy.Unicode`.
2. They must have an ``__init__`` method which accepts keyword arguments for
all columns (the constructor in
:class:`flask.ext.sqlalchemy.SQLAlchemy.Model` supplies such a method, so
@@ -169,3 +170,8 @@ the URL is the value of ``Person.__tablename__``::
"computers": [],
"id": 1
}
+
+If the primary key is a :class:`~sqlalchemy.Unicode` instead of an
+:class:`~sqlalchemy.Integer`, the instances will be accesible at URL endpoints
+like ``http://<host>:<port>/api/person/foo`` instead of
+``http://<host>:<port>/api/person/1``.
View
10 flask_restless/manager.py
@@ -322,7 +322,6 @@ def create_api_blueprint(self, model, methods=READONLY_METHODS,
methods & frozenset(('GET', 'PATCH', 'DELETE', 'PUT'))
# the base URL of the endpoints on which requests will be made
collection_endpoint = '/%s' % collection_name
- instance_endpoint = collection_endpoint + '/<int:instid>'
# the name of the API, for use in creating the view and the blueprint
apiname = APIManager.APINAME_FORMAT % collection_name
# the view function for the API for this model
@@ -344,8 +343,13 @@ def create_api_blueprint(self, model, methods=READONLY_METHODS,
blueprint.add_url_rule(collection_endpoint, defaults={'instid': None},
methods=possibly_empty_instance_methods,
view_func=api_view)
- blueprint.add_url_rule(instance_endpoint, methods=instance_methods,
- view_func=api_view)
+ # the per-instance endpoints will allow both integer and string primary
+ # key accesses
+ for converter in ('int', 'string'):
+ instance_endpoint = '%s/<%s:instid>' % (collection_endpoint,
+ converter)
+ blueprint.add_url_rule(instance_endpoint, methods=instance_methods,
+ view_func=api_view)
# if function evaluation is allowed, add an endpoint at /api/eval/...
# which responds only to GET requests and responds with the result of
# evaluating functions on all instances of the specified model
View
55 flask_restless/views.py
@@ -32,6 +32,7 @@
from sqlalchemy import Date
from sqlalchemy import DateTime
from sqlalchemy.exc import OperationalError
+from sqlalchemy.orm import class_mapper
from sqlalchemy.orm import ColumnProperty
from sqlalchemy.orm import object_mapper
from sqlalchemy.orm import RelationshipProperty
@@ -124,6 +125,23 @@ def _get_relations(model):
return [k for k in cols if isinstance(cols[k].property, RelProperty)]
+def _primary_key_name(model_or_instance):
+ """Returns the name of the primary key of the specified model or instance
+ of a model, as a string.
+
+ If `model_or_instance` specifies multiple primary keys and ``'id'`` is one
+ of them, ``'id'`` is returned. If `model_or_instance` specifies multiple
+ primary keys and ``'id'`` is not one of them, only the name of the first
+ one in the list of primary keys is returned.
+
+ """
+ its_a_model = isinstance(model_or_instance, type)
+ mapper = class_mapper if its_a_model else object_mapper
+ mapped = mapper(model_or_instance)
+ primary_key_names = [key.name for key in mapped.primary_key]
+ return 'id' if 'id' in primary_key_names else primary_key_names[0]
+
+
# This code was adapted from :meth:`elixir.entity.Entity.to_dict` and
# http://stackoverflow.com/q/1958219/108197.
#
@@ -450,8 +468,7 @@ def _add_to_relation(self, query, relationname, toadd=None):
submodel = _get_related_model(self.model, relationname)
for dictionary in toadd or []:
if 'id' in dictionary:
- filtered = self.query(submodel).filter_by(id=dictionary['id'])
- subinst = filtered.first()
+ subinst = self._get_by(dictionary['id'], submodel)
else:
kw = unicode_keys_to_strings(dictionary)
subinst = _get_or_create(self.session, submodel, **kw)[0]
@@ -486,8 +503,7 @@ def _remove_from_relation(self, query, relationname, toremove=None):
for dictionary in toremove or []:
remove = dictionary.pop('__delete__', False)
if 'id' in dictionary:
- filtered = self.query(submodel).filter_by(id=dictionary['id'])
- subinst = filtered.first()
+ subinst = self._get_by(dictionary['id'], submodel)
else:
kw = unicode_keys_to_strings(dictionary)
# TODO document that we use .first() here
@@ -750,6 +766,27 @@ def _check_authentication(self):
and not self.authentication_function()):
abort(401)
+ def _query_by_primary_key(self, primary_key_value, model=None):
+ """Returns a SQLAlchemy query object containing the result of querying
+ `model` (or ``self.model`` if not specified) for instances whose
+ primary key has the value `primary_key_value`.
+
+ Presumably, the returned query should have at most one element.
+
+ """
+ the_model = model or self.model
+ # force unicode primary key name to string; see unicode_keys_to_strings
+ pk_name = str(_primary_key_name(the_model))
+ return self.query(the_model).filter_by(**{pk_name: primary_key_value})
+
+ def _get_by(self, primary_key_value, model=None):
+ """Returns the single instance of `model` (or ``self.model`` if not
+ specified) whose primary key has the value `primary_key_value`, or
+ ``None`` if no such instance exists.
+
+ """
+ return self._query_by_primary_key(primary_key_value, model).first()
+
def get(self, instid):
"""Returns a JSON representation of an instance of model with the
specified name.
@@ -767,7 +804,7 @@ def get(self, instid):
self._check_authentication()
if instid is None:
return self._search()
- inst = self.query().filter_by(id=instid).first()
+ inst = self._get_by(instid)
if inst is None:
abort(404)
relations = _get_relations(self.model)
@@ -785,7 +822,7 @@ def delete(self, instid):
"""
self._check_authentication()
- inst = self.query().filter_by(id=instid).first()
+ inst = self._get_by(instid)
if inst is not None:
self.session.delete(inst)
self.session.commit()
@@ -849,7 +886,9 @@ def post(self):
self.session.add(instance)
self.session.commit()
- return jsonify_status_code(201, id=instance.id)
+ pk_name = str(_primary_key_name(instance))
+ pk_value = getattr(instance, pk_name)
+ return jsonify_status_code(201, **{pk_name: pk_value})
except self.validation_exceptions, exception:
return self._handle_validation_exception(exception)
@@ -888,7 +927,7 @@ def patch(self, instid):
message='Unable to construct query')
else:
# create a SQLAlchemy Query which has exactly the specified row
- query = self.query().filter_by(id=instid)
+ query = self._query_by_primary_key(instid)
assert query.count() == 1, 'Multiple rows with same ID'
relations = self._update_relations(query, data)
View
6 tests/helpers.py
@@ -127,8 +127,14 @@ class Person(self.Base):
birth_date = Column(Date)
computers = relationship('Computer',
backref=backref('owner', lazy='dynamic'))
+
+ class Planet(self.Base):
+ __tablename__ = 'planet'
+ name = Column(Unicode, primary_key=True)
+
self.Person = Person
self.Computer = Computer
+ self.Planet = Planet
# create all the tables required for the models
self.Base.metadata.create_all()
View
17 tests/test_views.py
@@ -814,6 +814,23 @@ def test_pagination(self):
self.assertEqual(loads(response.data)['page'], 1)
self.assertEqual(len(loads(response.data)['objects']), 25)
+ def test_alternate_primary_key(self):
+ """Tests that models with primary keys which are not ``id`` columns are
+ accessible via their primary keys.
+
+ """
+ self.manager.create_api(self.Planet, methods=['GET', 'POST'])
+ response = self.app.post('/api/planet', data=dumps(dict(name='Earth')))
+ self.assertEqual(response.status_code, 201)
+ response = self.app.get('/api/planet/1')
+ self.assertEqual(response.status_code, 404)
+ response = self.app.get('/api/planet')
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(len(loads(response.data)['objects']), 1)
+ response = self.app.get('/api/planet/Earth')
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(loads(response.data), dict(name='Earth'))
+
def load_tests(loader, standard_tests, pattern):
"""Returns the test suite for this module."""
Please sign in to comment.
Something went wrong with that request. Please try again.