Skip to content

Commit

Permalink
Account for values that are arrays/lists (#2607)
Browse files Browse the repository at this point in the history
when checking if keys are valid.
  • Loading branch information
ttmc committed Nov 25, 2018
1 parent 5de2fef commit dcfe23f
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 7 deletions.
6 changes: 4 additions & 2 deletions bigchaindb/backend/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import bigchaindb
from bigchaindb.backend.connection import connect
from bigchaindb.common.exceptions import ValidationError
from bigchaindb.common.utils import validate_all_values_for_key
from bigchaindb.common.utils import validate_all_values_for_key_in_obj, validate_all_values_for_key_in_list

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -101,7 +101,9 @@ def validate_language_key(obj, key):
if backend == 'localmongodb':
data = obj.get(key, {})
if isinstance(data, dict):
validate_all_values_for_key(data, 'language', validate_language)
validate_all_values_for_key_in_obj(data, 'language', validate_language)
elif isinstance(data, list):
validate_all_values_for_key_in_list(data, 'language', validate_language)


def validate_language(value):
Expand Down
32 changes: 27 additions & 5 deletions bigchaindb/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,20 @@ def validate_txn_obj(obj_name, obj, key, validation_fun):
if backend == 'localmongodb':
data = obj.get(key, {})
if isinstance(data, dict):
validate_all_keys(obj_name, data, validation_fun)
validate_all_keys_in_obj(obj_name, data, validation_fun)
elif isinstance(data, list):
validate_all_items_in_list(obj_name, data, validation_fun)


def validate_all_keys(obj_name, obj, validation_fun):
def validate_all_items_in_list(obj_name, data, validation_fun):
for item in data:
if isinstance(item, dict):
validate_all_keys_in_obj(obj_name, item, validation_fun)
elif isinstance(item, list):
validate_all_items_in_list(obj_name, item, validation_fun)


def validate_all_keys_in_obj(obj_name, obj, validation_fun):
"""Validate all (nested) keys in `obj` by using `validation_fun`.
Args:
Expand All @@ -97,10 +107,12 @@ def validate_all_keys(obj_name, obj, validation_fun):
for key, value in obj.items():
validation_fun(obj_name, key)
if isinstance(value, dict):
validate_all_keys(obj_name, value, validation_fun)
validate_all_keys_in_obj(obj_name, value, validation_fun)
elif isinstance(value, list):
validate_all_items_in_list(obj_name, value, validation_fun)


def validate_all_values_for_key(obj, key, validation_fun):
def validate_all_values_for_key_in_obj(obj, key, validation_fun):
"""Validate value for all (nested) occurrence of `key` in `obj`
using `validation_fun`.
Expand All @@ -117,7 +129,17 @@ def validate_all_values_for_key(obj, key, validation_fun):
if vkey == key:
validation_fun(value)
elif isinstance(value, dict):
validate_all_values_for_key(value, key, validation_fun)
validate_all_values_for_key_in_obj(value, key, validation_fun)
elif isinstance(value, list):
validate_all_values_for_key_in_list(value, key, validation_fun)


def validate_all_values_for_key_in_list(input_list, key, validation_fun):
for item in input_list:
if isinstance(item, dict):
validate_all_values_for_key_in_obj(item, key, validation_fun)
elif isinstance(item, list):
validate_all_values_for_key_in_list(item, key, validation_fun)


def validate_key(obj_name, key):
Expand Down

0 comments on commit dcfe23f

Please sign in to comment.