Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 51 additions & 43 deletions src/libraries/Common/src/System/Security/Cryptography/MLKem.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1170,6 +1170,7 @@ public string ExportEncryptedPkcs8PrivateKeyPem(string password, PbeParameters p
/// </exception>
public static MLKem ImportSubjectPublicKeyInfo(ReadOnlySpan<byte> source)
{
ThrowIfTrailingData(source);
ThrowIfNotSupported();

unsafe
Expand Down Expand Up @@ -1690,62 +1691,75 @@ private static void MLKemKeyReader(
MLKemAlgorithm algorithm = GetAlgorithmIdentifier(in algorithmIdentifier);
MLKemPrivateKeyAsn kemKey = MLKemPrivateKeyAsn.Decode(privateKeyContents, AsnEncodingRules.BER);

try
if (kemKey.Seed is ReadOnlyMemory<byte> seed)
{
if (kemKey.Seed is ReadOnlyMemory<byte> seed)
if (seed.Length != algorithm.PrivateSeedSizeInBytes)
{
kem = ImportPrivateSeed(algorithm, seed.Span);
throw new CryptographicException(SR.Cryptography_Der_Invalid_Encoding);
}
else if (kemKey.ExpandedKey is ReadOnlyMemory<byte> expandedKey)

kem = MLKemImplementation.ImportPrivateSeedImpl(algorithm, seed.Span);
}
else if (kemKey.ExpandedKey is ReadOnlyMemory<byte> expandedKey)
{
if (expandedKey.Length != algorithm.DecapsulationKeySizeInBytes)
{
kem = ImportDecapsulationKey(algorithm, expandedKey.Span);
throw new CryptographicException(SR.Cryptography_Der_Invalid_Encoding);
}
else if (kemKey.Both is MLKemPrivateKeyBothAsn both)

kem = MLKemImplementation.ImportDecapsulationKeyImpl(algorithm, expandedKey.Span);
}
else if (kemKey.Both is MLKemPrivateKeyBothAsn both)
{
int decapsulationKeySize = algorithm.DecapsulationKeySizeInBytes;

if (both.Seed.Length != algorithm.PrivateSeedSizeInBytes ||
both.ExpandedKey.Length != decapsulationKeySize)
{
MLKem key = ImportPrivateSeed(algorithm, both.Seed.Span);
int decapsulationKeySize = key.Algorithm.DecapsulationKeySizeInBytes;
byte[] rent = CryptoPool.Rent(decapsulationKeySize);
Span<byte> buffer = rent.AsSpan(0, decapsulationKeySize);
throw new CryptographicException(SR.Cryptography_Der_Invalid_Encoding);
}

try
{
key.ExportDecapsulationKey(buffer);
MLKem key = MLKemImplementation.ImportPrivateSeedImpl(algorithm, both.Seed.Span);
byte[] rent = CryptoPool.Rent(decapsulationKeySize);
Span<byte> buffer = rent.AsSpan(0, decapsulationKeySize);

if (CryptographicOperations.FixedTimeEquals(buffer, both.ExpandedKey.Span))
{
kem = key;
}
else
{
throw new CryptographicException(SR.Cryptography_KemPkcs8KeyMismatch);
}
}
catch
try
{
key.ExportDecapsulationKey(buffer);

if (CryptographicOperations.FixedTimeEquals(buffer, both.ExpandedKey.Span))
{
key.Dispose();
throw;
kem = key;
}
finally
else
{
CryptoPool.Return(rent, decapsulationKeySize);
throw new CryptographicException(SR.Cryptography_KemPkcs8KeyMismatch);
}
}
else
catch
{
throw new CryptographicException(SR.Cryptography_Der_Invalid_Encoding);
key.Dispose();
throw;
}
finally
{
CryptoPool.Return(rent, decapsulationKeySize);
}
}
catch (ArgumentException ae)
else
{
throw new CryptographicException(SR.Cryptography_Der_Invalid_Encoding, ae);
throw new CryptographicException(SR.Cryptography_Der_Invalid_Encoding);
}
}

private static void ThrowIfTrailingData(ReadOnlySpan<byte> data)
{
AsnDecoder.ReadEncodedValue(data, AsnEncodingRules.BER, out _, out _, out int bytesRead);
// The only thing we are checking here is that TryReadEncodedValue was able to decode it and that, given
// the length of the data, that it the same length as the span. The encoding rules don't matter for length
// checking, so just use BER.
bool success = AsnDecoder.TryReadEncodedValue(data, AsnEncodingRules.BER, out _, out _, out _, out int bytesRead);

if (bytesRead != data.Length)
if (!success || bytesRead != data.Length)
{
throw new CryptographicException(SR.Cryptography_Der_Invalid_Encoding);
}
Expand Down Expand Up @@ -1801,32 +1815,26 @@ private TResult ExportPkcs8PrivateKeyCallback<TResult>(ExportPkcs8PrivateKeyFunc
// Decapsulation keys are always larger than the seed, so if we end up with a seed export it should
// fit in the initial buffer.
int size = Algorithm.DecapsulationKeySizeInBytes + 32;
byte[] buffer = ArrayPool<byte>.Shared.Rent(size); // Released to callers, do not use CryptoPool.
byte[] buffer = CryptoPool.Rent(size); // Only passed out as span, callees can't keep a reference to it
int written;

while (!TryExportPkcs8PrivateKeyCore(buffer, out written))
{
ClearAndReturnToPool(buffer, written);
CryptoPool.Return(buffer);
size = checked(size * 2);
buffer = ArrayPool<byte>.Shared.Rent(size);
}

if (written > buffer.Length)
if (written < 0 || written > buffer.Length)
{
// We got a nonsense value written back. Clear the buffer, but don't put it back in the pool.
CryptographicOperations.ZeroMemory(buffer);
throw new CryptographicException();
}

TResult result = func(buffer.AsSpan(0, written));
ClearAndReturnToPool(buffer, written);
CryptoPool.Return(buffer, written);
return result;

static void ClearAndReturnToPool(byte[] buffer, int clearSize)
{
CryptographicOperations.ZeroMemory(buffer.AsSpan(0, clearSize));
ArrayPool<byte>.Shared.Return(buffer);
}
}

private static string EncodeAsnWriterToPem(string label, AsnWriter writer, bool clear = true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ public static void ExportPkcs8PrivateKey_ExpandAndRetry()
}

[Fact]
public static void ExportPkcs8PrivateKey_MisbehavingBytesWritten()
public static void ExportPkcs8PrivateKey_MisbehavingBytesWritten_Oversized()
{
MLKemContract kem = new(MLKemAlgorithm.MLKem512)
{
Expand All @@ -700,6 +700,21 @@ public static void ExportPkcs8PrivateKey_MisbehavingBytesWritten()
Assert.Throws<CryptographicException>(() => kem.ExportPkcs8PrivateKey());
}

[Fact]
public static void ExportPkcs8PrivateKey_MisbehavingBytesWritten_Negative()
{
MLKemContract kem = new(MLKemAlgorithm.MLKem512)
{
OnTryExportPkcs8PrivateKeyCore = (Span<byte> destination, out int bytesWritten) =>
{
bytesWritten = -1;
return true;
}
};

Assert.Throws<CryptographicException>(() => kem.ExportPkcs8PrivateKey());
}

[Fact]
public static void ExportPkcs8PrivateKey_Disposed()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ public static void ImportPrivateSeed_NotSupported(MLKemAlgorithm algorithm)
public static void ImportSubjectPublicKeyInfo_NotSupported()
{
Assert.Throws<PlatformNotSupportedException>(() =>
MLKem.ImportSubjectPublicKeyInfo(Array.Empty<byte>()));
MLKem.ImportSubjectPublicKeyInfo(MLKemTestData.IetfMlKem512Spki));

Assert.Throws<PlatformNotSupportedException>(() =>
MLKem.ImportSubjectPublicKeyInfo(ReadOnlySpan<byte>.Empty));
MLKem.ImportSubjectPublicKeyInfo(new ReadOnlySpan<byte>(MLKemTestData.IetfMlKem512Spki)));
}

[Theory]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,13 @@ public static void ImportSubjectPublicKeyInfo_WrongAlgorithm()
Assert.Throws<CryptographicException>(() => MLKem.ImportSubjectPublicKeyInfo(ecP256Spki));
}

[Fact]
public static void ImportSubjectPublicKeyInfo_NotAsn()
{
Assert.Throws<CryptographicException>(() => MLKem.ImportSubjectPublicKeyInfo("potatoes"u8));
Assert.Throws<CryptographicException>(() => MLKem.ImportSubjectPublicKeyInfo("potatoes"u8.ToArray()));
}

[Fact]
public static void ImportSubjectPublicKeyInfo_WrongParameters()
{
Expand Down Expand Up @@ -214,6 +221,15 @@ public static void ImportSubjectPublicKeyInfo_WrongSize()
Assert.Throws<CryptographicException>(() => MLKem.ImportSubjectPublicKeyInfo(mlKem512BadEncapKey));
}

[Fact]
public static void ImportSubjectPublicKeyInfo_TrailingData()
{
byte[] spki = new byte[MLKemTestData.IetfMlKem512Spki.Length + 1];
MLKemTestData.IetfMlKem512Spki.AsSpan().CopyTo(spki);
Assert.Throws<CryptographicException>(() => MLKem.ImportSubjectPublicKeyInfo(spki));
Assert.Throws<CryptographicException>(() => MLKem.ImportSubjectPublicKeyInfo(new ReadOnlySpan<byte>(spki)));
}

[Fact]
public static void ImportPkcs8PrivateKey_NullSource()
{
Expand Down Expand Up @@ -453,7 +469,14 @@ public static void ImportPkcs8PrivateKey_Both_TrailingData()
}
}

