Skip to content

Commit

Permalink
Merge pull request #5 from lawinsider/fix-memory-leak
Browse files Browse the repository at this point in the history
Fix memory leak and add tests
  • Loading branch information
honnibal committed Jun 5, 2019
2 parents adccf2f + 41a3a51 commit b20c2af
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 15 deletions.
64 changes: 64 additions & 0 deletions srsly/tests/ujson/test_ujson.py
Original file line number Diff line number Diff line change
Expand Up @@ -813,6 +813,70 @@ def test_sortKeys(self):
sortedKeys = ujson.dumps(data, sort_keys=True)
self.assertEqual(sortedKeys, '{"a":1,"b":1,"c":1,"d":1,"e":1,"f":1}')

def test_does_not_leak_dictionary_values(self):
import gc
gc.collect()
value = ["abc"]
data = {"1": value}
ref_count = sys.getrefcount(value)
ujson.dumps(data)
self.assertEqual(ref_count, sys.getrefcount(value))

def test_does_not_leak_dictionary_keys(self):
import gc
gc.collect()
key1 = "1"
key2 = "1"
value1 = ["abc"]
value2 = [1, 2, 3]
data = {key1: value1, key2: value2}
ref_count1 = sys.getrefcount(key1)
ref_count2 = sys.getrefcount(key2)
ujson.dumps(data)
self.assertEqual(ref_count1, sys.getrefcount(key1))
self.assertEqual(ref_count2, sys.getrefcount(key2))

def test_does_not_leak_dictionary_string_key(self):
import gc
gc.collect()
key1 = "1"
value1 = 1
data = {key1: value1}
ref_count1 = sys.getrefcount(key1)
ujson.dumps(data)
self.assertEqual(ref_count1, sys.getrefcount(key1))

def test_does_not_leak_dictionary_tuple_key(self):
import gc
gc.collect()
key1 = ("a",)
value1 = 1
data = {key1: value1}
ref_count1 = sys.getrefcount(key1)
ujson.dumps(data)
self.assertEqual(ref_count1, sys.getrefcount(key1))

def test_does_not_leak_dictionary_bytes_key(self):
import gc
gc.collect()
key1 = b"1"
value1 = 1
data = {key1: value1}
ref_count1 = sys.getrefcount(key1)
ujson.dumps(data)
self.assertEqual(ref_count1, sys.getrefcount(key1))

def test_does_not_leak_dictionary_None_key(self):
import gc
gc.collect()
key1 = None
value1 = 1
data = {key1: value1}
ref_count1 = sys.getrefcount(key1)
ujson.dumps(data)
self.assertEqual(ref_count1, sys.getrefcount(key1))


"""
def test_decodeNumericIntFrcOverflow(self):
input = "X.Y"
Expand Down
42 changes: 27 additions & 15 deletions srsly/ujson/objToJSON.c
Original file line number Diff line number Diff line change
Expand Up @@ -449,9 +449,7 @@ char *List_iterGetName(JSOBJ obj, JSONTypeContext *tc, size_t *outLen)

int Dict_iterNext(JSOBJ obj, JSONTypeContext *tc)
{
#if PY_MAJOR_VERSION >= 3
PyObject* itemNameTmp;
#endif

if (GET_TC(tc)->itemName)
{
Expand All @@ -465,41 +463,48 @@ int Dict_iterNext(JSOBJ obj, JSONTypeContext *tc)
return 0;
}

if (!(GET_TC(tc)->itemValue = PyObject_GetItem(GET_TC(tc)->dictObj, GET_TC(tc)->itemName)))
{
if (GET_TC(tc)->itemValue) {
Py_DECREF(GET_TC(tc)->itemValue);
GET_TC(tc)->itemValue = NULL;
}

if (!(GET_TC(tc)->itemValue = PyObject_GetItem(GET_TC(tc)->dictObj, GET_TC(tc)->itemName))) {
PRINTMARK();
return 0;
}

if (PyUnicode_Check(GET_TC(tc)->itemName))
{
GET_TC(tc)->itemName = PyUnicode_AsUTF8String (GET_TC(tc)->itemName);
itemNameTmp = GET_TC(tc)->itemName;
GET_TC(tc)->itemName = PyUnicode_AsUTF8String (itemNameTmp);
Py_DECREF(itemNameTmp);
}
else
if (!PyString_Check(GET_TC(tc)->itemName))
{
GET_TC(tc)->itemName = PyObject_Str(GET_TC(tc)->itemName);
itemNameTmp = GET_TC(tc)->itemName;
GET_TC(tc)->itemName = PyObject_Str(itemNameTmp);
Py_DECREF(itemNameTmp);
#if PY_MAJOR_VERSION >= 3
itemNameTmp = GET_TC(tc)->itemName;
GET_TC(tc)->itemName = PyUnicode_AsUTF8String (GET_TC(tc)->itemName);
GET_TC(tc)->itemName = PyUnicode_AsUTF8String (itemNameTmp);
Py_DECREF(itemNameTmp);
#endif
}
else
{
Py_INCREF(GET_TC(tc)->itemName);
}
PRINTMARK();
return 1;
}

void Dict_iterEnd(JSOBJ obj, JSONTypeContext *tc)
{
if (GET_TC(tc)->itemName)
{
if (GET_TC(tc)->itemName) {
Py_DECREF(GET_TC(tc)->itemName);
GET_TC(tc)->itemName = NULL;
}
if (GET_TC(tc)->itemValue) {
Py_DECREF(GET_TC(tc)->itemValue);
GET_TC(tc)->itemValue = NULL;
}
Py_CLEAR(GET_TC(tc)->iterator);
Py_DECREF(GET_TC(tc)->dictObj);
PRINTMARK();
Expand Down Expand Up @@ -943,6 +948,10 @@ void Object_endTypeContext(JSOBJ obj, JSONTypeContext *tc)
{
Py_XDECREF(GET_TC(tc)->newObj);

if (tc->type == JT_RAW)
{
Py_XDECREF(GET_TC(tc)->rawJSONValue);
}
PyObject_Free(tc->prv);
tc->prv = NULL;
}
Expand Down Expand Up @@ -1112,6 +1121,7 @@ PyObject* objToJSONFile(PyObject* self, PyObject *args, PyObject *kwargs)
PyObject *string;
PyObject *write;
PyObject *argtuple;
PyObject *write_result;

PRINTMARK();

Expand Down Expand Up @@ -1154,13 +1164,15 @@ PyObject* objToJSONFile(PyObject* self, PyObject *args, PyObject *kwargs)
Py_XDECREF(write);
return NULL;
}
if (PyObject_CallObject (write, argtuple) == NULL)
write_result = PyObject_CallObject (write, argtuple);
if (write_result == NULL)
{
Py_XDECREF(write);
Py_XDECREF(argtuple);
return NULL;
}


Py_DECREF(write_result);
Py_XDECREF(write);
Py_DECREF(argtuple);
Py_XDECREF(string);
Expand Down

0 comments on commit b20c2af

Please sign in to comment.