Skip to content

Commit 86e57cb

Browse files
fix(spanner_dbapi): replace insecure pickle with json for partition deserialization (#17014)
This PR resolves a critical Insecure Deserialization vulnerability (potential Remote Code Execution) in the `spanner_dbapi` module [b/510871112](b/510871112) . Previously, the module utilized `pickle.loads()` to decode partition IDs provided by users via the `RUN PARTITION` statement, creating a possibility for arbitrary code execution attack payloads. We have fully eliminated `pickle` usage in this module and migrated to standard `json` serialization. --------- Co-authored-by: Knut Olav Løite <koloite@gmail.com>
1 parent 6b62cb6 commit 86e57cb

5 files changed

Lines changed: 616 additions & 7 deletions

File tree

packages/google-cloud-spanner/google/cloud/spanner_dbapi/partition_helper.py

Lines changed: 126 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,145 @@
1313
# limitations under the License.
1414

1515
import base64
16+
import copy
17+
import datetime
1618
import gzip
17-
import pickle
19+
import json
1820
from dataclasses import dataclass
1921
from typing import Any
2022

23+
from google.protobuf.json_format import MessageToDict, ParseDict
24+
from google.protobuf.message import Message
25+
from google.protobuf.struct_pb2 import Struct
26+
2127
from google.cloud.spanner_v1 import BatchTransactionId
28+
from google.cloud.spanner_v1._helpers import _make_value_pb
29+
from google.cloud.spanner_v1.types import DirectedReadOptions, ExecuteSqlRequest, Type
30+
31+
_PROTO_CLASS_MAP = {
32+
"QueryOptions": ExecuteSqlRequest.QueryOptions,
33+
"DirectedReadOptions": DirectedReadOptions,
34+
"Struct": Struct,
35+
"Type": Type,
36+
}
37+
38+
39+
def _serialize_value(val: Any) -> Any:
40+
if isinstance(val, bytes):
41+
return {"__type__": "bytes", "value": base64.b64encode(val).decode("utf-8")}
42+
elif isinstance(val, datetime.datetime):
43+
return {"__type__": "datetime", "value": val.isoformat()}
44+
elif hasattr(val, "_pb"):
45+
return {
46+
"__type__": "protobuf",
47+
"class": val.__class__.__name__,
48+
"value": MessageToDict(val._pb, preserving_proto_field_name=True),
49+
}
50+
elif isinstance(val, Message):
51+
return {
52+
"__type__": "protobuf",
53+
"class": val.__class__.__name__,
54+
"value": MessageToDict(val, preserving_proto_field_name=True),
55+
}
56+
elif isinstance(val, dict):
57+
return {k: _serialize_value(v) for k, v in val.items()}
58+
elif isinstance(val, list):
59+
return [_serialize_value(v) for v in val]
60+
elif isinstance(val, tuple):
61+
return {"__type__": "tuple", "value": [_serialize_value(v) for v in val]}
62+
return val
63+
64+
65+
def _deserialize_value(val: Any) -> Any:
66+
if isinstance(val, dict):
67+
if "__type__" in val:
68+
t = val["__type__"]
69+
if t == "bytes":
70+
return base64.b64decode(val["value"])
71+
elif t == "datetime":
72+
dt_str = val["value"]
73+
if dt_str.endswith("Z"):
74+
dt_str = dt_str[:-1] + "+00:00"
75+
return datetime.datetime.fromisoformat(dt_str)
76+
elif t == "tuple":
77+
return tuple(_deserialize_value(x) for x in val["value"])
78+
elif t == "protobuf":
79+
cls_name = val.get("class")
80+
dict_val = val["value"]
81+
if cls_name in _PROTO_CLASS_MAP:
82+
cls = _PROTO_CLASS_MAP[cls_name]
83+
msg = cls()._pb if hasattr(cls(), "_pb") else cls()
84+
ParseDict(dict_val, msg)
85+
return cls(msg) if hasattr(cls(), "_pb") else msg
86+
return _deserialize_value(dict_val)
87+
return {k: _deserialize_value(v) for k, v in val.items()}
88+
elif isinstance(val, list):
89+
return [_deserialize_value(v) for v in val]
90+
return val
91+
92+
93+
def _unpack_value_pb(value):
94+
which = value.WhichOneof("kind")
95+
if which == "null_value":
96+
return None
97+
elif which == "number_value":
98+
return value.number_value
99+
elif which == "string_value":
100+
return value.string_value
101+
elif which == "bool_value":
102+
return value.bool_value
103+
elif which == "struct_value":
104+
return {k: _unpack_value_pb(v) for k, v in value.struct_value.fields.items()}
105+
elif which == "list_value":
106+
return [_unpack_value_pb(v) for v in value.list_value.values]
107+
return None
22108

23109

24110
def decode_from_string(encoded_partition_id):
25111
gzip_bytes = base64.b64decode(bytes(encoded_partition_id, "utf-8"))
26112
partition_id_bytes = gzip.decompress(gzip_bytes)
27-
return pickle.loads(partition_id_bytes)
113+
114+
data = json.loads(partition_id_bytes.decode("utf-8"))
115+
btid_data = data["batch_transaction_id"]
116+
btid = BatchTransactionId(
117+
transaction_id=_deserialize_value(btid_data["transaction_id"]),
118+
session_id=btid_data["session_id"],
119+
read_timestamp=_deserialize_value(btid_data["read_timestamp"]),
120+
)
121+
partition_result = _deserialize_value(data["partition_result"])
122+
123+
# Post-process query params back from Protobuf Struct to Python primitives
124+
if "query" in partition_result and "params" in partition_result["query"]:
125+
params_pb = partition_result["query"]["params"]
126+
if params_pb:
127+
partition_result["query"]["params"] = {
128+
k: _unpack_value_pb(v) for k, v in params_pb.fields.items()
129+
}
130+
131+
return PartitionId(btid, partition_result)
28132

29133

30134
def encode_to_string(batch_transaction_id, partition_result):
31-
partition_id = PartitionId(batch_transaction_id, partition_result)
32-
partition_id_bytes = pickle.dumps(partition_id)
135+
# Copy to avoid modifying the caller's dictionary in connection.py
136+
partition_result = copy.deepcopy(partition_result)
137+
138+
# Pre-process query params into a Protobuf Struct
139+
if "query" in partition_result and "params" in partition_result["query"]:
140+
params = partition_result["query"]["params"]
141+
if params:
142+
params_pb = Struct(fields={k: _make_value_pb(v) for k, v in params.items()})
143+
partition_result["query"]["params"] = params_pb
144+
145+
data = {
146+
"batch_transaction_id": {
147+
"transaction_id": _serialize_value(batch_transaction_id.transaction_id),
148+
"session_id": batch_transaction_id.session_id,
149+
"read_timestamp": _serialize_value(batch_transaction_id.read_timestamp),
150+
},
151+
"partition_result": _serialize_value(partition_result),
152+
}
153+
154+
partition_id_bytes = json.dumps(data).encode("utf-8")
33155
gzip_bytes = gzip.compress(partition_id_bytes)
34156
return str(base64.b64encode(gzip_bytes), "utf-8")
35157

packages/google-cloud-spanner/google/cloud/spanner_v1/testing/mock_spanner.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,21 @@ class MockSpanner:
3636
def __init__(self):
3737
self.results = {}
3838
self.execute_streaming_sql_results = {}
39+
self.partition_results = {}
3940
self.errors = {}
4041

4142
def clear_results(self):
4243
self.results = {}
4344
self.execute_streaming_sql_results = {}
45+
self.partition_results = {}
4446
self.errors = {}
4547

4648
def add_result(self, sql: str, result: result_set.ResultSet):
4749
self.results[sql.lower().strip()] = result
4850

51+
def add_partition_result(self, sql: str, result: spanner.PartitionResponse):
52+
self.partition_results[sql.lower().strip()] = result
53+
4954
def add_execute_streaming_sql_results(
5055
self, sql: str, partial_result_sets: list[result_set.PartialResultSet]
5156
):
@@ -57,6 +62,12 @@ def get_result(self, sql: str) -> result_set.ResultSet:
5762
raise ValueError(f"No result found for {sql}")
5863
return result
5964

65+
def get_partition_result(self, sql: str) -> spanner.PartitionResponse:
66+
result = self.partition_results.get(sql.lower().strip())
67+
if result is None:
68+
return spanner.PartitionResponse()
69+
return result
70+
6071
def add_error(self, method: str, error: _Status):
6172
if not hasattr(self, "_errors_list"):
6273
self._errors_list = {}
@@ -300,11 +311,12 @@ def Rollback(self, request, context):
300311

301312
def PartitionQuery(self, request, context):
302313
self._requests.append(request)
303-
return spanner.PartitionResponse()
314+
return self.mock_spanner.get_partition_result(request.sql)
304315

305316
def PartitionRead(self, request, context):
306317
self._requests.append(request)
307-
return spanner.PartitionResponse()
318+
# For reads, look up by target table name
319+
return self.mock_spanner.get_partition_result(request.table)
308320

309321
def BatchWrite(self, request, context):
310322
self._requests.append(request)

packages/google-cloud-spanner/tests/_helpers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
from os import getenv
22
from unittest import IsolatedAsyncioTestCase
33

4-
import mock
4+
try:
5+
import mock
6+
except ImportError:
7+
import unittest.mock as mock
58

69
from google.cloud.spanner_v1 import gapic_version
710
from google.cloud.spanner_v1.database_sessions_manager import TransactionType
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright 2024 Google LLC All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
from google.cloud.spanner_dbapi.connection import Connection
17+
from google.cloud.spanner_dbapi.parsed_statement import ParsedStatement, Statement
18+
from google.cloud.spanner_v1 import TypeCode
19+
from google.cloud.spanner_v1.types import spanner as spanner_types
20+
from tests.mockserver_tests.mock_server_test_base import (
21+
MockServerTestBase,
22+
add_single_result,
23+
)
24+
25+
26+
class TestDbapiPartitionQuery(MockServerTestBase):
27+
def test_partition_query_and_run_partition(self):
28+
sql = "SELECT name FROM users WHERE active = true"
29+
30+
# 1. Set up mock results for PartitionQuery RPC in the mock servicer
31+
partition_response = spanner_types.PartitionResponse()
32+
partition_response.partitions.extend(
33+
[
34+
spanner_types.Partition(partition_token=b"mock-token-1"),
35+
spanner_types.Partition(partition_token=b"mock-token-2"),
36+
]
37+
)
38+
self.spanner_service.mock_spanner.add_partition_result(sql, partition_response)
39+
40+
# 2. Set up mock results for ExecuteSql when executing the partitions
41+
add_single_result(sql, "name", TypeCode.STRING, [("Alice",), ("Bob",)])
42+
43+
# 3. Connect via DB-API and mark connection as read-only (required for partitioning)
44+
connection = Connection(self.instance, self.database)
45+
connection._read_only = True
46+
47+
# Define partitioning parameters inside DB-API Statement
48+
from google.cloud.spanner_dbapi.parsed_statement import (
49+
ClientSideStatementType,
50+
StatementType,
51+
)
52+
53+
parsed = ParsedStatement(
54+
statement_type=StatementType.CLIENT_SIDE,
55+
statement=Statement(sql),
56+
client_side_statement_type=ClientSideStatementType.PARTITION_QUERY,
57+
client_side_statement_params=["SELECT name FROM users WHERE active = true"],
58+
)
59+
60+
# Generate serialized token strings (Base64 + GZip JSON)
61+
partition_ids = connection.partition_query(parsed)
62+
self.assertEqual(2, len(partition_ids))
63+
64+
# 4. Reconstruct & Execute the partitions by deserializing their tokens
65+
all_names = []
66+
for token in partition_ids:
67+
result_stream = connection.run_partition(token)
68+
for row in result_stream:
69+
all_names.append(row[0])
70+
71+
# Verify results are successfully round-tripped and parsed
72+
self.assertIn("Alice", all_names)
73+
self.assertIn("Bob", all_names)
74+
75+
def test_partition_query_with_complex_parameters(self):
76+
import datetime
77+
import decimal
78+
79+
sql = "SELECT name FROM users WHERE active = @active AND salary > @salary AND signup_time = @signup_time"
80+
81+
# Set up complex parameter values (bool, Decimal, datetime)
82+
params = {
83+
"active": True,
84+
"salary": decimal.Decimal("75000.50"),
85+
"signup_time": datetime.datetime(
86+
2026, 5, 10, 12, 34, 56, tzinfo=datetime.timezone.utc
87+
),
88+
}
89+
from google.cloud.spanner_v1 import Type
90+
91+
param_types = {
92+
"active": Type(code=TypeCode.BOOL),
93+
"salary": Type(code=TypeCode.NUMERIC),
94+
"signup_time": Type(code=TypeCode.TIMESTAMP),
95+
}
96+
97+
# 1. Mock results for the partition generation RPC
98+
partition_response = spanner_types.PartitionResponse()
99+
partition_response.partitions.extend(
100+
[spanner_types.Partition(partition_token=b"complex-mock-token-1")]
101+
)
102+
self.spanner_service.mock_spanner.add_partition_result(sql, partition_response)
103+
104+
# 2. Mock results for execution of partition streaming SQL
105+
add_single_result(sql, "name", TypeCode.STRING, [("Charlie",)])
106+
107+
# 3. Establish Connection
108+
connection = Connection(self.instance, self.database)
109+
connection._read_only = True
110+
111+
from google.cloud.spanner_dbapi.parsed_statement import (
112+
ClientSideStatementType,
113+
StatementType,
114+
)
115+
116+
parsed = ParsedStatement(
117+
statement_type=StatementType.CLIENT_SIDE,
118+
statement=Statement(sql, params=params, param_types=param_types),
119+
client_side_statement_type=ClientSideStatementType.PARTITION_QUERY,
120+
client_side_statement_params=[sql],
121+
)
122+
123+
# Execute partition generation - this serializes query parameters!
124+
partition_ids = connection.partition_query(parsed)
125+
self.assertEqual(1, len(partition_ids))
126+
127+
# 4. Reconstruct and run the partition E2E
128+
all_names = []
129+
for token in partition_ids:
130+
result_stream = connection.run_partition(token)
131+
for row in result_stream:
132+
all_names.append(row[0])
133+
134+
self.assertEqual(["Charlie"], all_names)

0 commit comments

Comments
 (0)