Skip to content

Commit

Permalink
Resolve NoiseHandshake TODOs and add tests for message length checks
Browse files Browse the repository at this point in the history
  • Loading branch information
jchambers committed Mar 3, 2024
1 parent 228cafc commit 6b1e07c
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 4 deletions.
51 changes: 47 additions & 4 deletions src/main/java/com/eatthepath/noise/NoiseHandshake.java
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,8 @@ public class NoiseHandshake {

private int currentPreSharedKey;

static final int MAX_NOISE_MESSAGE_SIZE = 65_535;

private static final byte[] EMPTY_BYTE_ARRAY = new byte[0];
private static final ByteBuffer EMPTY_BYTE_BUFFER = ByteBuffer.wrap(EMPTY_BYTE_ARRAY);

Expand Down Expand Up @@ -643,8 +645,9 @@ static int getPayloadLength(final HandshakePattern handshakePattern,
* @see <a href="https://noiseprotocol.org/noise.html#payload-security-properties">The Noise Protocol Framework - Payload security properties</a>
*/
public byte[] writeMessage(@Nullable final byte[] payload) {
// TODO Verify that message size is within bounds
final int payloadLength = payload != null ? payload.length : 0;
checkOutboundMessageSize(payloadLength);

final byte[] message = new byte[getOutboundMessageLength(payloadLength)];

try {
Expand Down Expand Up @@ -690,7 +693,11 @@ public int writeMessage(@Nullable final byte[] payload,
final byte[] message,
final int messageOffset) throws ShortBufferException {

// TODO Check message buffer length, or just let plumbing deeper down complain?
checkOutboundMessageSize(payloadLength);

if (message.length - messageOffset < getOutboundMessageLength(payloadLength)) {
throw new ShortBufferException("Message array after offset is not large enough to hold handshake message");
}

if (!isExpectingWrite()) {
throw new IllegalStateException("Handshake not currently expecting to write a message");
Expand Down Expand Up @@ -770,7 +777,10 @@ public int writeMessage(@Nullable final byte[] payload,
* @see <a href="https://noiseprotocol.org/noise.html#payload-security-properties">The Noise Protocol Framework - Payload security properties</a>
*/
public ByteBuffer writeMessage(@Nullable final ByteBuffer payload) {
final ByteBuffer message = ByteBuffer.allocate(getOutboundMessageLength(payload != null ? payload.remaining() : 0));
final int payloadLength = payload != null ? payload.remaining() : 0;
checkOutboundMessageSize(payloadLength);

final ByteBuffer message = ByteBuffer.allocate(getOutboundMessageLength(payloadLength));

try {
writeMessage(payload, message);
Expand Down Expand Up @@ -813,7 +823,12 @@ public ByteBuffer writeMessage(@Nullable final ByteBuffer payload) {
public int writeMessage(@Nullable final ByteBuffer payload,
final ByteBuffer message) throws ShortBufferException {

// TODO Check message buffer length, or just let plumbing deeper down complain?
final int payloadLength = payload != null ? payload.remaining() : 0;
checkOutboundMessageSize(payloadLength);

if (message.remaining() < getOutboundMessageLength(payloadLength)) {
throw new ShortBufferException("Message buffer is not large enough to hold handshake message");
}

if (!isExpectingWrite()) {
throw new IllegalStateException("Handshake not currently expecting to write a message");
Expand Down Expand Up @@ -869,6 +884,12 @@ public int writeMessage(@Nullable final ByteBuffer payload,
return bytesWritten;
}

private void checkOutboundMessageSize(final int payloadLength) {
if (getOutboundMessageLength(payloadLength) > MAX_NOISE_MESSAGE_SIZE) {
throw new IllegalArgumentException("Message containing payload would be larger than maximum allowed Noise message size");
}
}

/**
* Reads the next handshake message, advancing this handshake's internal state.
*
Expand All @@ -881,6 +902,8 @@ public int writeMessage(@Nullable final ByteBuffer payload,
* @throws IllegalArgumentException if the given message is too short to contain the expected handshake message
*/
public byte[] readMessage(final byte[] message) throws AEADBadTagException {
checkInboundMessageSize(message.length);

final byte[] payload = new byte[getPayloadLength(message.length)];

try {
Expand Down Expand Up @@ -921,6 +944,12 @@ public int readMessage(final byte[] message,
final byte[] payload,
final int payloadOffset) throws ShortBufferException, AEADBadTagException {

checkInboundMessageSize(messageLength);

if (payload.length - payloadOffset < getPayloadLength(messageLength)) {
throw new ShortBufferException("Payload array after offset is not large enough to hold payload");
}

if (!isExpectingRead()) {
throw new IllegalStateException("Handshake not currently expecting to read a message");
}
Expand Down Expand Up @@ -987,6 +1016,8 @@ public int readMessage(final byte[] message,
* @throws IllegalArgumentException if the given message is too short to contain the expected handshake message
*/
public ByteBuffer readMessage(final ByteBuffer message) throws AEADBadTagException {
checkInboundMessageSize(message.remaining());

final ByteBuffer payload = ByteBuffer.allocate(getPayloadLength(message.remaining()));

try {
Expand Down Expand Up @@ -1025,6 +1056,12 @@ public ByteBuffer readMessage(final ByteBuffer message) throws AEADBadTagExcepti
public int readMessage(final ByteBuffer message,
final ByteBuffer payload) throws ShortBufferException, AEADBadTagException {

checkInboundMessageSize(message.remaining());

if (payload.remaining() < getPayloadLength(message.remaining())) {
throw new ShortBufferException("Payload buffer is not large enough to hold payload");
}

if (!isExpectingRead()) {
throw new IllegalStateException("Handshake not currently expecting to read a message");
}
Expand Down Expand Up @@ -1077,6 +1114,12 @@ public int readMessage(final ByteBuffer message,
return decryptAndHash(message, payload);
}

private void checkInboundMessageSize(final int messageSize) {
if (messageSize > MAX_NOISE_MESSAGE_SIZE) {
throw new IllegalArgumentException("Message is larger than maximum allowed Noise message size");
}
}

private void handleMixKeyToken(final HandshakePattern.Token token) {
switch (token) {
case EE -> {
Expand Down
86 changes: 86 additions & 0 deletions src/test/java/com/eatthepath/noise/NoiseHandshakeTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import javax.annotation.Nullable;
import javax.crypto.AEADBadTagException;
import javax.crypto.ShortBufferException;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
Expand Down Expand Up @@ -191,6 +192,91 @@ void getPayloadLength() throws NoSuchPatternException {
() -> NoiseHandshake.getPayloadLength(handshakePattern, 0, publicKeyLength, 55));
}

@Test
void writeMessageOversize() throws NoSuchAlgorithmException {
final NoiseKeyAgreement keyAgreement = NoiseKeyAgreement.getInstance("25519");

final NoiseHandshake handshake =
NoiseHandshakeBuilder.forIKInitiator(keyAgreement.generateKeyPair(), keyAgreement.generateKeyPair().getPublic())
.setComponentsFromProtocolName("Noise_IK_25519_AESGCM_SHA256")
.build();

// We want to make sure we're testing the size of the resulting message (which may include key material and AEAD
// tags) rather than the length of just the payload
final int payloadLength = NoiseHandshake.MAX_NOISE_MESSAGE_SIZE - 1;
final int messageLength = handshake.getOutboundMessageLength(payloadLength);

assertTrue(messageLength > NoiseHandshake.MAX_NOISE_MESSAGE_SIZE);

assertThrows(IllegalArgumentException.class,
() -> handshake.writeMessage(new byte[payloadLength]));

assertThrows(IllegalArgumentException.class,
() -> handshake.writeMessage(new byte[payloadLength], 0, payloadLength, new byte[messageLength], 0));

assertThrows(IllegalArgumentException.class,
() -> handshake.writeMessage(ByteBuffer.allocate(payloadLength)));

assertThrows(IllegalArgumentException.class,
() -> handshake.writeMessage(ByteBuffer.allocate(payloadLength), ByteBuffer.allocate(messageLength)));
}

@Test
void writeMessageShortBuffer() throws NoSuchAlgorithmException {
final NoiseHandshake handshake =
NoiseHandshakeBuilder.forNNInitiator()
.setComponentsFromProtocolName("Noise_NN_25519_AESGCM_SHA256")
.build();

final byte[] payload = new byte[32];
final byte[] message = new byte[payload.length - 1];

assertThrows(ShortBufferException.class, () ->
handshake.writeMessage(payload, 0, payload.length, message, 0));

assertThrows(ShortBufferException.class, () ->
handshake.writeMessage(ByteBuffer.wrap(payload), ByteBuffer.wrap(message)));
}

@Test
void readMessageOversize() throws NoSuchAlgorithmException {
final NoiseHandshake handshake =
NoiseHandshakeBuilder.forNNResponder()
.setComponentsFromProtocolName("Noise_NN_25519_AESGCM_SHA256")
.build();

final int messageLength = NoiseHandshake.MAX_NOISE_MESSAGE_SIZE + 1;

assertThrows(IllegalArgumentException.class, () ->
handshake.readMessage(new byte[messageLength]));

assertThrows(IllegalArgumentException.class, () ->
handshake.readMessage(new byte[messageLength], 0, messageLength, new byte[messageLength], 0));

assertThrows(IllegalArgumentException.class, () ->
handshake.readMessage(ByteBuffer.allocate(messageLength)));

assertThrows(IllegalArgumentException.class, () ->
handshake.readMessage(ByteBuffer.allocate(messageLength), ByteBuffer.allocate(messageLength)));
}

@Test
void readMessageShortBuffer() throws NoSuchAlgorithmException {
final NoiseHandshake handshake =
NoiseHandshakeBuilder.forNNResponder()
.setComponentsFromProtocolName("Noise_NN_25519_AESGCM_SHA256")
.build();

final byte[] message = new byte[128];
final int payloadLength = handshake.getPayloadLength(message.length);

assertThrows(ShortBufferException.class, () ->
handshake.readMessage(message, 0, message.length, new byte[payloadLength - 1], 0));

assertThrows(ShortBufferException.class, () ->
handshake.readMessage(ByteBuffer.wrap(message), ByteBuffer.allocate(payloadLength - 1)));
}

@ParameterizedTest
@MethodSource("cacophonyTestVectors")
void cacophonyTestsWithNewByteArray(final CacophonyTestVector testVector) throws AEADBadTagException {
Expand Down

0 comments on commit 6b1e07c

Please sign in to comment.