Skip to content

Commit

Permalink
CollectionsMarshal.GetValueRef(Dictionary) (#49388)
Browse files Browse the repository at this point in the history
  • Loading branch information
benaadams committed Mar 17, 2021
1 parent 604ea07 commit 46127ea
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -352,7 +352,7 @@ public virtual void GetObjectData(SerializationInfo info, StreamingContext conte
}
}

private ref TValue FindValue(TKey key)
internal ref TValue FindValue(TKey key)
{
if (key == null)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,20 @@ public static class CollectionsMarshal
/// Get a <see cref="Span{T}"/> view over a <see cref="List{T}"/>'s data.
/// Items should not be added or removed from the <see cref="List{T}"/> while the <see cref="Span{T}"/> is in use.
/// </summary>
/// <param name="list">The list to get the data view over.</param>
public static Span<T> AsSpan<T>(List<T>? list)
=> list is null ? default : new Span<T>(list._items, 0, list._size);

/// <summary>
/// Gets either a ref to a <typeparamref name="TValue"/> in the <see cref="Dictionary{TKey, TValue}"/> or a ref null if it does not exist in the <paramref name="dictionary"/>.
/// </summary>
/// <param name="dictionary">The dictionary to get the ref to <typeparamref name="TValue"/> from.</param>
/// <param name="key">The key used for lookup.</param>
/// <remarks>
/// Items should not be added or removed from the <see cref="Dictionary{TKey, TValue}"/> while the ref <typeparamref name="TValue"/> is in use.
/// The ref null can be detected using System.Runtime.CompilerServices.Unsafe.IsNullRef
/// </remarks>
public static ref TValue GetValueRefOrNullRef<TKey, TValue>(Dictionary<TKey, TValue> dictionary, TKey key) where TKey : notnull
=> ref dictionary.FindValue(key);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ public sealed partial class CoClassAttribute : System.Attribute
public static partial class CollectionsMarshal
{
public static System.Span<T> AsSpan<T>(System.Collections.Generic.List<T>? list) { throw null; }
public static ref TValue GetValueRefOrNullRef<TKey, TValue>(System.Collections.Generic.Dictionary<TKey, TValue> dictionary, TKey key) where TKey : notnull { throw null; }
}
[System.AttributeUsageAttribute(System.AttributeTargets.Field | System.AttributeTargets.Parameter | System.AttributeTargets.Property | System.AttributeTargets.ReturnValue, Inherited=false)]
public sealed partial class ComAliasNameAttribute : System.Attribute
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;

using Xunit;

namespace System.Runtime.InteropServices.Tests
Expand Down Expand Up @@ -142,9 +144,171 @@ public void ListAsSpanLinkBreaksOnResize()
}
}

[Fact]
public void GetValueRefOrNullRefValueType()
{
var dict = new Dictionary<int, Struct>
{
{ 1, default },
{ 2, default }
};

Assert.Equal(2, dict.Count);

Assert.Equal(0, dict[1].Value);
Assert.Equal(0, dict[1].Property);

var itemVal = dict[1];
itemVal.Value = 1;
itemVal.Property = 2;

// Does not change values in dictionary
Assert.Equal(0, dict[1].Value);
Assert.Equal(0, dict[1].Property);

CollectionsMarshal.GetValueRefOrNullRef(dict, 1).Value = 3;
CollectionsMarshal.GetValueRefOrNullRef(dict, 1).Property = 4;

Assert.Equal(3, dict[1].Value);
Assert.Equal(4, dict[1].Property);

ref var itemRef = ref CollectionsMarshal.GetValueRefOrNullRef(dict, 2);

Assert.Equal(0, itemRef.Value);
Assert.Equal(0, itemRef.Property);

itemRef.Value = 5;
itemRef.Property = 6;

Assert.Equal(5, itemRef.Value);
Assert.Equal(6, itemRef.Property);
Assert.Equal(dict[2].Value, itemRef.Value);
Assert.Equal(dict[2].Property, itemRef.Property);

itemRef = new() { Value = 7, Property = 8 };

Assert.Equal(7, itemRef.Value);
Assert.Equal(8, itemRef.Property);
Assert.Equal(dict[2].Value, itemRef.Value);
Assert.Equal(dict[2].Property, itemRef.Property);

// Check for null refs

Assert.True(Unsafe.IsNullRef(ref CollectionsMarshal.GetValueRefOrNullRef(dict, 3)));
Assert.Throws<NullReferenceException>(() => CollectionsMarshal.GetValueRefOrNullRef(dict, 3).Value = 9);

Assert.Equal(2, dict.Count);
}

