Skip to content

Commit

Permalink
disallow cross connection replacement scans, add test
Browse files Browse the repository at this point in the history
  • Loading branch information
Tishj committed May 21, 2024
1 parent cb1291e commit ad0612d
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 12 deletions.
3 changes: 3 additions & 0 deletions src/include/duckdb/main/client_context_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ class ClientContextWrapper {
public:
explicit ClientContextWrapper(const shared_ptr<ClientContext> &context);
shared_ptr<ClientContext> GetContext();
operator bool() const {
return !client_context.expired();
}

private:
weak_ptr<ClientContext> client_context;
Expand Down
2 changes: 2 additions & 0 deletions tools/pythonpkg/src/include/duckdb_python/pyrelation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,8 @@ struct DuckDBPyRelation {
static bool IsRelation(const py::object &object);

bool CanBeRegisteredBy(Connection &con);
bool CanBeRegisteredBy(ClientContext &context);
bool CanBeRegisteredBy(shared_ptr<ClientContext> &context);

Relation &GetRel();

Expand Down
17 changes: 14 additions & 3 deletions tools/pythonpkg/src/pyrelation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,23 @@ DuckDBPyRelation::DuckDBPyRelation(shared_ptr<Relation> rel_p) : rel(std::move(r
}

bool DuckDBPyRelation::CanBeRegisteredBy(Connection &con) {
if (!rel) {
return CanBeRegisteredBy(con.context);
}

bool DuckDBPyRelation::CanBeRegisteredBy(ClientContext &context) {
if (!rel || !rel->context) {
// PyRelation without an internal relation can not be registered
return false;
}
auto context = rel->context.GetContext();
return context == con.context;
auto this_context = rel->context.GetContext();
return &context == this_context.get();
}

bool DuckDBPyRelation::CanBeRegisteredBy(shared_ptr<ClientContext> &con) {
if (!con) {
return false;
}
return CanBeRegisteredBy(*con);
}

DuckDBPyRelation::~DuckDBPyRelation() {
Expand Down
22 changes: 13 additions & 9 deletions tools/pythonpkg/src/python_replacement_scan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ static void CreateArrowScan(const string &name, py::object entry, TableFunctionR
table_function.external_dependency = std::move(dependency);
}

static unique_ptr<TableRef> TryReplacementObject(const py::object &entry, const string &name,
ClientProperties &client_properties) {
static unique_ptr<TableRef> TryReplacementObject(const py::object &entry, const string &name, ClientContext &context) {
auto client_properties = context.GetClientProperties();
auto table_function = make_uniq<TableFunctionRef>();
vector<unique_ptr<ParsedExpression>> children;
NumpyObjectType numpytype; // Identify the type of accepted numpy objects.
Expand All @@ -54,6 +54,12 @@ static unique_ptr<TableRef> TryReplacementObject(const py::object &entry, const
CreateArrowScan(name, entry, *table_function, children, client_properties);
} else if (DuckDBPyRelation::IsRelation(entry)) {
auto pyrel = py::cast<DuckDBPyRelation *>(entry);
if (!pyrel->CanBeRegisteredBy(context)) {
throw InvalidInputException(
"Python Object \"%s\" of type \"DuckDBPyRelation\" not suitable for replacement scan.\nThe object was "
"created by another Connection and can therefore not be used by this Connection.",
name);
}
// create a subquery from the underlying relation object
auto select = make_uniq<SelectStatement>();
select->node = pyrel->GetRel().GetQueryNode();
Expand Down Expand Up @@ -111,15 +117,15 @@ static unique_ptr<TableRef> TryReplacementObject(const py::object &entry, const
return std::move(table_function);
}

static unique_ptr<TableRef> TryReplacement(py::dict &dict, const string &name, ClientProperties &client_properties,
static unique_ptr<TableRef> TryReplacement(py::dict &dict, const string &name, ClientContext &context,
py::object &current_frame) {
auto table_name = py::str(name);
if (!dict.contains(table_name)) {
// not present in the globals
return nullptr;
}
const py::object &entry = dict[table_name];
auto result = TryReplacementObject(entry, name, client_properties);
auto result = TryReplacementObject(entry, name, context);
if (!result) {
std::string location = py::cast<py::str>(current_frame.attr("f_code").attr("co_filename"));
location += ":";
Expand All @@ -140,20 +146,19 @@ static unique_ptr<TableRef> ReplaceInternal(ClientContext &context, const string
py::gil_scoped_acquire acquire;
// Here we do an exhaustive search on the frame lineage
auto current_frame = py::module::import("inspect").attr("currentframe")();
auto client_properties = context.GetClientProperties();
while (hasattr(current_frame, "f_locals")) {
auto local_dict = py::reinterpret_borrow<py::dict>(current_frame.attr("f_locals"));
// search local dictionary
if (local_dict) {
auto result = TryReplacement(local_dict, table_name, client_properties, current_frame);
auto result = TryReplacement(local_dict, table_name, context, current_frame);
if (result) {
return result;
}
}
// search global dictionary
auto global_dict = py::reinterpret_borrow<py::dict>(current_frame.attr("f_globals"));
if (global_dict) {
auto result = TryReplacement(global_dict, table_name, client_properties, current_frame);
auto result = TryReplacement(global_dict, table_name, context, current_frame);
if (result) {
return result;
}
Expand All @@ -175,8 +180,7 @@ unique_ptr<TableRef> PythonReplacementScan::Replace(ClientContext &context, Repl
auto &python_dependency = dependency_item->Cast<PythonDependencyItem>();
auto &registered_object = *python_dependency.object;
auto &py_object = registered_object.obj;
auto client_properties = context.GetClientProperties();
auto result = TryReplacementObject(py_object, table_name, client_properties);
auto result = TryReplacementObject(py_object, table_name, context);
// This was cached, so it was successful before, it should be successfull now
D_ASSERT(result);
return std::move(result);
Expand Down
22 changes: 22 additions & 0 deletions tools/pythonpkg/tests/fast/test_replacement_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,25 @@ def test_replacement_scan_fail(self):
match=r'Python Object "random_object" of type "str" found on line .* not suitable for replacement scans.',
):
con.execute("select count(*) from random_object").fetchone()

def test_replacement_of_cross_connection_relation(self):
con1 = duckdb.connect(':memory:')
con2 = duckdb.connect(':memory:')
con1.query('create table integers(i int)')
con2.query('create table integers(v varchar)')
con1.query('insert into integers values (42)')
con2.query('insert into integers values (\'xxx\')')
rel1 = con1.query('select * from integers')
with pytest.raises(
duckdb.InvalidInputException,
match=r'The object was created by another Connection and can therefore not be used by this Connection.',
):
con2.query('from rel1')

del con1

with pytest.raises(
duckdb.InvalidInputException,
match=r'The object was created by another Connection and can therefore not be used by this Connection.',
):
con2.query('from rel1')

0 comments on commit ad0612d

Please sign in to comment.