Permalink
Browse files

Merge "add database string field length check"

  • Loading branch information...
2 parents 57f1e30 + 9c2c4ec commit e1abe0fca3ccff8fee425796de23899658403b0b Jenkins committed with openstack-gerrit Jan 15, 2013
Showing with 94 additions and 2 deletions.
  1. +40 −0 keystone/common/sql/core.py
  2. +5 −0 keystone/exception.py
  3. +19 −1 tests/test_backend.py
  4. +20 −0 tests/test_backend_sql.py
  5. +1 −1 tests/test_v3.py
  6. +9 −0 tests/test_v3_catalog.py
@@ -23,10 +23,12 @@
import sqlalchemy.orm
import sqlalchemy.pool
from sqlalchemy import types as sql_types
+from sqlalchemy.orm.attributes import InstrumentedAttribute
from keystone.common import logging
from keystone import config
from keystone.openstack.common import jsonutils
+from keystone import exception
CONF = config.CONF
@@ -49,6 +51,44 @@
Text = sql.Text
+def initialize_decorator(init):
+ """Ensure that the length of string field do not exceed the limit.
+
+ This decorator check the initialize arguments, to make sure the
+ length of string field do not exceed the length limit, or raise a
+ 'StringLengthExceeded' exception.
+
+ Use decorator instead of inheritance, because the metaclass will
+ check the __tablename__, primary key columns, etc. at the class
+ definition.
+
+ """
+ def initialize(self, *args, **kwargs):
+ cls = type(self)
+ for k, v in kwargs.items():
+ if hasattr(cls, k):
+ attr = getattr(cls, k)
+ if isinstance(attr, InstrumentedAttribute):
+ column = attr.property.columns[0]
+ if isinstance(column.type, String):
+ if column.type.length and \
+ column.type.length < len(str(v)):
+ #if signing.token_format == 'PKI', the id will
+ #store it's public key which is very long.
+ if config.CONF.signing.token_format == 'PKI' and \
+ self.__tablename__ == 'token' and \
+ k == 'id':
+ continue
+
+ raise exception.StringLengthExceeded(
+ string=v, type=k, length=column.type.length)
+
+ init(self, *args, **kwargs)
+ return initialize
+
+ModelBase.__init__ = initialize_decorator(ModelBase.__init__)
+
+
def set_global_engine(engine):
global GLOBAL_ENGINE
GLOBAL_ENGINE = engine
View
@@ -79,6 +79,11 @@ class ValidationError(Error):
title = 'Bad Request'
+class StringLengthExceeded(ValidationError):
+ """The length of string "%(string)s" exceeded the limit of column
+ %(type)s(CHAR(%(length)d))."""
+
+
class SecurityError(Error):
"""Avoids exposing details of security failures, unless in debug mode."""
View
@@ -1192,7 +1192,7 @@ def test_delete_service_with_endpoint(self):
endpoint = {
'id': uuid.uuid4().hex,
'region': uuid.uuid4().hex,
- 'interface': uuid.uuid4().hex,
+ 'interface': uuid.uuid4().hex[:8],
'url': uuid.uuid4().hex,
'service_id': service['id'],
}
@@ -1240,6 +1240,24 @@ def test_delete_endpoint_404(self):
{},
uuid.uuid4().hex)
+ def test_create_endpoint(self):
+ service = {
+ 'id': uuid.uuid4().hex,
+ 'type': uuid.uuid4().hex,
+ 'name': uuid.uuid4().hex,
+ 'description': uuid.uuid4().hex,
+ }
+ self.catalog_api.create_service(service['id'], service.copy())
+
+ endpoint = {
+ 'id': uuid.uuid4().hex,
+ 'region': "0" * 255,
+ 'service_id': service['id'],
+ 'interface': 'public',
+ 'url': uuid.uuid4().hex,
+ }
+ self.catalog_api.create_endpoint(endpoint['id'], endpoint.copy())
+
class PolicyTests(object):
def _new_policy_ref(self):
View
@@ -264,6 +264,26 @@ def test_get_catalog_with_empty_public_url(self):
self.assertIsNone(catalog_endpoint.get('adminURL'))
self.assertIsNone(catalog_endpoint.get('internalURL'))
+ def test_create_endpoint_400(self):
+ service = {
+ 'id': uuid.uuid4().hex,
+ 'type': uuid.uuid4().hex,
+ 'name': uuid.uuid4().hex,
+ 'description': uuid.uuid4().hex,
+ }
+ self.catalog_api.create_service(service['id'], service.copy())
+
+ endpoint = {
+ 'id': uuid.uuid4().hex,
+ 'region': "0" * 256,
+ 'service_id': service['id'],
+ 'interface': 'public',
+ 'url': uuid.uuid4().hex,
+ }
+
+ with self.assertRaises(exception.StringLengthExceeded):
+ self.catalog_api.create_endpoint(endpoint['id'], endpoint.copy())
+
class SqlPolicy(SqlTests, test_backend.PolicyTests):
pass
View
@@ -42,7 +42,7 @@ def new_service_ref(self):
def new_endpoint_ref(self, service_id):
ref = self.new_ref()
- ref['interface'] = uuid.uuid4().hex
+ ref['interface'] = uuid.uuid4().hex[:8]
ref['service_id'] = service_id
ref['url'] = uuid.uuid4().hex
return ref
View
@@ -119,6 +119,15 @@ def test_create_endpoint(self):
body={'endpoint': ref})
self.assertValidEndpointResponse(r, ref)
+ def assertValidErrorResponse(self, response):
+ self.assertTrue(response.status in [400])
+
+ def test_create_endpoint_400(self):
+ """POST /endpoints"""
+ ref = self.new_endpoint_ref(service_id=self.service_id)
+ ref["region"] = "0" * 256
+ self.post('/endpoints', body={'endpoint': ref}, expected_status=400)
+
def test_get_endpoint(self):
"""GET /endpoints/{endpoint_id}"""
r = self.get(

0 comments on commit e1abe0f

Please sign in to comment.