diff --git a/pygit2/index.py b/pygit2/index.py index 45b37a61b..56ed31917 100644 --- a/pygit2/index.py +++ b/pygit2/index.py @@ -94,6 +94,11 @@ def __getitem__(self, key): def __iter__(self): return IndexIterator(self) + def __bool__(self): + return True + + __nonzero__ = __bool__ + def read(self, force=True): """Update the contents the Index diff --git a/src/object.c b/src/object.c index e169a3367..507b81237 100644 --- a/src/object.c +++ b/src/object.c @@ -184,6 +184,12 @@ Object_peel(Object *self, PyObject *py_type) return wrap_object(peeled, self->repo); } +int +Object___nonzero__(PyObject *self) +{ + return 1; +} + PyGetSetDef Object_getseters[] = { GETTER(Object, oid), GETTER(Object, id), @@ -199,6 +205,36 @@ PyMethodDef Object_methods[] = { {NULL} }; +#if PY_MAJOR_VERSION == 2 || defined(PYPY_VERSION) +static PyNumberMethods Object_as_number = { + 0, /* nb_add */ + 0, /* nb_subtract */ + 0, /* nb_multiply */ + 0, /* nb_divide */ + 0, /* nb_remainder */ + 0, /* nb_divmod */ + 0, /* nb_power */ + 0, /* nb_negative */ + 0, /* nb_positive */ + 0, /* nb_absolute */ + Object___nonzero__, /* nb_nonzero */ + /* There are a lot more, we we don't need any of them */ +}; +#else +static PyNumberMethods Object_as_number = { + 0, /* nb_add */ + 0, /* nb_subtract */ + 0, /* nb_divide */ + 0, /* nb_remainder */ + 0, /* nb_divmod */ + 0, /* nb_power */ + 0, /* nb_negative */ + 0, /* nb_positive */ + 0, /* nb_absolute */ + Object___nonzero__, /* nb_nonzero */ + /* There are a lot more, we we don't need any of them */ +}; +#endif PyDoc_STRVAR(Object__doc__, "Base class for Git objects."); @@ -213,7 +249,7 @@ PyTypeObject ObjectType = { 0, /* tp_setattr */ 0, /* tp_compare */ 0, /* tp_repr */ - 0, /* tp_as_number */ + &Object_as_number, /* tp_as_number */ 0, /* tp_as_sequence */ 0, /* tp_as_mapping */ 0, /* tp_hash */ diff --git a/test/test_index.py b/test/test_index.py index f6ec53466..1b4ce6eb0 100644 --- a/test/test_index.py +++ b/test/test_index.py @@ -50,6 +50,9 @@ class IndexTest(utils.RepoTestCase): def test_index(self): self.assertNotEqual(None, self.repo.index) + def test_nonzero(self): + self.assertTrue(Index()) + def test_read(self): index = self.repo.index self.assertEqual(len(index), 2) diff --git a/test/test_object.py b/test/test_object.py index 663040418..3f76447dd 100644 --- a/test/test_object.py +++ b/test/test_object.py @@ -33,7 +33,7 @@ import unittest import pygit2 -from pygit2 import GIT_OBJ_TREE, GIT_OBJ_TAG, Tree, Tag +from pygit2 import GIT_OBJ_TREE, GIT_OBJ_TAG, GIT_OBJ_BLOB, Tree, Tag from . import utils @@ -78,5 +78,12 @@ def test_invalid_type(self): self.assertRaises(ValueError, commit.peel, Tag) + def test_nonzero(self): + empty_tree_id = self.repo.TreeBuilder().write() + self.assertTrue(self.repo[empty_tree_id]) + + empty_blob_id = self.repo.write(GIT_OBJ_BLOB, '') + self.assertTrue(self.repo[empty_tree_id]) + if __name__ == '__main__': unittest.main()