From 8145bb831ebd334b56d2f11c76fbca56ba480ed4 Mon Sep 17 00:00:00 2001 From: Yat Long Poon Date: Wed, 12 Nov 2025 15:15:47 +0000 Subject: [PATCH] Simplify UTF-16 validation Vector128 codepath Combine the SSE2 codepath with a more generic Vector128 algorithm. AdvSimd is handled slightly differently to avoid using Vector128 ExtractMostSignificantBits, because there is no such equivalent instruction on Arm so the performance would be very slow otherwise. --- .../Text/Unicode/Utf16Utility.Validation.cs | 386 ++++++------------ 1 file changed, 128 insertions(+), 258 deletions(-) diff --git a/src/libraries/System.Private.CoreLib/src/System/Text/Unicode/Utf16Utility.Validation.cs b/src/libraries/System.Private.CoreLib/src/System/Text/Unicode/Utf16Utility.Validation.cs index 4843b66101fe21..66604610a9ff63 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Text/Unicode/Utf16Utility.Validation.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Text/Unicode/Utf16Utility.Validation.cs @@ -1,18 +1,66 @@ // Licensed to the .NET Foundation under one or more agreements. // The .NET Foundation licenses this file to you under the MIT license. -using System.Buffers.Text; using System.Diagnostics; using System.Numerics; using System.Runtime.CompilerServices; using System.Runtime.Intrinsics; using System.Runtime.Intrinsics.Arm; -using System.Runtime.Intrinsics.X86; namespace System.Text.Unicode { internal static unsafe partial class Utf16Utility { + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static nuint GetSurrogateMask(Vector128 cmp) + { + // Convert the comparison result to a scalar surrogate mask. + // The elements in 'cmp' should be either all bits set or zero. + + if (AdvSimd.Arm64.IsSupported) + { + // Since ExtractMostSignificantBits is very slow on AdvSimd, + // we use a 64-bit value to encode the mask, where each byte represents one element: + // 0x01 for all bits set, 0x00 for zero. + ulong mask = AdvSimd.Arm64.UnzipOdd(cmp.AsByte(), cmp.AsByte()).AsUInt64().ToScalar(); + return (nuint)(mask & 0x0101010101010101u); + } + + // Otherwise, encode the mask with 8-bits (one byte), where each bit represents one element. + return cmp.ExtractMostSignificantBits(); + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool IsSurrogatesMatch(nuint maskHigh, nuint maskLow) + { + // Make sure that each high surrogate is followed by a low surrogate character, + // and each low surrogate follows a high surrogate character. + // The last character is discarded as it will be checked by 'IsLastCharHighSurrogate'. + // The first character must not be a low surrogate. This is checked by matching + // 'maskLow' aganist the zeros inserted after shifting 'maskHigh' to the left. + + if (AdvSimd.Arm64.IsSupported) + { + // Each surrogate character is 8 bits apart. + return (maskHigh << 8) == maskLow; + } + // Each surrogate character is 1 bit apart. + return (byte)(maskHigh << 1) == (byte)maskLow; + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static bool IsLastCharHighSurrogate(nuint maskHigh) + { + if (AdvSimd.Arm64.IsSupported) + { + // Check if the top byte is not zero. + return (maskHigh >>> 56) != 0; + } + // Check if the top bit (of a byte) is not zero. + return ((byte)maskHigh >>> 7) != 0; + } + // Returns &inputBuffer[inputLength] if the input buffer is valid. /// /// Given an input buffer of char length , @@ -59,17 +107,14 @@ internal static unsafe partial class Utf16Utility int tempScalarCountAdjustment = 0; char* pEndOfInputBuffer = pInputBuffer + (uint)inputLength; - // Per https://github.com/dotnet/runtime/issues/41699, temporarily disabling - // ARM64-intrinsicified code paths. ARM64 platforms may still use the vectorized - // non-intrinsicified 'else' block below. - - if (/* (AdvSimd.Arm64.IsSupported && BitConverter.IsLittleEndian) || */ Sse2.IsSupported) + if (Vector128.IsHardwareAccelerated) { if (inputLength >= Vector128.Count) { - Vector128 vector0080 = Vector128.Create((ushort)0x0080); - Vector128 vector7800 = Vector128.Create((ushort)0x7800); - Vector128 vectorA000 = Vector128.Create((ushort)0xA000); + Vector128 vector0080 = Vector128.Create(0x0080); + Vector128 vector0400 = Vector128.Create(0x0400); + Vector128 vector0800 = Vector128.Create(0x0800); + Vector128 vectorD800 = Vector128.Create(0xD800); char* pHighestAddressWhereCanReadOneVector = pEndOfInputBuffer - Vector128.Count; Debug.Assert(pHighestAddressWhereCanReadOneVector >= pInputBuffer); @@ -78,286 +123,113 @@ internal static unsafe partial class Utf16Utility { Vector128 utf16Data = Vector128.Load((ushort*)pInputBuffer); - pInputBuffer += Vector128.Count; // eagerly bump this now in preparation for next loop, will adjust later if necessary - - // Sets the 0x0080 bit of each element in 'charIsNonAscii' if the corresponding - // input was 0x0080 <= [value]. (i.e., [value] is non-ASCII.) - - Vector128 charIsNonAscii = Vector128.Min(utf16Data, vector0080); - -#if DEBUG - // Quick check to ensure we didn't accidentally set the 0x8000 bit of any element. - uint debugMask = charIsNonAscii.AsByte().ExtractMostSignificantBits(); - Debug.Assert((debugMask & 0b_1010_1010_1010_1010) == 0, "Shouldn't have set the 0x8000 bit of any element in 'charIsNonAscii'."); -#endif // DEBUG - - // Sets the 0x8080 bits of each element in 'charIsNonAscii' if the corresponding - // input was 0x0800 <= [value]. This also handles the missing range a few lines above. - - // Since 3-byte elements have a value >= 0x0800, we'll perform a saturating add of 0x7800 in order to - // get all 3-byte elements to have their 0x8000 bits set. A saturating add will not set the 0x8000 - // bit for 1-byte or 2-byte elements. The 0x0080 bit will already have been set for non-ASCII (2-byte - // and 3-byte) elements. - - Vector128 charIsThreeByteUtf8Encoded = Vector128.AddSaturate(utf16Data, vector7800); - uint mask = (charIsNonAscii | charIsThreeByteUtf8Encoded).AsByte().ExtractMostSignificantBits(); - - // Each even bit of mask will be 1 only if the char was >= 0x0080, - // and each odd bit of mask will be 1 only if the char was >= 0x0800. - // - // Example for UTF-16 input "[ 0123 ] [ 1234 ] ...": - // - // ,-- set if char[1] is >= 0x0800 - // | ,-- set if char[0] is >= 0x0800 - // v v - // mask = ... 1 1 0 1 - // ^ ^-- set if char[0] is non-ASCII - // `-- set if char[1] is non-ASCII - // - // This means we can popcnt the number of set bits, and the result is the - // number of *additional* UTF-8 bytes that each UTF-16 code unit requires as - // it expands. This results in the wrong count for UTF-16 surrogate code - // units (we just counted that each individual code unit expands to 3 bytes, - // but in reality a well-formed UTF-16 surrogate pair expands to 4 bytes). - // We'll handle this in just a moment. + // Calculate the popcnt for UTF-8 adjustments, which is the number of *additional* + // UTF-8 bytes that each UTF-16 code unit requires as it expands. + // This results in the wrong count for UTF-16 surrogate code units (we just counted + // that each individual code unit expands to 3 bytes, but in reality a well-formed + // UTF-16 surrogate pair expands to 4 bytes). We'll handle this in just a moment. // // For now, compute the popcnt but squirrel it away. We'll fold it in to the // cumulative UTF-8 adjustment factor once we determine that there are no // unpaired surrogates in our data. (Unpaired surrogates would invalidate // our computed result and we'd have to throw it away.) - nuint popcnt = (uint)BitOperations.PopCount(mask); // on x64, perform zero-extension for free - - // Surrogates need to be special-cased for two reasons: (a) we need - // to account for the fact that we over-counted in the addition above; - // and (b) they require separate validation. - // - // Since surrogate code points are [D800..DFFF], adding {A000} to each element moves surrogate - // code points to [7800..7FFF], which allows performing a single signed comparison. - - mask = Vector128.LessThan((utf16Data + vectorA000).AsInt16(), vector7800.AsInt16()).AsByte().ExtractMostSignificantBits(); + uint popcnt; - FinishIteration: - - // Note: mask bits are set when the corresponding element is NOT a surrogate. - // We'll invert this before entering the "validate surrogate pairs" logic below. - - if (mask == 0xFFFF) - { - // Put this logic up top since it's predicted-taken (surrogate pairs are uncommon). - // Either we saw no surrogates or we already handled them below. + // On AdvSimd ExtractMostSignificantBits is very slow, so a different algorithm is used to avoid + // the poor performance. - tempUtf8CodeUnitCountAdjustment += (long)popcnt; - if (pInputBuffer > pHighestAddressWhereCanReadOneVector) - { - goto NonVectorizedLoop; // can no longer read a vector's worth of data - } - } - else + if (AdvSimd.Arm64.IsSupported) { - mask = ~mask; - - // There's at least one UTF-16 surrogate code unit present. - // Since we performed a pmovmskb operation on the result of a 16-bit pcmpgtw, - // the resulting bits of 'mask' will occur in pairs: - // - 00 if the corresponding UTF-16 char was not a surrogate code unit; - // - 11 if the corresponding UTF-16 char was a surrogate code unit. + // The 'twoOrMoreUtf8Bytes' and 'threeOrMoreUtf8Bytes' vectors will contain + // elements whose values are 0xFFFF (-1 as signed word) iff the corresponding + // UTF-16 code unit was >= 0x0080 and >= 0x0800, respectively. By summing these + // vectors, each element of the sum will contain one of three values: // - // A UTF-16 high/low surrogate code unit has the bit pattern [ 11011q## ######## ], - // where # is any bit; q = 0 represents a high surrogate, and q = 1 represents - // a low surrogate. Right-shifting each surrogate char by 3 bits, we end up with - // [ 00011011 q####### ], which means that we can immediately use pmovmskb to - // determine whether a given char was a high or a low surrogate. + // 0x0000 ( 0) = original char was 0000..007F + // 0xFFFF (-1) = original char was 0080..07FF + // 0xFFFE (-2) = original char was 0800..FFFF // - // Therefore the resulting bits of 'mask2' will occur in pairs: - // - 00 if the corresponding UTF-16 char was a high surrogate code unit; - // - 01 if the corresponding UTF-16 char was a low surrogate code unit; - // - ## (garbage) if the corresponding UTF-16 char was not a surrogate code unit. - // Since 'mask' already has 00 in these positions (since the corresponding char - // wasn't a surrogate), "mask AND mask2 == 00" holds for these positions. - - uint mask2 = Vector128.ShiftRightLogical(utf16Data, 3).AsByte().ExtractMostSignificantBits(); - - // 'lowSurrogatesMask' has its bits occur in pairs: - // - 01 if the corresponding char was a low surrogate char, - // - 00 if the corresponding char was a high surrogate char or not a surrogate at all. - - uint lowSurrogatesMask = mask2 & mask; - - // 'highSurrogatesMask' has its bits occur in pairs: - // - 01 if the corresponding char was a high surrogate char, - // - 00 if the corresponding char was a low surrogate char or not a surrogate at all. - - uint highSurrogatesMask = (mask2 ^ 0b_0101_0101_0101_0101u /* flip all even-numbered bits 00 <-> 01 */) & mask; - - Debug.Assert((highSurrogatesMask & lowSurrogatesMask) == 0, - "A char cannot simultaneously be both a high and a low surrogate char."); - - Debug.Assert(((highSurrogatesMask | lowSurrogatesMask) & 0b_1010_1010_1010_1010u) == 0, - "Only even bits (no odd bits) of the masks should be set."); - - // Now check that each high surrogate is followed by a low surrogate and that each - // low surrogate follows a high surrogate. We make an exception for the case where - // the final char of the vector is a high surrogate, since we can't perform validation - // on it until the next iteration of the loop when we hope to consume the matching - // low surrogate. - - highSurrogatesMask <<= 2; - if ((ushort)highSurrogatesMask != lowSurrogatesMask) - { - break; // error: mismatched surrogate pair; break out of vectorized logic - } - - if (highSurrogatesMask > ushort.MaxValue) - { - // There was a standalone high surrogate at the end of the vector. - // We'll adjust our counters so that we don't consider this char consumed. - - highSurrogatesMask = (ushort)highSurrogatesMask; // don't allow stray high surrogate to be consumed by popcnt - popcnt -= 2; // the '0xC000_0000' bits in the original mask are shifted out and discarded, so account for that here - pInputBuffer--; // don't consume this char (pointer has already been bumped at start of loop) - } - - // If we're 64-bit, we can perform the zero-extension of the surrogate pairs count for - // free right now, saving the extension step a few lines below. If we're 32-bit, the - // conversion to nuint immediately below is a no-op, and we'll pay the cost of the real - // 64 -bit extension a few lines below. - nuint surrogatePairsCountNuint = (uint)BitOperations.PopCount(highSurrogatesMask); - - // 2 UTF-16 chars become 1 Unicode scalar - - tempScalarCountAdjustment -= (int)surrogatePairsCountNuint; - - // Since each surrogate code unit was >= 0x0800, we eagerly assumed - // it'd be encoded as 3 UTF-8 code units, so our earlier popcnt computation - // assumes that the pair is encoded as 6 UTF-8 code units. Since each - // pair is in reality only encoded as 4 UTF-8 code units, we need to - // perform this adjustment now. - - if (IntPtr.Size == 8) - { - // Since we've already zero-extended surrogatePairsCountNuint, we can directly - // sub + sub. It's more efficient than shl + sub. - tempUtf8CodeUnitCountAdjustment -= (long)surrogatePairsCountNuint; - tempUtf8CodeUnitCountAdjustment -= (long)surrogatePairsCountNuint; - } - else - { - // Take the hit of the 64-bit extension now. - tempUtf8CodeUnitCountAdjustment -= 2 * (uint)surrogatePairsCountNuint; - } - - mask = 0xFFFF; // mark "no surrogates require processing" - goto FinishIteration; // jump backward to continue the main loop + // We'll negate them to produce a value 0..2 for each element, then sum all the + // elements together to produce the number of *additional* UTF-8 code units + // required to represent this UTF-16 data. + + Vector128 twoOrMoreUtf8Bytes = Vector128.GreaterThanOrEqual(utf16Data, vector0080); + Vector128 threeOrMoreUtf8Bytes = Vector128.GreaterThanOrEqual(utf16Data, vector0800); + Vector128 sumVector = Vector128.Zero - twoOrMoreUtf8Bytes - threeOrMoreUtf8Bytes; + popcnt = Vector128.Sum(sumVector); } - } while (true); - - // If we reached this point, we saw truly invalid data within the loop. - // Need to undo the eager "bump pInputBuffer" adjustment that took place at start of loop. - - pInputBuffer -= Vector128.Count; - } - } - else if (Vector128.IsHardwareAccelerated) - { - if (inputLength >= Vector128.Count) - { - Vector128 vector0080 = Vector128.Create(0x0080); - Vector128 vector0400 = Vector128.Create(0x0400); - Vector128 vector0800 = Vector128.Create(0x0800); - Vector128 vectorD800 = Vector128.Create(0xD800); + else + { + Vector128 vector7800 = Vector128.Create(0x7800); - char* pHighestAddressWhereCanReadOneVector = pEndOfInputBuffer - Vector128.Count; - Debug.Assert(pHighestAddressWhereCanReadOneVector >= pInputBuffer); + // Sets the 0x0080 bit of each element in 'charIsNonAscii' if the corresponding + // input was 0x0080 <= [value]. (i.e., [value] is non-ASCII.) - do - { - // The 'twoOrMoreUtf8Bytes' and 'threeOrMoreUtf8Bytes' vectors will contain - // elements whose values are 0xFFFF (-1 as signed word) iff the corresponding - // UTF-16 code unit was >= 0x0080 and >= 0x0800, respectively. By summing these - // vectors, each element of the sum will contain one of three values: - // - // 0x0000 ( 0) = original char was 0000..007F - // 0xFFFF (-1) = original char was 0080..07FF - // 0xFFFE (-2) = original char was 0800..FFFF - // - // We'll negate them to produce a value 0..2 for each element, then sum all the - // elements together to produce the number of *additional* UTF-8 code units - // required to represent this UTF-16 data. This is similar to the popcnt step - // performed by the SSE2 code path. This will overcount surrogates, but we'll - // handle that shortly. + Vector128 charIsNonAscii = Vector128.Min(utf16Data, vector0080); - Vector128 utf16Data = Vector128.Load((ushort*)pInputBuffer); - Vector128 twoOrMoreUtf8Bytes = Vector128.GreaterThanOrEqual(utf16Data, vector0080); - Vector128 threeOrMoreUtf8Bytes = Vector128.GreaterThanOrEqual(utf16Data, vector0800); - Vector128 sumVector = (Vector128.Zero - twoOrMoreUtf8Bytes - threeOrMoreUtf8Bytes).AsNUInt(); +#if DEBUG + // Quick check to ensure we didn't accidentally set the 0x8000 bit of any element. + uint debugMask = charIsNonAscii.AsByte().ExtractMostSignificantBits(); + Debug.Assert((debugMask & 0b_1010_1010_1010_1010) == 0, "Shouldn't have set the 0x8000 bit of any element in 'charIsNonAscii'."); +#endif // DEBUG - // We'll try summing by a natural word (rather than a 16-bit word) at a time, - // which should halve the number of operations we must perform. + // Sets the 0x8080 bits of each element in 'charIsNonAscii' if the corresponding + // input was 0x0800 <= [value]. This also handles the missing range a few lines above. + // Since 3-byte elements have a value >= 0x0800, we'll perform a saturating add of 0x7800 in order to + // get all 3-byte elements to have their 0x8000 bits set. A saturating add will not set the 0x8000 + // bit for 1-byte or 2-byte elements. The 0x0080 bit will already have been set for non-ASCII (2-byte + // and 3-byte) elements. - nuint popcnt = 0; - for (int i = 0; i < Vector128.Count; i++) - { - popcnt += (nuint)sumVector[i]; - } + Vector128 charIsThreeByteUtf8Encoded = Vector128.AddSaturate(utf16Data, vector7800); - uint popcnt32 = (uint)popcnt; - if (IntPtr.Size == 8) - { - popcnt32 += (uint)(popcnt >> 32); + // Each even bit of mask will be 1 only if the char was >= 0x0080, + // and each odd bit of mask will be 1 only if the char was >= 0x0800. + // + // Example for UTF-16 input "[ 0123 ] [ 1234 ] ...": + // + // ,-- set if char[1] is >= 0x0800 + // | ,-- set if char[0] is >= 0x0800 + // v v + // mask = ... 1 1 0 1 + // ^ ^-- set if char[0] is non-ASCII + // `-- set if char[1] is non-ASCII + + uint mask = (charIsNonAscii | charIsThreeByteUtf8Encoded).AsByte().ExtractMostSignificantBits(); + popcnt = (uint)BitOperations.PopCount(mask); // on x64, perform zero-extension for free } - // As in the SSE4.1 paths, compute popcnt but don't fold it in until we - // know there aren't any unpaired surrogates in the input data. - - popcnt32 = (ushort)popcnt32 + (popcnt32 >> 16); - // Now check for surrogates. utf16Data -= vectorD800; - Vector128 surrogateChars = Vector128.LessThan(utf16Data, vector0800); - if (surrogateChars != Vector128.Zero) + nuint maskSurr = GetSurrogateMask(Vector128.LessThan(utf16Data, vector0800)); + if (maskSurr != 0) { - // There's at least one surrogate (high or low) UTF-16 code unit in - // the vector. We'll build up additional vectors: 'highSurrogateChars' - // and 'lowSurrogateChars', where the elements are 0xFFFF iff the original - // UTF-16 code unit was a high or low surrogate, respectively. - - Vector128 highSurrogateChars = Vector128.LessThan(utf16Data, vector0400); - Vector128 lowSurrogateChars = Vector128.AndNot(surrogateChars, highSurrogateChars); - - // We want to make sure that each high surrogate code unit is followed by - // a low surrogate code unit and each low surrogate code unit follows a - // high surrogate code unit. Since we don't have an equivalent of pmovmskb - // or palignr available to us, we'll do this as a loop. We won't look at - // the very last high surrogate char element since we don't yet know if - // the next vector read will have a low surrogate char element. - - if (lowSurrogateChars[0] != 0) - { - goto Error; // error: start of buffer contains standalone low surrogate char - } + // Get the surrogate masks for high and low surrogates. + // A high surrogate will be less than 0x0400 after subtracting by 0xD800. + // A low surrogate is a surrogate that is not a high surrogate. + + nuint maskHigh = GetSurrogateMask(Vector128.LessThan(utf16Data, vector0400)); + nuint maskLow = ~maskHigh & maskSurr; - ushort surrogatePairsCount = 0; - for (int i = 0; i < Vector128.Count - 1; i++) + if (!IsSurrogatesMatch(maskHigh, maskLow)) { - surrogatePairsCount -= highSurrogateChars[i]; // turns into +1 or +0 - if (highSurrogateChars[i] != lowSurrogateChars[i + 1]) - { - goto NonVectorizedLoop; // error: mismatched surrogate pair; break out of vectorized logic - } + break; // error: mismatched surrogate pair; break out of vectorized logic } - if (highSurrogateChars[Vector128.Count - 1] != 0) + if (IsLastCharHighSurrogate(maskHigh)) { // There was a standalone high surrogate at the end of the vector. // We'll adjust our counters so that we don't consider this char consumed. pInputBuffer--; - popcnt32 -= 2; + popcnt -= 2; } - nint surrogatePairsCountNint = (nint)surrogatePairsCount; // zero-extend to native int size + // If all the surrogate pairs are valid, then the number of surrogate pairs + // is equal to the number of low surrogates. + + nint surrogatePairsCountNint = (nint)BitOperations.PopCount(maskLow); // 2 UTF-16 chars become 1 Unicode scalar @@ -372,14 +244,12 @@ internal static unsafe partial class Utf16Utility tempUtf8CodeUnitCountAdjustment -= surrogatePairsCountNint; } - tempUtf8CodeUnitCountAdjustment += popcnt32; + tempUtf8CodeUnitCountAdjustment += popcnt; pInputBuffer += Vector128.Count; } while (pInputBuffer <= pHighestAddressWhereCanReadOneVector); } } - NonVectorizedLoop: - // Vectorization isn't supported on our current platform, or the input was too small to benefit // from vectorization, or we saw invalid UTF-16 data in the vectorized code paths and need to // drain remaining valid chars before we report failure.