[Fact]
[Fact]
public static void ImportPkcs8PrivateKey_NotAsn()
{
Assert.Throws<CryptographicException>(() => MLKem.ImportPkcs8PrivateKey("potatoes"u8));
Assert.Throws<CryptographicException>(() => MLKem.ImportPkcs8PrivateKey("potatoes"u8.ToArray()));
}

[Fact]
public static void ImportEncryptedPkcs8PrivateKey_WrongAlgorithm()
{
byte[] ecP256Key = Convert.FromBase64String(@"
Expand Down Expand Up @@ -492,6 +515,19 @@ public static void ImportEncryptedPkcs8PrivateKey_TrailingData()
}
}

[Fact]
public static void ImportEncryptedPkcs8PrivateKey_NotAsn()
{
Assert.Throws<CryptographicException>(() =>
MLKem.ImportEncryptedPkcs8PrivateKey("PLACEHOLDER", "potatoes"u8.ToArray()));

Assert.Throws<CryptographicException>(() =>
MLKem.ImportEncryptedPkcs8PrivateKey("PLACEHOLDER".AsSpan(), "potatoes"u8));

Assert.Throws<CryptographicException>(() =>
MLKem.ImportEncryptedPkcs8PrivateKey("PLACEHOLDER"u8, "potatoes"u8));
}

[Fact]
public static void ImportEncryptedPkcs8PrivateKey_DoesNotProcessUnencryptedData()
{
Expand Down
Loading