Skip to content

Commit

Permalink
don't check key names on index creation - allow creating indexes on s…
Browse files Browse the repository at this point in the history
…ub-objects
  • Loading branch information
Mike Dirolf committed Jun 2, 2009
1 parent 61d9597 commit 1e9f815
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 38 deletions.
50 changes: 26 additions & 24 deletions pymongo/_cbsonmodule.c
Expand Up @@ -48,7 +48,7 @@ typedef struct {
int position;
} bson_buffer;

static int write_dict(bson_buffer* buffer, PyObject* dict, unsigned char no_dollar_sign);
static int write_dict(bson_buffer* buffer, PyObject* dict, unsigned char check_keys);
static PyObject* elements_to_dict(const char* string, int max);

static bson_buffer* buffer_new(void) {
Expand Down Expand Up @@ -177,7 +177,7 @@ static int write_string(bson_buffer* buffer, PyObject* py_string) {

/* TODO our platform better be little-endian w/ 4-byte ints! */
/* returns 0 on failure */
static int write_element_to_buffer(bson_buffer* buffer, int type_byte, PyObject* value, unsigned char no_dollar_sign) {
static int write_element_to_buffer(bson_buffer* buffer, int type_byte, PyObject* value, unsigned char check_keys) {
/* TODO this isn't quite the same as the Python version:
* here we check for type equivalence, not isinstance in some
* places. */
Expand Down Expand Up @@ -206,7 +206,7 @@ static int write_element_to_buffer(bson_buffer* buffer, int type_byte, PyObject*
return 1;
} else if (PyDict_Check(value)) {
*(buffer->buffer + type_byte) = 0x03;
return write_dict(buffer, value, no_dollar_sign);
return write_dict(buffer, value, check_keys);
} else if (PyList_CheckExact(value)) {
*(buffer->buffer + type_byte) = 0x04;
int start_position = buffer->position;
Expand Down Expand Up @@ -237,7 +237,7 @@ static int write_element_to_buffer(bson_buffer* buffer, int type_byte, PyObject*
free(name);

PyObject* item_value = PyList_GetItem(value, i);
if (!write_element_to_buffer(buffer, list_type_byte, item_value, no_dollar_sign)) {
if (!write_element_to_buffer(buffer, list_type_byte, item_value, check_keys)) {
return 0;
}
}
Expand Down Expand Up @@ -421,7 +421,7 @@ static int write_element_to_buffer(bson_buffer* buffer, int type_byte, PyObject*
Py_DECREF(id_object);
return 0;
}
if (!write_element_to_buffer(buffer, type_pos, id_object, no_dollar_sign)) {
if (!write_element_to_buffer(buffer, type_pos, id_object, check_keys)) {
Py_DECREF(id_object);
return 0;
}
Expand Down Expand Up @@ -527,25 +527,27 @@ static int write_element_to_buffer(bson_buffer* buffer, int type_byte, PyObject*
return 0;
}

static int check_key_name(const unsigned char no_dollar_sign,
static int check_key_name(const unsigned char check_keys,
const char* name,
const Py_ssize_t name_length) {
if (no_dollar_sign && name_length > 0 && name[0] == '$') {
PyErr_SetString(InvalidName, "key must not start with '$'");
return 0;
}
int i;
for (i = 0; i < name_length; i++) {
if (name[i] == '.') {
PyErr_SetString(InvalidName, "key must not contain '.'");
if (check_keys) {
if (name_length > 0 && name[0] == '$') {
PyErr_SetString(InvalidName, "key must not start with '$'");
return 0;
}
int i;
for (i = 0; i < name_length; i++) {
if (name[i] == '.') {
PyErr_SetString(InvalidName, "key must not contain '.'");
return 0;
}
}
}
return 1;
}

static int write_son(bson_buffer* buffer, PyObject* dict, int start_position,
int length_location, unsigned char no_dollar_sign) {
int length_location, unsigned char check_keys) {
PyObject* keys = PyObject_CallMethod(dict, "keys", NULL);
if (!keys) {
return 0;
Expand Down Expand Up @@ -586,7 +588,7 @@ static int write_son(bson_buffer* buffer, PyObject* dict, int start_position,
Py_DECREF(encoded);
return 0;
}
if (!check_key_name(no_dollar_sign, name, name_length)) {
if (!check_key_name(check_keys, name, name_length)) {
Py_DECREF(keys);
Py_DECREF(encoded);
return 0;
Expand All @@ -597,7 +599,7 @@ static int write_son(bson_buffer* buffer, PyObject* dict, int start_position,
return 0;
}
Py_DECREF(encoded);
if (!write_element_to_buffer(buffer, type_byte, value, no_dollar_sign)) {
if (!write_element_to_buffer(buffer, type_byte, value, check_keys)) {
Py_DECREF(keys);
return 0;
}
Expand All @@ -607,7 +609,7 @@ static int write_son(bson_buffer* buffer, PyObject* dict, int start_position,
}

/* returns 0 on failure */
static int write_dict(bson_buffer* buffer, PyObject* dict, unsigned char no_dollar_sign) {
static int write_dict(bson_buffer* buffer, PyObject* dict, unsigned char check_keys) {
int start_position = buffer->position;

// save space for length
Expand All @@ -617,7 +619,7 @@ static int write_dict(bson_buffer* buffer, PyObject* dict, unsigned char no_doll
}

if (PyObject_IsInstance(dict, SON)) {
if (!write_son(buffer, dict, start_position, length_location, no_dollar_sign)) {
if (!write_son(buffer, dict, start_position, length_location, check_keys)) {
return 0;
}
} else if (PyDict_Check(dict)) {
Expand Down Expand Up @@ -645,7 +647,7 @@ static int write_dict(bson_buffer* buffer, PyObject* dict, unsigned char no_doll
Py_DECREF(encoded);
return 0;
}
if (!check_key_name(no_dollar_sign, name, name_length)) {
if (!check_key_name(check_keys, name, name_length)) {
Py_DECREF(encoded);
return 0;
}
Expand All @@ -654,7 +656,7 @@ static int write_dict(bson_buffer* buffer, PyObject* dict, unsigned char no_doll
return 0;
}
Py_DECREF(encoded);
if (!write_element_to_buffer(buffer, type_byte, value, no_dollar_sign)) {
if (!write_element_to_buffer(buffer, type_byte, value, check_keys)) {
return 0;
}
}
Expand Down Expand Up @@ -696,8 +698,8 @@ static int write_dict(bson_buffer* buffer, PyObject* dict, unsigned char no_doll

static PyObject* _cbson_dict_to_bson(PyObject* self, PyObject* args) {
PyObject* dict;
unsigned char no_dollar_sign;
if (!PyArg_ParseTuple(args, "Ob", &dict, &no_dollar_sign)) {
unsigned char check_keys;
if (!PyArg_ParseTuple(args, "Ob", &dict, &check_keys)) {
return NULL;
}

Expand All @@ -706,7 +708,7 @@ static PyObject* _cbson_dict_to_bson(PyObject* self, PyObject* args) {
return NULL;
}

if (!write_dict(buffer, dict, no_dollar_sign)) {
if (!write_dict(buffer, dict, check_keys)) {
buffer_free(buffer);
return NULL;
}
Expand Down
25 changes: 14 additions & 11 deletions pymongo/bson.py
Expand Up @@ -305,11 +305,12 @@ def _shuffle_oid(data):
return data[7::-1] + data[:7:-1]

_RE_TYPE = type(_valid_array_name)
def _element_to_bson(key, value, check_key_names):
if check_key_names and key.startswith("$"):
raise InvalidName("key %r must not start with '$'" % key)
if "." in key:
raise InvalidName("key %r must not contain '.'" % key)
def _element_to_bson(key, value, check_keys):
if check_keys:
if key.startswith("$"):
raise InvalidName("key %r must not start with '$'" % key)
if "." in key:
raise InvalidName("key %r must not contain '.'" % key)

name = _make_c_string(key)
if isinstance(value, float):
Expand All @@ -334,10 +335,10 @@ def _element_to_bson(key, value, check_key_names):
length = struct.pack("<i", len(cstring))
return "\x02" + name + length + cstring
if isinstance(value, dict):
return "\x03" + name + _dict_to_bson(value, check_key_names)
return "\x03" + name + _dict_to_bson(value, check_keys)
if isinstance(value, (list, tuple)):
as_dict = SON(zip([str(i) for i in range(len(value))], value))
return "\x04" + name + _dict_to_bson(as_dict, check_key_names)
return "\x04" + name + _dict_to_bson(as_dict, check_keys)
if isinstance(value, ObjectId):
return "\x07" + name + _shuffle_oid(str(value))
if value is True:
Expand Down Expand Up @@ -374,9 +375,9 @@ def _element_to_bson(key, value, check_key_names):
return _element_to_bson(key, SON([("$ref", value.collection), ("$id", value.id)]), False)
raise InvalidDocument("cannot convert value of type %s to bson" % type(value))

def _dict_to_bson(dict, check_key_names):
def _dict_to_bson(dict, check_keys):
try:
elements = "".join([_element_to_bson(key, value, check_key_names) for (key, value) in dict.iteritems()])
elements = "".join([_element_to_bson(key, value, check_keys) for (key, value) in dict.iteritems()])
except AttributeError:
raise TypeError("encoder expected a mapping type but got: %r" % dict)

Expand Down Expand Up @@ -436,7 +437,7 @@ def __new__(cls, bson):
"""
return str.__new__(cls, bson)

def from_dict(cls, dict, check_key_names=False):
def from_dict(cls, dict, check_keys=False):
"""Create a new BSON object from a python mapping type (like dict).
Raises TypeError if the argument is not a mapping type, or contains keys
Expand All @@ -445,8 +446,10 @@ def from_dict(cls, dict, check_key_names=False):
:Parameters:
- `dict`: mapping type representing a Mongo document
- `check_keys`: check if keys start with '$' or contain '.',
raising `pymongo.errors.InvalidName` in either case
"""
return cls(_dict_to_bson(dict, check_key_names))
return cls(_dict_to_bson(dict, check_keys))
from_dict = classmethod(from_dict)

def to_dict(self):
Expand Down
9 changes: 6 additions & 3 deletions pymongo/collection.py
Expand Up @@ -148,7 +148,7 @@ def save(self, to_save, manipulate=True, safe=False):
self.update({"_id": to_save["_id"]}, to_save, True, manipulate, safe)
return to_save.get("_id", None)

def insert(self, doc_or_docs, manipulate=True, safe=False):
def insert(self, doc_or_docs, manipulate=True, safe=False, check_keys=True):
"""Insert a document(s) into this collection.
If manipulate is set the document(s) are manipulated using any
Expand All @@ -162,6 +162,8 @@ def insert(self, doc_or_docs, manipulate=True, safe=False):
- `doc_or_docs`: a SON object or list of SON objects to be inserted
- `manipulate` (optional): monipulate the objects before inserting?
- `safe` (optional): check that the insert succeeded?
- `check_keys` (optional): check if keys start with '$' or contain '.',
raising `pymongo.errors.InvalidName` in either case
"""
docs = doc_or_docs
if isinstance(docs, types.DictType):
Expand All @@ -173,7 +175,7 @@ def insert(self, doc_or_docs, manipulate=True, safe=False):
if manipulate:
docs = [self.__database._fix_incoming(doc, self) for doc in docs]

data = [bson.BSON.from_dict(doc, True) for doc in docs]
data = [bson.BSON.from_dict(doc, check_keys) for doc in docs]
self._send_message(2002, "".join(data))

if safe:
Expand Down Expand Up @@ -351,7 +353,8 @@ def create_index(self, key_or_list, direction=None, unique=False, ttl=300):
self.name(),
name, ttl)

self.database().system.indexes.save(to_save, False)
self.database().system.indexes.insert(to_save, manipulate=False,
check_keys=False)
return to_save["name"]

def ensure_index(self, key_or_list, direction=None, unique=False, ttl=300):
Expand Down

0 comments on commit 1e9f815

Please sign in to comment.