diff --git a/changelog.d/53.improve.md b/changelog.d/53.improve.md new file mode 100644 index 0000000..1660ee8 --- /dev/null +++ b/changelog.d/53.improve.md @@ -0,0 +1,5 @@ +Harden `Structure` class against memory leak. +The extensions' implementation of packstream `Structure` could leak memory when being part of a reference cycle. +In reality this doesn't matter because the driver never constructs cyclic `Structure`s. +Every packstream value is a tree in terms of references (both directions: packing and unpacking). +This change is meant to harden the extensions against introducing effective memory leaks in the driver should the driver's usage of `Structure` change in the future. diff --git a/src/lib.rs b/src/lib.rs index 5030228..4d547ce 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -19,7 +19,7 @@ use pyo3::basic::CompareOp; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use pyo3::types::{PyBytes, PyTuple}; -use pyo3::IntoPyObjectExt; +use pyo3::{IntoPyObjectExt, PyTraverseError, PyVisit}; #[pymodule(gil_used = false)] #[pyo3(name = "_rust")] @@ -114,4 +114,15 @@ impl Structure { } Ok(fields_hash.wrapping_add(self.tag.into())) } + + fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> { + for field in &self.fields { + visit.call(field)?; + } + Ok(()) + } + + fn __clear__(&mut self) { + self.fields.clear(); + } } diff --git a/src/v1/unpack.rs b/src/v1/unpack.rs index 4f5380d..a258ed9 100644 --- a/src/v1/unpack.rs +++ b/src/v1/unpack.rs @@ -135,8 +135,7 @@ impl<'a> PackStreamDecoder<'a> { _ => { // raise ValueError("Unknown PackStream marker %02X" % marker) return Err(PyErr::new::(format!( - "Unknown PackStream marker {:02X}", - marker + "Unknown PackStream marker {marker:02X}", ))); } }) @@ -243,8 +242,7 @@ impl<'a> PackStreamDecoder<'a> { STRING_16 => self.read_u16(), STRING_32 => self.read_u32(), _ => Err(PyErr::new::(format!( - "Invalid string length marker: {}", - marker + "Invalid string length marker: {marker}", ))), } } diff --git a/tests/test_structure.py b/tests/test_structure.py new file mode 100644 index 0000000..1a25e6e --- /dev/null +++ b/tests/test_structure.py @@ -0,0 +1,55 @@ +# Copyright (c) "Neo4j" +# Neo4j Sweden AB [https://neo4j.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from __future__ import annotations + +import gc +from contextlib import contextmanager + +from neo4j._codec.packstream import Structure + + +@contextmanager +def gc_disabled(): + try: + gc.disable() + yield + finally: + gc.enable() + gc.collect() + + +class StructureHolder: + s: Structure | None = None + + +def test_memory_leak() -> None: + iterations = 10_000 + + gc.collect() + with gc_disabled(): + for _ in range(iterations): + # create a reference cycle + holder1 = StructureHolder() + structure1 = Structure(b"\x00", [holder1]) + holder2 = StructureHolder() + structure2 = Structure(b"\x01", [holder2]) + holder1.s = structure2 + holder2.s = structure1 + del structure1, structure2, holder1, holder2 + + cleaned = gc.collect() + assert cleaned >= 4 * iterations