Skip to content

Commit

Permalink
disallow insertion of invalid key names
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike Dirolf committed May 28, 2009
1 parent 92b4339 commit c59a09d
Show file tree
Hide file tree
Showing 5 changed files with 88 additions and 27 deletions.
64 changes: 49 additions & 15 deletions pymongo/_cbsonmodule.c
Expand Up @@ -25,7 +25,7 @@
#include <time.h>

static PyObject* CBSONError;
static PyObject* _cbson_dict_to_bson(PyObject* self, PyObject* dict);
static PyObject* InvalidName;
static PyObject* SON;
static PyObject* Binary;
static PyObject* Code;
Expand All @@ -48,7 +48,7 @@ typedef struct {
int position;
} bson_buffer;

static int write_dict(bson_buffer* buffer, PyObject* dict);
static int write_dict(bson_buffer* buffer, PyObject* dict, unsigned char no_dollar_sign);
static PyObject* elements_to_dict(const char* string, int max);

static bson_buffer* buffer_new(void) {
Expand Down Expand Up @@ -177,7 +177,7 @@ static int write_string(bson_buffer* buffer, PyObject* py_string) {

/* TODO our platform better be little-endian w/ 4-byte ints! */
/* returns 0 on failure */
static int write_element_to_buffer(bson_buffer* buffer, int type_byte, PyObject* value) {
static int write_element_to_buffer(bson_buffer* buffer, int type_byte, PyObject* value, unsigned char no_dollar_sign) {
/* TODO this isn't quite the same as the Python version:
* here we check for type equivalence, not isinstance in some
* places. */
Expand Down Expand Up @@ -206,7 +206,7 @@ static int write_element_to_buffer(bson_buffer* buffer, int type_byte, PyObject*
return 1;
} else if (PyDict_Check(value)) {
*(buffer->buffer + type_byte) = 0x03;
return write_dict(buffer, value);
return write_dict(buffer, value, no_dollar_sign);
} else if (PyList_CheckExact(value)) {
*(buffer->buffer + type_byte) = 0x04;
int start_position = buffer->position;
Expand Down Expand Up @@ -237,7 +237,7 @@ static int write_element_to_buffer(bson_buffer* buffer, int type_byte, PyObject*
free(name);

PyObject* item_value = PyList_GetItem(value, i);
if (!write_element_to_buffer(buffer, list_type_byte, item_value)) {
if (!write_element_to_buffer(buffer, list_type_byte, item_value, no_dollar_sign)) {
return 0;
}
}
Expand Down Expand Up @@ -303,7 +303,7 @@ static int write_element_to_buffer(bson_buffer* buffer, int type_byte, PyObject*
if (!scope) {
return 0;
}
if (!write_dict(buffer, scope)) {
if (!write_dict(buffer, scope, 0)) {
Py_DECREF(scope);
return 0;
}
Expand Down Expand Up @@ -421,7 +421,7 @@ static int write_element_to_buffer(bson_buffer* buffer, int type_byte, PyObject*
Py_DECREF(id_object);
return 0;
}
if (!write_element_to_buffer(buffer, type_pos, id_object)) {
if (!write_element_to_buffer(buffer, type_pos, id_object, no_dollar_sign)) {
Py_DECREF(id_object);
return 0;
}
Expand Down Expand Up @@ -527,7 +527,25 @@ static int write_element_to_buffer(bson_buffer* buffer, int type_byte, PyObject*
return 0;
}

static int write_son(bson_buffer* buffer, PyObject* dict, int start_position, int length_location) {
static int check_key_name(const unsigned char no_dollar_sign,
const char* name,
const Py_ssize_t name_length) {
if (no_dollar_sign && name_length > 0 && name[0] == '$') {
PyErr_SetString(InvalidName, "key must not contain '$'");
return 0;
}
int i;
for (i = 0; i < name_length; i++) {
if (name[i] == '.') {
PyErr_SetString(InvalidName, "key must not contain '.'");
return 0;
}
}
return 1;
}

static int write_son(bson_buffer* buffer, PyObject* dict, int start_position,
int length_location, unsigned char no_dollar_sign) {
PyObject* keys = PyObject_CallMethod(dict, "keys", NULL);
if (!keys) {
return 0;
Expand Down Expand Up @@ -568,13 +586,18 @@ static int write_son(bson_buffer* buffer, PyObject* dict, int start_position, in
Py_DECREF(encoded);
return 0;
}
if (!check_key_name(no_dollar_sign, name, name_length)) {
Py_DECREF(keys);
Py_DECREF(encoded);
return 0;
}
if (!buffer_write_bytes(buffer, name, name_length + 1)) {
Py_DECREF(keys);
Py_DECREF(encoded);
return 0;
}
Py_DECREF(encoded);
if (!write_element_to_buffer(buffer, type_byte, value)) {
if (!write_element_to_buffer(buffer, type_byte, value, no_dollar_sign)) {
Py_DECREF(keys);
return 0;
}
Expand All @@ -584,7 +607,7 @@ static int write_son(bson_buffer* buffer, PyObject* dict, int start_position, in
}

/* returns 0 on failure */
static int write_dict(bson_buffer* buffer, PyObject* dict) {
static int write_dict(bson_buffer* buffer, PyObject* dict, unsigned char no_dollar_sign) {
int start_position = buffer->position;

// save space for length
Expand All @@ -594,7 +617,7 @@ static int write_dict(bson_buffer* buffer, PyObject* dict) {
}

if (PyObject_IsInstance(dict, SON)) {
if (!write_son(buffer, dict, start_position, length_location)) {
if (!write_son(buffer, dict, start_position, length_location, no_dollar_sign)) {
return 0;
}
} else if (PyDict_Check(dict)) {
Expand Down Expand Up @@ -622,12 +645,16 @@ static int write_dict(bson_buffer* buffer, PyObject* dict) {
Py_DECREF(encoded);
return 0;
}
if (!check_key_name(no_dollar_sign, name, name_length)) {
Py_DECREF(encoded);
return 0;
}
if (!buffer_write_bytes(buffer, name, name_length + 1)) {
Py_DECREF(encoded);
return 0;
}
Py_DECREF(encoded);
if (!write_element_to_buffer(buffer, type_byte, value)) {
if (!write_element_to_buffer(buffer, type_byte, value, no_dollar_sign)) {
return 0;
}
}
Expand Down Expand Up @@ -667,13 +694,19 @@ static int write_dict(bson_buffer* buffer, PyObject* dict) {
return 1;
}

static PyObject* _cbson_dict_to_bson(PyObject* self, PyObject* dict) {
static PyObject* _cbson_dict_to_bson(PyObject* self, PyObject* args) {
PyObject* dict;
unsigned char no_dollar_sign;
if (!PyArg_ParseTuple(args, "Ob", &dict, &no_dollar_sign)) {
return NULL;
}

bson_buffer* buffer = buffer_new();
if (!buffer) {
return NULL;
}

if (!write_dict(buffer, dict)) {
if (!write_dict(buffer, dict, no_dollar_sign)) {
buffer_free(buffer);
return NULL;
}
Expand Down Expand Up @@ -1000,7 +1033,7 @@ static PyObject* _cbson_bson_to_dict(PyObject* self, PyObject* bson) {
}

static PyMethodDef _CBSONMethods[] = {
{"_dict_to_bson", _cbson_dict_to_bson, METH_O,
{"_dict_to_bson", _cbson_dict_to_bson, METH_VARARGS,
"convert a dictionary to a string containing it's BSON representation."},
{"_bson_to_dict", _cbson_bson_to_dict, METH_O,
"convert a BSON string to a SON object."},
Expand All @@ -1017,6 +1050,7 @@ PyMODINIT_FUNC init_cbson(void) {

PyObject* errors_module = PyImport_ImportModule("pymongo.errors");
CBSONError = PyObject_GetAttrString(errors_module, "InvalidDocument");
InvalidName = PyObject_GetAttrString(errors_module, "InvalidName");
Py_DECREF(errors_module);

PyObject* son_module = PyImport_ImportModule("pymongo.son");
Expand Down
25 changes: 15 additions & 10 deletions pymongo/bson.py
Expand Up @@ -28,7 +28,7 @@
from objectid import ObjectId
from dbref import DBRef
from son import SON
from errors import InvalidBSON, InvalidDocument, UnsupportedTag
from errors import InvalidBSON, InvalidDocument, UnsupportedTag, InvalidName

try:
import _cbson
Expand Down Expand Up @@ -305,7 +305,12 @@ def _shuffle_oid(data):
return data[7::-1] + data[:7:-1]

_RE_TYPE = type(_valid_array_name)
def _element_to_bson(key, value):
def _element_to_bson(key, value, check_key_names):
if check_key_names and key.startswith("$"):
raise InvalidName("key %r must not start with '$'" % key)
if "." in key:
raise InvalidName("key %r must not contain '.'" % key)

name = _make_c_string(key)
if isinstance(value, float):
return "\x01" + name + struct.pack("<d", value)
Expand All @@ -316,7 +321,7 @@ def _element_to_bson(key, value):
return "\x05" + name + struct.pack("<i", len(value)) + chr(subtype) + value
if isinstance(value, Code):
cstring = _make_c_string(value)
scope = _dict_to_bson(value.scope)
scope = _dict_to_bson(value.scope, False)
full_length = struct.pack("<i", 8 + len(cstring) + len(scope))
length = struct.pack("<i", len(cstring))
return "\x0F" + name + full_length + length + cstring + scope
Expand All @@ -329,10 +334,10 @@ def _element_to_bson(key, value):
length = struct.pack("<i", len(cstring))
return "\x02" + name + length + cstring
if isinstance(value, dict):
return "\x03" + name + _dict_to_bson(value)
return "\x03" + name + _dict_to_bson(value, check_key_names)
if isinstance(value, (list, tuple)):
as_dict = SON(zip([str(i) for i in range(len(value))], value))
return "\x04" + name + _dict_to_bson(as_dict)
return "\x04" + name + _dict_to_bson(as_dict, check_key_names)
if isinstance(value, ObjectId):
return "\x07" + name + _shuffle_oid(str(value))
if value is True:
Expand Down Expand Up @@ -366,12 +371,12 @@ def _element_to_bson(key, value):
flags += "x"
return "\x0B" + name + _make_c_string(pattern) + _make_c_string(flags)
if isinstance(value, DBRef):
return _element_to_bson(key, SON([("$ref", value.collection), ("$id", value.id)]))
return _element_to_bson(key, SON([("$ref", value.collection), ("$id", value.id)]), False)
raise InvalidDocument("cannot convert value of type %s to bson" % type(value))

def _dict_to_bson(dict):
def _dict_to_bson(dict, check_key_names):
try:
elements = "".join([_element_to_bson(key, value) for (key, value) in dict.iteritems()])
elements = "".join([_element_to_bson(key, value, check_key_names) for (key, value) in dict.iteritems()])
except AttributeError:
raise TypeError("encoder expected a mapping type but got: %r" % dict)

Expand Down Expand Up @@ -431,7 +436,7 @@ def __new__(cls, bson):
"""
return str.__new__(cls, bson)

def from_dict(cls, dict):
def from_dict(cls, dict, check_key_names=False):
"""Create a new BSON object from a python mapping type (like dict).
Raises TypeError if the argument is not a mapping type, or contains keys
Expand All @@ -441,7 +446,7 @@ def from_dict(cls, dict):
:Parameters:
- `dict`: mapping type representing a Mongo document
"""
return cls(_dict_to_bson(dict))
return cls(_dict_to_bson(dict, check_key_names))
from_dict = classmethod(from_dict)

def to_dict(self):
Expand Down
2 changes: 1 addition & 1 deletion pymongo/collection.py
Expand Up @@ -173,7 +173,7 @@ def insert(self, doc_or_docs, manipulate=True, safe=False):
if manipulate:
docs = [self.__database._fix_incoming(doc, self) for doc in docs]

data = [bson.BSON.from_dict(doc) for doc in docs]
data = [bson.BSON.from_dict(doc, True) for doc in docs]
self._send_message(2002, "".join(data))

if safe:
Expand Down
2 changes: 1 addition & 1 deletion test/qcheck.py
Expand Up @@ -70,7 +70,7 @@ def gen_unichar():
return lambda: unichr(random.randint(1, 0xFFF))

def gen_unicode(gen_length):
return lambda: u"".join(gen_list(gen_unichar(), gen_length)())
return lambda: u"".join([x for x in gen_list(gen_unichar(), gen_length)() if x not in ".$"])

def gen_list(generator, gen_length):
return lambda: [generator() for _ in range(gen_length())]
Expand Down
22 changes: 22 additions & 0 deletions test/test_collection.py
Expand Up @@ -285,6 +285,28 @@ def iterate():

self.assertRaises(TypeError, iterate)

def test_invalid_key_names(self):
db = self.db
db.test.remove({})

db.test.insert({"hello": "world"})
db.test.insert({"hello": {"hello": "world"}})

self.assertRaises(InvalidName, db.test.insert, {"$hello": "world"})
self.assertRaises(InvalidName, db.test.insert, {"hello": {"$hello": "world"}})

db.test.insert({"he$llo": "world"})
db.test.insert({"hello": {"hello$": "world"}})

self.assertRaises(InvalidName, db.test.insert, {".hello": "world"})
self.assertRaises(InvalidName, db.test.insert, {"hello": {".hello": "world"}})
self.assertRaises(InvalidName, db.test.insert, {"hello.": "world"})
self.assertRaises(InvalidName, db.test.insert, {"hello": {"hello.": "world"}})
self.assertRaises(InvalidName, db.test.insert, {"hel.lo": "world"})
self.assertRaises(InvalidName, db.test.insert, {"hello": {"hel.lo": "world"}})

db.test.update({"hello": "world"}, {"$inc": "hello"})

def test_insert_multiple(self):
db = self.db
db.drop_collection("test")
Expand Down

0 comments on commit c59a09d

Please sign in to comment.