[Fact]
public void GetValueRefOrNullRefClass()
{
var dict = new Dictionary<int, IntAsObject>
{
{ 1, new() },
{ 2, new() }
};

Assert.Equal(2, dict.Count);

Assert.Equal(0, dict[1].Value);
Assert.Equal(0, dict[1].Property);

var itemVal = dict[1];
itemVal.Value = 1;
itemVal.Property = 2;

// Does change values in dictionary
Assert.Equal(1, dict[1].Value);
Assert.Equal(2, dict[1].Property);

CollectionsMarshal.GetValueRefOrNullRef(dict, 1).Value = 3;
CollectionsMarshal.GetValueRefOrNullRef(dict, 1).Property = 4;

Assert.Equal(3, dict[1].Value);
Assert.Equal(4, dict[1].Property);

ref var itemRef = ref CollectionsMarshal.GetValueRefOrNullRef(dict, 2);

Assert.Equal(0, itemRef.Value);
Assert.Equal(0, itemRef.Property);

itemRef.Value = 5;
itemRef.Property = 6;

Assert.Equal(5, itemRef.Value);
Assert.Equal(6, itemRef.Property);
Assert.Equal(dict[2].Value, itemRef.Value);
Assert.Equal(dict[2].Property, itemRef.Property);

itemRef = new() { Value = 7, Property = 8 };

Assert.Equal(7, itemRef.Value);
Assert.Equal(8, itemRef.Property);
Assert.Equal(dict[2].Value, itemRef.Value);
Assert.Equal(dict[2].Property, itemRef.Property);

// Check for null refs

Assert.True(Unsafe.IsNullRef(ref CollectionsMarshal.GetValueRefOrNullRef(dict, 3)));
Assert.Throws<NullReferenceException>(() => CollectionsMarshal.GetValueRefOrNullRef(dict, 3).Value = 9);

Assert.Equal(2, dict.Count);
}

[Fact]
public void GetValueRefOrNullRefLinkBreaksOnResize()
{
var dict = new Dictionary<int, Struct>
{
{ 1, new() }
};

Assert.Equal(1, dict.Count);

ref var itemRef = ref CollectionsMarshal.GetValueRefOrNullRef(dict, 1);

Assert.Equal(0, itemRef.Value);
Assert.Equal(0, itemRef.Property);

itemRef.Value = 1;
itemRef.Property = 2;

Assert.Equal(1, itemRef.Value);
Assert.Equal(2, itemRef.Property);
Assert.Equal(dict[1].Value, itemRef.Value);
Assert.Equal(dict[1].Property, itemRef.Property);

// Resize
dict.EnsureCapacity(100);
for (int i = 2; i <= 50; i++)
{
dict.Add(i, new());
}

itemRef.Value = 3;
itemRef.Property = 4;

Assert.Equal(3, itemRef.Value);
Assert.Equal(4, itemRef.Property);

// Check connection broken
Assert.NotEqual(dict[1].Value, itemRef.Value);
Assert.NotEqual(dict[1].Property, itemRef.Property);

Assert.Equal(50, dict.Count);
}

private struct Struct
{
public int Value;
public int Property { get; set; }
}

private class IntAsObject
{
public int Value;
public int Property { get; set; }
}
}
}

0 comments on commit 46127ea

Please sign in to comment.