Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions changelog.d/53.improve.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
Harden `Structure` class against memory leak<ISSUES_LIST>.
The extensions' implementation of packstream `Structure` could leak memory when being part of a reference cycle.
In reality this doesn't matter because the driver never constructs cyclic `Structure`s.
Every packstream value is a tree in terms of references (both directions: packing and unpacking).
This change is meant to harden the extensions against introducing effective memory leaks in the driver should the driver's usage of `Structure` change in the future.
13 changes: 12 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use pyo3::basic::CompareOp;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::types::{PyBytes, PyTuple};
use pyo3::IntoPyObjectExt;
use pyo3::{IntoPyObjectExt, PyTraverseError, PyVisit};

#[pymodule(gil_used = false)]
#[pyo3(name = "_rust")]
Expand Down Expand Up @@ -114,4 +114,15 @@ impl Structure {
}
Ok(fields_hash.wrapping_add(self.tag.into()))
}

fn __traverse__(&self, visit: PyVisit<'_>) -> Result<(), PyTraverseError> {
for field in &self.fields {
visit.call(field)?;
}
Ok(())
}

fn __clear__(&mut self) {
self.fields.clear();
}
}
6 changes: 2 additions & 4 deletions src/v1/unpack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,7 @@ impl<'a> PackStreamDecoder<'a> {
_ => {
// raise ValueError("Unknown PackStream marker %02X" % marker)
return Err(PyErr::new::<PyValueError, _>(format!(
"Unknown PackStream marker {:02X}",
marker
"Unknown PackStream marker {marker:02X}",
)));
}
})
Expand Down Expand Up @@ -243,8 +242,7 @@ impl<'a> PackStreamDecoder<'a> {
STRING_16 => self.read_u16(),
STRING_32 => self.read_u32(),
_ => Err(PyErr::new::<PyValueError, _>(format!(
"Invalid string length marker: {}",
marker
"Invalid string length marker: {marker}",
))),
}
}
Expand Down
55 changes: 55 additions & 0 deletions tests/test_structure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright (c) "Neo4j"
# Neo4j Sweden AB [https://neo4j.com]
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from __future__ import annotations

import gc
from contextlib import contextmanager

from neo4j._codec.packstream import Structure


@contextmanager
def gc_disabled():
try:
gc.disable()
yield
finally:
gc.enable()
gc.collect()


class StructureHolder:
s: Structure | None = None


def test_memory_leak() -> None:
iterations = 10_000

gc.collect()
with gc_disabled():
for _ in range(iterations):
# create a reference cycle
holder1 = StructureHolder()
structure1 = Structure(b"\x00", [holder1])
holder2 = StructureHolder()
structure2 = Structure(b"\x01", [holder2])
holder1.s = structure2
holder2.s = structure1
del structure1, structure2, holder1, holder2

cleaned = gc.collect()
assert cleaned >= 4 * iterations