Skip to content

Commit

Permalink
Update RecordMap handling (#11306)
Browse files Browse the repository at this point in the history
- Manage record map insertions more centrally.
- Remove the array of records in BinaryFormattedObject (it wasn't really useful)
- Expose the root record on BinaryFormattedObject.
- Manually grab the header first, before iterating.
  • Loading branch information
JeremyKuhne committed May 3, 2024
1 parent 1e92b29 commit c926856
Show file tree
Hide file tree
Showing 29 changed files with 164 additions and 523 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,21 +24,13 @@ public ArraySingleObject(Id objectId, IReadOnlyList<object?> arrayObjects)
: base(new ArrayInfo(objectId, arrayObjects.Count), arrayObjects)
{ }

static ArraySingleObject IBinaryFormatParseable<ArraySingleObject>.Parse(
BinaryFormattedObject.IParseState state)
{
ArraySingleObject record = new(
ArrayInfo.Parse(state.Reader, out Count length),
ReadObjectArrayValues(state, length));

state.RecordMap[record.ObjectId] = record;
return record;
}
static ArraySingleObject IBinaryFormatParseable<ArraySingleObject>.Parse(BinaryFormattedObject.IParseState state) =>
new(ArrayInfo.Parse(state.Reader, out Count length), ReadObjectArrayValues(state, length));

public override void Write(BinaryWriter writer)
{
writer.Write((byte)RecordType);
ArrayInfo.Write(writer);
WriteRecords(writer, ArrayObjects);
WriteRecords(writer, ArrayObjects, coalesceNulls: true);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ static ArrayRecord IBinaryFormatParseable<ArrayRecord>.Parse(
PrimitiveType primitiveType = (PrimitiveType)state.Reader.ReadByte();
Debug.Assert(typeof(T) == primitiveType.GetPrimitiveTypeType());

ArraySinglePrimitive<T> record = new(id, state.Reader.ReadPrimitiveArray<T>(length));
state.RecordMap[record.ObjectId] = record;
return record;
return new ArraySinglePrimitive<T>(id, state.Reader.ReadPrimitiveArray<T>(length));
}

public override void Write(BinaryWriter writer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,21 +21,13 @@ public ArraySingleString(Id objectId, IReadOnlyList<object?> arrayObjects)
: base(new ArrayInfo(objectId, arrayObjects.Count), arrayObjects)
{ }

static ArraySingleString IBinaryFormatParseable<ArraySingleString>.Parse(
BinaryFormattedObject.IParseState state)
{
ArraySingleString record = new(
ArrayInfo.Parse(state.Reader, out Count length),
ReadObjectArrayValues(state, length));

state.RecordMap[record.ObjectId] = record;
return record;
}
static ArraySingleString IBinaryFormatParseable<ArraySingleString>.Parse(BinaryFormattedObject.IParseState state) =>
new(ArrayInfo.Parse(state.Reader, out Count length), ReadObjectArrayValues(state, length));

public override void Write(BinaryWriter writer)
{
writer.Write((byte)RecordType);
ArrayInfo.Write(writer);
WriteRecords(writer, ArrayObjects);
WriteRecords(writer, ArrayObjects, coalesceNulls: true);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ internal static IBinaryArray Parse(BinaryFormattedObject.IParseState state)
_ => throw new SerializationException($"Invalid primitive type '{(PrimitiveType)info}'"),
};

state.RecordMap[objectId] = array;
return array;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ Type ITypeResolver.GetType(string typeName, Id libraryId)
? assembly.GetType(typeName)
: GetSimplyNamedTypeFromAssembly(assembly, typeName);

_lastType = type ?? throw new SerializationException($"Could not find type '{typeName}'.");
_types[(typeName, libraryId)] = type ?? throw new SerializationException($"Could not find type '{typeName}'.");
_lastType = type;
return _lastType;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,15 @@ internal sealed partial class BinaryFormattedObject
#pragma warning restore SYSLIB0050

private static readonly Options s_defaultOptions = new();
private readonly Options _options;

private readonly List<IRecord> _records = [];
private readonly RecordMap _recordMap = new();

private readonly Options _options;
private ITypeResolver? _typeResolver;
private ITypeResolver TypeResolver => _typeResolver ??= new DefaultTypeResolver(_options, _recordMap);

private readonly Id _rootRecord;

/// <summary>
/// Creates <see cref="BinaryFormattedObject"/> by parsing <paramref name="stream"/>.
/// </summary>
Expand All @@ -48,25 +49,32 @@ public BinaryFormattedObject(Stream stream, Options? options = null)
ParseState state = new(reader, this);

IRecord? currentRecord;
do

try
{
try
{
currentRecord = Record.ReadBinaryFormatRecord(state);
}
catch (Exception ex) when (ex is ArgumentException or InvalidCastException or ArithmeticException or IOException)
currentRecord = Record.ReadBinaryFormatRecord(state);
if (currentRecord is not SerializationHeader header)
{
// Make the exception easier to catch, but retain the original stack trace.
throw ex.ConvertToSerializationException();
throw new SerializationException("Did not find serialization header.");
}
catch (TargetInvocationException ex)

_rootRecord = header.RootId;

do
{
throw ExceptionDispatchInfo.Capture(ex.InnerException!).SourceException.ConvertToSerializationException();
currentRecord = Record.ReadBinaryFormatRecord(state);
}

_records.Add(currentRecord);
while (currentRecord is not MessageEnd);
}
catch (Exception ex) when (ex is ArgumentException or InvalidCastException or ArithmeticException or IOException)
{
// Make the exception easier to catch, but retain the original stack trace.
throw ex.ConvertToSerializationException();
}
catch (TargetInvocationException ex)
{
throw ExceptionDispatchInfo.Capture(ex.InnerException!).SourceException.ConvertToSerializationException();
}
while (currentRecord is not MessageEnd);
}

/// <summary>
Expand All @@ -77,8 +85,7 @@ public object Deserialize()
{
try
{
Id rootId = ((SerializationHeader)_records[0]).RootId;
return Deserializer.Deserializer.Deserialize(rootId, _recordMap, TypeResolver, _options);
return Deserializer.Deserializer.Deserialize(RootRecord.Id, _recordMap, TypeResolver, _options);
}
catch (Exception ex) when (ex is ArgumentException or InvalidCastException or ArithmeticException or IOException)
{
Expand All @@ -92,14 +99,9 @@ public object Deserialize()
}

/// <summary>
/// Total count of top-level records.
/// </summary>
public int RecordCount => _records.Count;

/// <summary>
/// Gets a record by it's index.
/// The Id of the root record of the object graph.
/// </summary>
public IRecord this[int index] => _records[index];
public IRecord RootRecord => _recordMap[_rootRecord];

/// <summary>
/// Gets a record by it's identifier. Not all records have identifiers, only ones that
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,9 @@ static bool Get(BinaryFormattedObject format, [NotNullWhen(true)] out object? va
{
value = default;

if (format.RecordCount < 4
|| format[1] is not BinaryLibrary binaryLibrary
if (format.RootRecord is not ClassWithMembersAndTypes classInfo
|| format[classInfo.LibraryId] is not BinaryLibrary binaryLibrary
|| binaryLibrary.LibraryName != TypeInfo.SystemDrawingAssemblyName
|| format[2] is not ClassWithMembersAndTypes classInfo
|| classInfo.Name != typeof(PointF).FullName
|| classInfo.MemberValues.Count != 2)
{
Expand All @@ -85,10 +84,9 @@ static bool Get(BinaryFormattedObject format, [NotNullWhen(true)] out object? va
{
value = default;

if (format.RecordCount < 4
|| format[1] is not BinaryLibrary binaryLibrary
if (format.RootRecord is not ClassWithMembersAndTypes classInfo
|| format[classInfo.LibraryId] is not BinaryLibrary binaryLibrary
|| binaryLibrary.LibraryName != TypeInfo.SystemDrawingAssemblyName
|| format[2] is not ClassWithMembersAndTypes classInfo
|| classInfo.Name != typeof(RectangleF).FullName
|| classInfo.MemberValues.Count != 4)
{
Expand Down Expand Up @@ -116,18 +114,14 @@ public static bool TryGetPrimitiveType(this BinaryFormattedObject format, [NotNu
static bool Get(BinaryFormattedObject format, [NotNullWhen(true)] out object? value)
{
value = default;
if (format.RecordCount < 3)
{
return false;
}

if (format[1] is BinaryObjectString binaryString)
if (format.RootRecord is BinaryObjectString binaryString)
{
value = binaryString.Value;
return true;
}

if (format[1] is not SystemClassWithMembersAndTypes systemClass)
if (format.RootRecord is not SystemClassWithMembersAndTypes systemClass)
{
return false;
}
Expand Down Expand Up @@ -190,10 +184,10 @@ static bool Get(BinaryFormattedObject format, [NotNullWhen(true)] out object? li

const string ListTypeName = "System.Collections.Generic.List`1[[";

if (format.RecordCount != 4
|| format[1] is not SystemClassWithMembersAndTypes classInfo
if (format.RootRecord is not SystemClassWithMembersAndTypes classInfo
|| !classInfo.Name.StartsWith(ListTypeName, StringComparison.Ordinal)
|| format[2] is not ArrayRecord array)
|| classInfo["_items"] is not MemberReference reference
|| format[reference] is not ArrayRecord array)
{
return false;
}
Expand Down Expand Up @@ -278,8 +272,7 @@ static bool Get(BinaryFormattedObject format, [NotNullWhen(true)] out object? va
{
value = null;

if (format.RecordCount != 4
|| format[1] is not SystemClassWithMembersAndTypes classInfo
if (format.RootRecord is not SystemClassWithMembersAndTypes classInfo
|| classInfo.Name != typeof(ArrayList).FullName
|| format[2] is not ArraySingleObject array)
{
Expand Down Expand Up @@ -326,18 +319,18 @@ public static bool TryGetPrimitiveArray(this BinaryFormattedObject format, [NotN
static bool Get(BinaryFormattedObject format, [NotNullWhen(true)] out object? value)
{
value = null;
if (format.RecordCount != 3 || format[1] is not ArrayRecord)
if (format.RootRecord is not ArrayRecord array)
{
return false;
}

if (format[1] is ArraySingleString stringArray)
if (array is ArraySingleString stringArray)
{
value = stringArray.GetStringValues(format.RecordMap).ToArray();
return true;
}

if (format[1] is not IPrimitiveTypeRecord primitiveArray)
if (array is not IPrimitiveTypeRecord primitiveArray)
{
return false;
}
Expand Down Expand Up @@ -389,8 +382,7 @@ static bool Get(BinaryFormattedObject format, [NotNullWhen(true)] out object? ha

// Note that hashtables with custom comparers and/or hash code providers will have that information before
// the value pair arrays.
if (format.RecordCount != 5
|| format[1] is not SystemClassWithMembersAndTypes classInfo
if (format.RootRecord is not SystemClassWithMembersAndTypes classInfo
|| classInfo.Name != TypeInfo.HashtableType
|| format[2] is not ArraySingleObject keys
|| format[3] is not ArraySingleObject values
Expand Down Expand Up @@ -466,8 +458,7 @@ static bool Get(BinaryFormattedObject format, [NotNullWhen(true)] out object? ex
{
exception = null;

if (format.RecordCount < 3
|| format[1] is not SystemClassWithMembersAndTypes classInfo
if (format.RootRecord is not SystemClassWithMembersAndTypes classInfo
|| classInfo.Name != TypeInfo.NotSupportedExceptionType)
{
return false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ internal sealed class BinaryLibrary : IRecord<BinaryLibrary>, IBinaryFormatParse
{
public Id LibraryId { get; }
public string LibraryName { get; }
Id IRecord.Id => LibraryId;

public BinaryLibrary(Id libraryId, string libraryName)
{
Expand All @@ -26,16 +27,8 @@ public BinaryLibrary(Id libraryId, string libraryName)

public static RecordType RecordType => RecordType.BinaryLibrary;

static BinaryLibrary IBinaryFormatParseable<BinaryLibrary>.Parse(
BinaryFormattedObject.IParseState state)
{
BinaryLibrary record = new(
state.Reader.ReadInt32(),
state.Reader.ReadString());

state.RecordMap[record.LibraryId] = record;
return record;
}
static BinaryLibrary IBinaryFormatParseable<BinaryLibrary>.Parse(BinaryFormattedObject.IParseState state) =>
new(state.Reader.ReadInt32(), state.Reader.ReadString());

public void Write(BinaryWriter writer)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ internal sealed class BinaryObjectString : IRecord<BinaryObjectString>, IBinaryF
{
public Id ObjectId { get; }
public string Value { get; }
Id IRecord.Id => ObjectId;

public static RecordType RecordType => RecordType.BinaryObjectString;

Expand All @@ -26,14 +27,8 @@ public BinaryObjectString(Id objectId, string value)
Value = value;
}

static BinaryObjectString IBinaryFormatParseable<BinaryObjectString>.Parse(
BinaryFormattedObject.IParseState state)
{
BinaryObjectString record = new(state.Reader.ReadInt32(), state.Reader.ReadString());

state.RecordMap[record.ObjectId] = record;
return record;
}
static BinaryObjectString IBinaryFormatParseable<BinaryObjectString>.Parse(BinaryFormattedObject.IParseState state) =>
new(state.Reader.ReadInt32(), state.Reader.ReadString());

public void Write(BinaryWriter writer)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,16 +52,13 @@ static ClassWithId IBinaryFormatParseable<ClassWithId>.Parse(

if (state.RecordMap[metadataId] is not ClassRecord referencedRecord)
{
throw new SerializationException();
throw new SerializationException("Invalid referenced record type.");
}

ClassWithId record = new(
return new(
objectId,
referencedRecord,
ReadObjectMemberValues(state, referencedRecord.MemberTypeInfo));

state.RecordMap[record.ObjectId] = record;
return record;
}

public override void Write(BinaryWriter writer)
Expand All @@ -79,7 +76,7 @@ public override void Write(BinaryWriter writer)
WriteValuesFromMemberTypeInfo(writer, systemClassWithMembersAndTypes.MemberTypeInfo, MemberValues);
break;
case ClassWithMembers or SystemClassWithMembers:
WriteRecords(writer, MemberValues);
WriteRecords(writer, MemberValues, coalesceNulls: false);
break;
default:
throw new SerializationException();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,11 @@ static ClassWithMembers IBinaryFormatParseable<ClassWithMembers>.Parse(
ClassInfo classInfo = ClassInfo.Parse(state.Reader, out _);
Id libraryId = state.Reader.ReadInt32();
MemberTypeInfo memberTypeInfo = MemberTypeInfo.CreateFromClassInfoAndLibrary(state, classInfo, libraryId);
ClassWithMembers record = new(
return new(
classInfo,
libraryId,
memberTypeInfo,
ReadObjectMemberValues(state, memberTypeInfo));

// Index this record by the id of the embedded ClassInfo's object id.
state.RecordMap[classInfo.ObjectId] = record;
return record;
}

public override void Write(BinaryWriter writer)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,11 @@ static ClassWithMembersAndTypes IBinaryFormatParseable<ClassWithMembersAndTypes>
ClassInfo classInfo = ClassInfo.Parse(state.Reader, out Count memberCount);
MemberTypeInfo memberTypeInfo = MemberTypeInfo.Parse(state.Reader, memberCount);

ClassWithMembersAndTypes record = new(
return new(
classInfo,
state.Reader.ReadInt32(),
memberTypeInfo,
ReadObjectMemberValues(state, memberTypeInfo));

// Index this record by the id of the embedded ClassInfo's object id.
state.RecordMap[record.ClassInfo.ObjectId] = record;
return record;
}

public override void Write(BinaryWriter writer)
Expand Down
Loading

0 comments on commit c926856

Please sign in to comment.