Skip to content

Commit

Permalink
"Wedge" handshakes after falling back
Browse files Browse the repository at this point in the history
  • Loading branch information
jchambers committed Mar 3, 2024
1 parent b1af66e commit 2cdf7be
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 1 deletion.
20 changes: 19 additions & 1 deletion src/main/java/com/eatthepath/noise/NoiseHandshake.java
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ public class NoiseHandshake {

private int currentMessagePattern = 0;
private boolean hasSplit = false;
private boolean hasFallenBack = false;

private final CipherState cipherState;
private final NoiseHash noiseHash;
Expand Down Expand Up @@ -478,6 +479,10 @@ public boolean isOneWayHandshake() {
* @see #isDone()
*/
public boolean isExpectingRead() {
if (hasFallenBack) {
return false;
}

if (currentMessagePattern < handshakePattern.getHandshakeMessagePatterns().length) {
return handshakePattern.getHandshakeMessagePatterns()[currentMessagePattern].sender() != role;
}
Expand All @@ -497,6 +502,10 @@ public boolean isExpectingRead() {
* @see #isDone()
*/
public boolean isExpectingWrite() {
if (hasFallenBack) {
return false;
}

if (currentMessagePattern < handshakePattern.getHandshakeMessagePatterns().length) {
return handshakePattern.getHandshakeMessagePatterns()[currentMessagePattern].sender() == role;
}
Expand All @@ -514,6 +523,10 @@ public boolean isExpectingWrite() {
* @see #isExpectingWrite()
*/
public boolean isDone() {
if (hasFallenBack) {
return false;
}

return currentMessagePattern == handshakePattern.getHandshakeMessagePatterns().length;
}

Expand Down Expand Up @@ -1246,7 +1259,6 @@ private void handleMixKeyToken(final HandshakePattern.Token token) {
* @see HandshakePattern#isFallbackPattern()
*/
public NoiseHandshake fallbackTo(final String handshakePatternName) throws NoSuchPatternException {
// TODO Self-destruct after falling back
return fallbackTo(handshakePatternName, null);
}

Expand All @@ -1271,6 +1283,10 @@ public NoiseHandshake fallbackTo(final String handshakePatternName) throws NoSuc
* @see HandshakePattern#isFallbackPattern()
*/
public NoiseHandshake fallbackTo(final String handshakePatternName, @Nullable final List<byte[]> preSharedKeys) throws NoSuchPatternException {
if (hasFallenBack) {
throw new IllegalStateException("Handshake has already fallen back to another pattern");
}

final HandshakePattern fallbackPattern = HandshakePattern.getInstance(handshakePatternName);

if (!fallbackPattern.isFallbackPattern()) {
Expand Down Expand Up @@ -1313,6 +1329,8 @@ public NoiseHandshake fallbackTo(final String handshakePatternName, @Nullable fi
fallbackRemoteEphemeralPublicKey = null;
}

hasFallenBack = true;

return new NoiseHandshake(role,
fallbackPattern,
keyAgreement,
Expand Down
36 changes: 36 additions & 0 deletions src/test/java/com/eatthepath/noise/NoiseHandshakeTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
import com.eatthepath.noise.component.NoiseKeyAgreement;
import org.junit.jupiter.api.Test;

import javax.crypto.AEADBadTagException;
import javax.crypto.ShortBufferException;
import java.nio.ByteBuffer;
import java.security.KeyPair;
import java.security.NoSuchAlgorithmException;
import java.security.PublicKey;

import static org.junit.jupiter.api.Assertions.*;

Expand Down Expand Up @@ -130,4 +133,37 @@ void readMessageShortBuffer() throws NoSuchAlgorithmException {
assertThrows(ShortBufferException.class, () ->
handshake.readMessage(ByteBuffer.wrap(message), ByteBuffer.allocate(payloadLength - 1)));
}

@Test
void repeatedFallback() throws NoSuchAlgorithmException {
final NoiseKeyAgreement keyAgreement = NoiseKeyAgreement.getInstance("25519");

final KeyPair initiatorStaticKeyPair = keyAgreement.generateKeyPair();
final PublicKey staleRemoteStaticPublicKey = keyAgreement.generateKeyPair().getPublic();
final KeyPair currentResponderStaticKeyPair = keyAgreement.generateKeyPair();

final byte[] initiatorStaticKeyMessage;
{
final NoiseHandshake ikInitiatorHandshake =
NoiseHandshakeBuilder.forIKInitiator(initiatorStaticKeyPair, staleRemoteStaticPublicKey)
.setComponentsFromProtocolName("Noise_IK_25519_AESGCM_SHA256")
.build();

initiatorStaticKeyMessage = ikInitiatorHandshake.writeMessage((byte[]) null);
}

final NoiseHandshake ikResponderHandshake =
NoiseHandshakeBuilder.forIKResponder(currentResponderStaticKeyPair)
.setComponentsFromProtocolName("Noise_IK_25519_AESGCM_SHA256")
.build();

assertThrows(AEADBadTagException.class, () -> ikResponderHandshake.readMessage(initiatorStaticKeyMessage));

assertDoesNotThrow(() -> ikResponderHandshake.fallbackTo("XXfallback"));
assertThrows(IllegalStateException.class, () -> ikResponderHandshake.fallbackTo("XXfallback"));

assertFalse(ikResponderHandshake.isExpectingRead());
assertFalse(ikResponderHandshake.isExpectingWrite());
assertFalse(ikResponderHandshake.isDone());
}
}

0 comments on commit 2cdf7be

Please sign in to comment.