Skip to content

Commit

Permalink
Resolve TODOs and robustify tests in HandshakePattern
Browse files Browse the repository at this point in the history
  • Loading branch information
jchambers committed Mar 3, 2024
1 parent 5fc40ea commit e89fc10
Show file tree
Hide file tree
Showing 2 changed files with 192 additions and 27 deletions.
57 changes: 36 additions & 21 deletions src/main/java/com/eatthepath/noise/HandshakePattern.java
Original file line number Diff line number Diff line change
Expand Up @@ -464,13 +464,14 @@ static List<String> getModifiers(final String fullPatternName) {
}

HandshakePattern withModifier(final String modifier) {
// TODO Disallow duplicate modifiers

final MessagePattern[] modifiedPreMessagePatterns;
final MessagePattern[] modifiedHandshakeMessagePatterns;

if ("fallback".equals(modifier)) {
// TODO Make sure first handshake message is eligible for fallback
if (!isValidFallbackMessagePattern(handshakeMessagePatterns[0])) {
throw new IllegalStateException("Cannot generate fallback pattern; first message pattern is not a fallback-eligible message pattern");
}

modifiedPreMessagePatterns = new MessagePattern[getPreMessagePatterns().length + 1];
modifiedHandshakeMessagePatterns = new MessagePattern[getHandshakeMessagePatterns().length - 1];

Expand Down Expand Up @@ -517,6 +518,20 @@ HandshakePattern withModifier(final String modifier) {
return new HandshakePattern(modifiedName, modifiedPreMessagePatterns, modifiedHandshakeMessagePatterns);
}

static boolean isValidFallbackMessagePattern(final MessagePattern messagePattern) {
if (messagePattern.sender() != NoiseHandshake.Role.INITIATOR) {
return false;
}

if (messagePattern.tokens().length == 1) {
return messagePattern.tokens()[0] == Token.E || messagePattern.tokens()[0] == Token.S;
} else if (messagePattern.tokens().length == 2) {
return messagePattern.tokens()[0] == Token.E && messagePattern.tokens()[1] == Token.S;
}

return false;
}

static HandshakePattern fromString(final String patternString) {
final String name = patternString.lines()
.findFirst()
Expand Down Expand Up @@ -604,6 +619,24 @@ boolean isFallbackPattern() {
return getModifiers(getName()).contains("fallback");
}

boolean isPreSharedKeyHandshake() {
return Arrays.stream(getHandshakeMessagePatterns())
.flatMap(messagePattern -> Arrays.stream(messagePattern.tokens()))
.anyMatch(token -> token == Token.PSK);
}

/**
* Returns the number of pre-shared keys either party in this handshake must provide prior to beginning the handshake.
*
* @return the number of pre-shared keys either party in this handshake must provide prior to beginning the handshake
*/
int getRequiredPreSharedKeyCount() {
return Math.toIntExact(Arrays.stream(getHandshakeMessagePatterns())
.flatMap(messagePattern -> Arrays.stream(messagePattern.tokens()))
.filter(token -> token == Token.PSK)
.count());
}

/**
* Checks whether the party with the given role in this handshake must supply a local static key pair prior to
* beginning the handshake.
Expand Down Expand Up @@ -658,24 +691,6 @@ boolean requiresRemoteStaticPublicKey(final NoiseHandshake.Role role) {
.anyMatch(token -> token == Token.S);
}

boolean isPreSharedKeyHandshake() {
return Arrays.stream(getHandshakeMessagePatterns())
.flatMap(messagePattern -> Arrays.stream(messagePattern.tokens()))
.anyMatch(token -> token == Token.PSK);
}

/**
* Returns the number of pre-shared keys either party in this handshake must provide prior to beginning the handshake.
*
* @return the number of pre-shared keys either party in this handshake must provide prior to beginning the handshake
*/
int getRequiredPreSharedKeyCount() {
return Math.toIntExact(Arrays.stream(getHandshakeMessagePatterns())
.flatMap(messagePattern -> Arrays.stream(messagePattern.tokens()))
.filter(token -> token == Token.PSK)
.count());
}

/**
* Tests whether this handshake pattern is equal to another object. This handshake pattern is equal to the given
* object if the given object is also a handshake pattern and has the same name and message patterns as this handshake
Expand Down
162 changes: 156 additions & 6 deletions src/test/java/com/eatthepath/noise/HandshakePatternTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,36 @@ class HandshakePatternTest {

@Test
void getInstance() throws NoSuchPatternException {
{
final HandshakePattern expectedXXPattern = HandshakePattern.fromString("""
XX:
-> e
<- e, ee, s, es
-> s, se
""");

assertEquals(expectedXXPattern, HandshakePattern.getInstance("XX"));
}


{
final HandshakePattern expectedXXFallbackPsk0Pattern =HandshakePattern.fromString("""
XXfallback+psk0:
-> e
...
<- psk, e, ee, s, es
-> s, se
""");

assertEquals(expectedXXFallbackPsk0Pattern, HandshakePattern.getInstance("XXfallback+psk0"));
}

assertThrows(NoSuchPatternException.class,
() -> HandshakePattern.getInstance("This is not a legitimate Noise handshake pattern"));
}

@Test
void fromString() {
{
final HandshakePattern expectedXXPattern = new HandshakePattern("XX",
new MessagePattern[0],
Expand All @@ -28,7 +58,12 @@ void getInstance() throws NoSuchPatternException {
new MessagePattern(Role.INITIATOR, new Token[]{Token.S, Token.SE}),
});

assertEquals(expectedXXPattern, HandshakePattern.getInstance("XX"));
assertEquals(expectedXXPattern, HandshakePattern.fromString("""
XX:
-> e
<- e, ee, s, es
-> s, se
"""));
}

{
Expand All @@ -42,11 +77,15 @@ void getInstance() throws NoSuchPatternException {
new MessagePattern(Role.RESPONDER, new Token[]{Token.E, Token.EE, Token.SE})
});

assertEquals(expectedKKPattern, HandshakePattern.getInstance("KK"));
assertEquals(expectedKKPattern, HandshakePattern.fromString("""
KK:
-> s
<- s
...
-> e, es, ss
<- e, ee, se
"""));
}

assertThrows(NoSuchPatternException.class,
() -> HandshakePattern.getInstance("This is not a legitimate Noise handshake pattern"));
}

@ParameterizedTest
Expand Down Expand Up @@ -75,7 +114,7 @@ private static List<Arguments> getModifiers() {
}

@Test
void withModifier() throws NoSuchPatternException {
void withFallbackModifier() throws NoSuchPatternException {
final HandshakePattern expectedXXFallbackPattern = HandshakePattern.fromString("""
XXfallback:
-> e
Expand All @@ -86,4 +125,115 @@ void withModifier() throws NoSuchPatternException {

assertEquals(expectedXXFallbackPattern, HandshakePattern.getInstance("XX").withModifier("fallback"));
}

@Test
void withPskModifier() throws NoSuchPatternException {
{
final HandshakePattern expectedNKPsk0Pattern = HandshakePattern.fromString("""
NKpsk0:
<- s
...
-> psk, e, es
<- e, ee
""");

assertEquals(expectedNKPsk0Pattern, HandshakePattern.getInstance("NK").withModifier("psk0"));
}

{
final HandshakePattern expectedXXPsk3Pattern = HandshakePattern.fromString("""
XXpsk3:
-> e
<- e, ee, s, es
-> s, se, psk
""");

assertEquals(expectedXXPsk3Pattern, HandshakePattern.getInstance("XX").withModifier("psk3"));
}
}

@Test
void withModifierUnrecognized() {
assertThrows(IllegalArgumentException.class, () -> HandshakePattern.getInstance("XX").withModifier("fancy"));
}

@ParameterizedTest
@MethodSource
void isValidFallbackMessagePattern(final MessagePattern messagePattern, final boolean expectValidFallbackMessagePattern) {
assertEquals(expectValidFallbackMessagePattern, HandshakePattern.isValidFallbackMessagePattern(messagePattern));
}

private static List<Arguments> isValidFallbackMessagePattern() {
return List.of(
Arguments.of(new MessagePattern(Role.INITIATOR, new Token[] { Token.E }), true),
Arguments.of(new MessagePattern(Role.INITIATOR, new Token[] { Token.S }), true),
Arguments.of(new MessagePattern(Role.INITIATOR, new Token[] { Token.E, Token.S }), true),
Arguments.of(new MessagePattern(Role.RESPONDER, new Token[] { Token.E }), false),
Arguments.of(new MessagePattern(Role.RESPONDER, new Token[] { Token.S }), false),
Arguments.of(new MessagePattern(Role.RESPONDER, new Token[] { Token.E, Token.S }), false),
Arguments.of(new MessagePattern(Role.INITIATOR, new Token[] { Token.EE }), false),
Arguments.of(new MessagePattern(Role.INITIATOR, new Token[] { Token.E, Token.S, Token.EE }), false)
);
}

@Test
void isOneWayPattern() throws NoSuchPatternException {
assertTrue(HandshakePattern.getInstance("N").isOneWayPattern());
assertFalse(HandshakePattern.getInstance("NK").isOneWayPattern());
}

@Test
void isFallbackPattern() throws NoSuchPatternException {
assertTrue(HandshakePattern.getInstance("XXfallback").isFallbackPattern());
assertFalse(HandshakePattern.getInstance("XX").isFallbackPattern());
}

@Test
void isPreSharedKeyHandshake() throws NoSuchPatternException {
assertFalse(HandshakePattern.getInstance("N").isPreSharedKeyHandshake());
assertTrue(HandshakePattern.getInstance("Npsk0").isPreSharedKeyHandshake());

assertFalse(HandshakePattern.getInstance("NN").isPreSharedKeyHandshake());
assertTrue(HandshakePattern.getInstance("NNpsk2").isPreSharedKeyHandshake());
}

@ParameterizedTest
@CsvSource({
"N, 0",
"NN, 0",
"Npsk0, 1",
"NNpsk2, 1",
"NNpsk0+psk2, 2"
})
void getRequiredPreSharedKeyCount(final String handshakePatternName, final int expectedRequiredPreSharedKeyCount) throws NoSuchPatternException {
assertEquals(expectedRequiredPreSharedKeyCount,
HandshakePattern.getInstance(handshakePatternName).getRequiredPreSharedKeyCount());
}

@Test
void requiresLocalStaticKeyPair() throws NoSuchPatternException {
assertTrue(HandshakePattern.getInstance("XN").requiresLocalStaticKeyPair(Role.INITIATOR));
assertFalse(HandshakePattern.getInstance("XN").requiresLocalStaticKeyPair(Role.RESPONDER));

assertTrue(HandshakePattern.getInstance("NX").requiresLocalStaticKeyPair(Role.RESPONDER));
assertFalse(HandshakePattern.getInstance("NX").requiresLocalStaticKeyPair(Role.INITIATOR));
}

@Test
void requiresRemoteEphemeralPublicKey() throws NoSuchPatternException {
assertTrue(HandshakePattern.getInstance("XXfallback").requiresRemoteEphemeralPublicKey(Role.RESPONDER));
assertFalse(HandshakePattern.getInstance("XXfallback").requiresRemoteEphemeralPublicKey(Role.INITIATOR));

assertFalse(HandshakePattern.getInstance("NX").requiresRemoteEphemeralPublicKey(Role.RESPONDER));
assertFalse(HandshakePattern.getInstance("NX").requiresRemoteEphemeralPublicKey(Role.INITIATOR));
}

@Test
void requiresRemoteStaticPublicKey() throws NoSuchPatternException {
assertTrue(HandshakePattern.getInstance("NK").requiresRemoteStaticPublicKey(Role.INITIATOR));
assertFalse(HandshakePattern.getInstance("NK").requiresRemoteStaticPublicKey(Role.RESPONDER));

assertTrue(HandshakePattern.getInstance("KN").requiresRemoteStaticPublicKey(Role.RESPONDER));
assertFalse(HandshakePattern.getInstance("KN").requiresRemoteStaticPublicKey(Role.INITIATOR));
}
}

0 comments on commit e89fc10

Please sign in to comment.