Skip to content
Permalink
Browse files Browse the repository at this point in the history
Fix handling of invalid union data in table-based serializer
Summary:
Fix handling of invalid union data in the table-based serializer. Previously if the input contained duplicate union data, previous active member of the union was overwritten without calling the destructor of the old object, potentially causing a memory leak. In addition to that, if the second piece of data was incomplete the wrong destructor would be called during stack unwinding causing a segfault, data corruption or other undesirable effects.

Fix the issue by clearing the union if there is an active member.

Also fix the type of the data member that holds the active field id (it's `int`, not `FieldID`).

Reviewed By: yfeldblum

Differential Revision: D26440248

fbshipit-source-id: fae9ab96566cf07e14dabe9663b2beb680a01bb4
  • Loading branch information
vitaut authored and facebook-github-bot committed Feb 19, 2021
1 parent cd365bc commit bfda1ef
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 17 deletions.
34 changes: 17 additions & 17 deletions thrift/lib/cpp2/protocol/TableBasedSerializer.cpp
Expand Up @@ -186,9 +186,11 @@ const FieldInfo* FOLLY_NULLABLE findFieldInfo(
return nullptr;
}

const FieldID& activeUnionMemberId(const void* object, ptrdiff_t offset) {
return *reinterpret_cast<const FieldID*>(
offset + static_cast<const char*>(object));
// Returns a reference to the data member that holds the active field id for a
// Thrift union object.
const int& getActiveId(const void* object, const StructInfo& info) {
return *reinterpret_cast<const int*>(
static_cast<const char*>(object) + info.unionExt->unionTypeOffset);
}

const bool& fieldIsSet(const void* object, ptrdiff_t offset) {
Expand Down Expand Up @@ -591,16 +593,15 @@ void read(Protocol_* iprot, const StructInfo& structInfo, void* object) {
readState.readStructEnd(iprot);
return;
}
const auto* fieldInfo = findFieldInfo(iprot, readState, structInfo);
// Found it.
if (fieldInfo) {
void* unionVal = getMember(*fieldInfo, object);
// Default construct and placement new into the member union.
structInfo.unionExt->initMember[fieldInfo - structInfo.fieldInfos](
unionVal);
read(iprot, *fieldInfo->typeInfo, readState, unionVal);
const_cast<FieldID&>(activeUnionMemberId(
object, structInfo.unionExt->unionTypeOffset)) = fieldInfo->id;
if (const auto* fieldInfo = findFieldInfo(iprot, readState, structInfo)) {
auto& activeId = const_cast<int&>(getActiveId(object, structInfo));
if (activeId != 0) {
structInfo.unionExt->clear(object);
}
void* value = getMember(*fieldInfo, object);
structInfo.unionExt->initMember[fieldInfo - structInfo.fieldInfos](value);
read(iprot, *fieldInfo->typeInfo, readState, value);
activeId = fieldInfo->id;
} else {
skip(iprot, readState);
}
Expand Down Expand Up @@ -671,14 +672,13 @@ write(Protocol_* iprot, const StructInfo& structInfo, const void* object) {
size_t written = iprot->writeStructBegin(structInfo.name);
if (UNLIKELY(structInfo.unionExt != nullptr)) {
const FieldInfo* end = structInfo.fieldInfos + structInfo.numFields;
const auto& unionId =
activeUnionMemberId(object, structInfo.unionExt->unionTypeOffset);
const auto& activeId = getActiveId(object, structInfo);
const FieldInfo* found = std::lower_bound(
structInfo.fieldInfos,
end,
unionId,
activeId,
[](const FieldInfo& lhs, FieldID rhs) { return lhs.id < rhs; });
if (found < end && found->id == unionId) {
if (found < end && found->id == activeId) {
const OptionalThriftValue value = getValue(*found->typeInfo, object);
if (value.hasValue()) {
written += writeField(iprot, *found, value.value());
Expand Down
22 changes: 22 additions & 0 deletions thrift/test/tablebased/SerializerTest.cpp
Expand Up @@ -356,3 +356,25 @@ TEST(SerializerTest, UnionValueOffsetIsZero) {
u.set_fieldB({});
EXPECT_EQ(static_cast<void*>(&u), &*u.fieldB_ref());
}

TEST(SerializerTest, DuplicateUnionData) {
// Test that we can handle invalid serialized input with duplicate and
// incomplete union data.
const char data[] =
"\x0c" // type = TType::T_STRUCT
"\x00\x01" // fieldId = 1 (unionField)
"\x0b" // type = TType::T_STRING
"\x00\x01" // fieldId = 1 (stringField)
"\x00\x00\x00\x00" // size = 0
"\x00" // end of unionField

"\x0c" // type = TType::T_STRUCT
"\x00\x01" // fieldId = 1 (unionField)
"\x13" // type = TType::T_FLOAT
"\x00\x02"; // fieldId = 2 (floatField), value is missing

EXPECT_THROW(
BinarySerializer::deserialize<tablebased::TestStructWithUnion>(
folly::StringPiece(data, sizeof(data))),
std::out_of_range);
}
9 changes: 9 additions & 0 deletions thrift/test/tablebased/thrift_tablebased.thrift
Expand Up @@ -79,3 +79,12 @@ union UnionWithRef {
3: StructA fieldC (cpp2.ref_type = "shared_const");
4: StructA fieldD (cpp2.ref_type = "shared");
}

union TestUnion {
1: string stringField;
2: float floatField;
}

struct TestStructWithUnion {
1: TestUnion unionField;
}

0 comments on commit bfda1ef

Please sign in to comment.