Skip to content

Commit

Permalink
Reduce visibility of internals
Browse files Browse the repository at this point in the history
  • Loading branch information
jchambers committed Feb 26, 2024
1 parent abb8850 commit 7f4df24
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 73 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ private static Map<String, String> buildTemplateModel(final HandshakePattern han
case RESPONDER -> "Responder";
};

final String methodSafePatternName = HandshakePattern.getFundamentalPatternName(handshakePattern.name()) +
HandshakePattern.getModifiers(handshakePattern.name()).stream()
final String methodSafePatternName = HandshakePattern.getFundamentalPatternName(handshakePattern.getName()) +
HandshakePattern.getModifiers(handshakePattern.getName()).stream()
.map(modifier -> {
final char firstChar = Character.toUpperCase(modifier.charAt(0));
return firstChar + modifier.substring(1);
Expand Down Expand Up @@ -112,7 +112,7 @@ private static Map<String, String> buildTemplateModel(final HandshakePattern han
"%METHOD_SAFE_PATTERN_NAME%", methodSafePatternName,
"%METHOD_SAFE_ROLE_NAME%", methodSafeRoleName,
"%ROLE_ENUM_KEY%", role.name(),
"%PATTERN_NAME%", handshakePattern.name(),
"%PATTERN_NAME%", handshakePattern.getName(),
"%ARGUMENT_LIST%", String.join(", ", arguments),
"%LOCAL_STATIC_KEY_PAIR_ARGUMENT%", localStaticKeyPairArgument,
"%REMOTE_STATIC_PUBLIC_KEY_ARGUMENT%", remoteStaticPublicKeyArgument,
Expand Down
72 changes: 48 additions & 24 deletions src/main/java/com/eatthepath/noise/HandshakePattern.java
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;

public record HandshakePattern(String name, MessagePattern[] preMessagePatterns, MessagePattern[] handshakeMessagePatterns) {
public class HandshakePattern {

private final String name;

private final MessagePattern[] preMessagePatterns;
private final MessagePattern[] handshakeMessagePatterns;

private static final Map<String, HandshakePattern> FUNDAMENTAL_PATTERNS_BY_NAME;

Expand Down Expand Up @@ -316,14 +321,21 @@ public record HandshakePattern(String name, MessagePattern[] preMessagePatterns,
-> se, es
""")
.map(HandshakePattern::fromString)
.collect(Collectors.toMap(HandshakePattern::name, handshakePattern -> handshakePattern));
.collect(Collectors.toMap(HandshakePattern::getName, handshakePattern -> handshakePattern));
}

private static final Map<String, HandshakePattern> DERIVED_PATTERNS_BY_NAME = new ConcurrentHashMap<>();

private static final String PRE_MESSAGE_SEPARATOR = "...";

public record MessagePattern(NoiseHandshake.Role sender, Token[] tokens) {
HandshakePattern(final String name, final MessagePattern[] preMessagePatterns, final MessagePattern[] handshakeMessagePatterns) {
this.name = name;

this.preMessagePatterns = preMessagePatterns;
this.handshakeMessagePatterns = handshakeMessagePatterns;
}

record MessagePattern(NoiseHandshake.Role sender, Token[] tokens) {
@Override
public boolean equals(final Object o) {
if (this == o) return true;
Expand All @@ -340,7 +352,7 @@ public int hashCode() {
}
}

public enum Token {
enum Token {
E,
S,
EE,
Expand All @@ -363,6 +375,18 @@ static Token fromString(final String string) {
}
}

public String getName() {
return name;
}

MessagePattern[] getPreMessagePatterns() {
return preMessagePatterns;
}

MessagePattern[] getHandshakeMessagePatterns() {
return handshakeMessagePatterns;
}

public static HandshakePattern getInstance(final String name) throws NoSuchPatternException {
if (FUNDAMENTAL_PATTERNS_BY_NAME.containsKey(name)) {
return FUNDAMENTAL_PATTERNS_BY_NAME.get(name);
Expand Down Expand Up @@ -431,18 +455,18 @@ HandshakePattern withModifier(final String modifier) {

if ("fallback".equals(modifier)) {
// TODO Make sure first handshake message is eligible for fallback
modifiedPreMessagePatterns = new MessagePattern[preMessagePatterns().length + 1];
modifiedHandshakeMessagePatterns = new MessagePattern[handshakeMessagePatterns().length - 1];
modifiedPreMessagePatterns = new MessagePattern[getPreMessagePatterns().length + 1];
modifiedHandshakeMessagePatterns = new MessagePattern[getHandshakeMessagePatterns().length - 1];

System.arraycopy(preMessagePatterns(), 0, modifiedPreMessagePatterns, 0, preMessagePatterns().length);
modifiedPreMessagePatterns[modifiedPreMessagePatterns.length - 1] = handshakeMessagePatterns()[0];
System.arraycopy(getPreMessagePatterns(), 0, modifiedPreMessagePatterns, 0, getPreMessagePatterns().length);
modifiedPreMessagePatterns[modifiedPreMessagePatterns.length - 1] = getHandshakeMessagePatterns()[0];

System.arraycopy(handshakeMessagePatterns(), 1, modifiedHandshakeMessagePatterns, 0, handshakeMessagePatterns().length - 1);
System.arraycopy(getHandshakeMessagePatterns(), 1, modifiedHandshakeMessagePatterns, 0, getHandshakeMessagePatterns().length - 1);
} else if (modifier.startsWith("psk")) {
final int pskIndex = Integer.parseInt(modifier.substring("psk".length()));

modifiedPreMessagePatterns = preMessagePatterns().clone();
modifiedHandshakeMessagePatterns = handshakeMessagePatterns().clone();
modifiedPreMessagePatterns = getPreMessagePatterns().clone();
modifiedHandshakeMessagePatterns = getHandshakeMessagePatterns().clone();

if (pskIndex == 0) {
// Insert a PSK token at the start of the first message
Expand All @@ -467,17 +491,17 @@ HandshakePattern withModifier(final String modifier) {

final String modifiedName;

if (name().equals(getFundamentalPatternName(name()))) {
if (getName().equals(getFundamentalPatternName(getName()))) {
// Our current name doesn't have any modifiers, and so this is the first
modifiedName = name() + modifier;
modifiedName = getName() + modifier;
} else {
modifiedName = name() + "+" + modifier;
modifiedName = getName() + "+" + modifier;
}

return new HandshakePattern(modifiedName, modifiedPreMessagePatterns, modifiedHandshakeMessagePatterns);
}

public static HandshakePattern fromString(final String patternString) {
static HandshakePattern fromString(final String patternString) {
final String name = patternString.lines()
.findFirst()
.filter(line -> line.endsWith(":"))
Expand Down Expand Up @@ -548,18 +572,18 @@ private static MessagePattern messagePatternFromString(final String messagePatte
}

public boolean isOneWayPattern() {
return Arrays.stream(handshakeMessagePatterns())
return Arrays.stream(getHandshakeMessagePatterns())
.allMatch(messagePattern -> messagePattern.sender() == NoiseHandshake.Role.INITIATOR);
}

public boolean isFallbackPattern() {
return getModifiers(name()).contains("fallback");
boolean isFallbackPattern() {
return getModifiers(getName()).contains("fallback");
}

public boolean requiresLocalStaticKeyPair(final NoiseHandshake.Role role) {
// The given role needs a local static key pair if any pre-handshake message or handshake message involves that role
// sending a static key to the other party
return Stream.concat(Arrays.stream(preMessagePatterns()), Arrays.stream(handshakeMessagePatterns()))
return Stream.concat(Arrays.stream(getPreMessagePatterns()), Arrays.stream(getHandshakeMessagePatterns()))
.filter(messagePattern -> messagePattern.sender() == role)
.flatMap(messagePattern -> Arrays.stream(messagePattern.tokens()))
.anyMatch(token -> token == Token.S);
Expand All @@ -568,7 +592,7 @@ public boolean requiresLocalStaticKeyPair(final NoiseHandshake.Role role) {
public boolean requiresRemoteEphemeralPublicKey(final NoiseHandshake.Role role) {
// The given role needs a remote static key pair if the handshake pattern involves that role receiving an ephemeral
// key from the other party in a pre-handshake message
return Arrays.stream(preMessagePatterns())
return Arrays.stream(getPreMessagePatterns())
.filter(messagePattern -> messagePattern.sender() != role)
.flatMap(messagePattern -> Arrays.stream(messagePattern.tokens()))
.anyMatch(token -> token == Token.E);
Expand All @@ -577,20 +601,20 @@ public boolean requiresRemoteEphemeralPublicKey(final NoiseHandshake.Role role)
public boolean requiresRemoteStaticPublicKey(final NoiseHandshake.Role role) {
// The given role needs a remote static key pair if the handshake pattern involves that role receiving a static key
// from the other party in a pre-handshake message
return Arrays.stream(preMessagePatterns())
return Arrays.stream(getPreMessagePatterns())
.filter(messagePattern -> messagePattern.sender() != role)
.flatMap(messagePattern -> Arrays.stream(messagePattern.tokens()))
.anyMatch(token -> token == Token.S);
}

public boolean isPreSharedKeyHandshake() {
return Arrays.stream(handshakeMessagePatterns())
boolean isPreSharedKeyHandshake() {
return Arrays.stream(getHandshakeMessagePatterns())
.flatMap(messagePattern -> Arrays.stream(messagePattern.tokens()))
.anyMatch(token -> token == Token.PSK);
}

public int getRequiredPreSharedKeyCount() {
return Math.toIntExact(Arrays.stream(handshakeMessagePatterns())
return Math.toIntExact(Arrays.stream(getHandshakeMessagePatterns())
.flatMap(messagePattern -> Arrays.stream(messagePattern.tokens()))
.filter(token -> token == Token.PSK)
.count());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ static void validatePublicKeysPresentForKeyAgreement(final HandshakePattern hand
boolean hasRemoteStaticKey = false;
boolean hasRemoteEphemeralKey = false;

for (final HandshakePattern.MessagePattern messagePattern : handshakePattern.preMessagePatterns()) {
for (final HandshakePattern.MessagePattern messagePattern : handshakePattern.getPreMessagePatterns()) {
if (messagePattern.sender() != role) {
for (final HandshakePattern.Token token : messagePattern.tokens()) {
switch (token) {
Expand All @@ -41,7 +41,7 @@ static void validatePublicKeysPresentForKeyAgreement(final HandshakePattern hand
}
}

for (final HandshakePattern.MessagePattern messagePattern : handshakePattern.handshakeMessagePatterns()) {
for (final HandshakePattern.MessagePattern messagePattern : handshakePattern.getHandshakeMessagePatterns()) {
for (HandshakePattern.Token token : messagePattern.tokens()) {
switch (token) {
case E -> {
Expand Down Expand Up @@ -114,7 +114,7 @@ static void validateKeyTransmissionLimits(final HandshakePattern handshakePatter
for (final NoiseHandshake.Role role : NoiseHandshake.Role.values()) {
for (final HandshakePattern.Token token : new HandshakePattern.Token[] { HandshakePattern.Token.E, HandshakePattern.Token.S }) {
final long tokenCount =
Stream.concat(Arrays.stream(handshakePattern.preMessagePatterns()), Arrays.stream(handshakePattern.handshakeMessagePatterns()))
Stream.concat(Arrays.stream(handshakePattern.getPreMessagePatterns()), Arrays.stream(handshakePattern.getHandshakeMessagePatterns()))
.filter(messagePattern -> messagePattern.sender() == role)
.flatMap(messagePattern -> Arrays.stream(messagePattern.tokens()))
.filter(t -> t == token)
Expand All @@ -134,7 +134,7 @@ static void validateKeyAgreementLimits(final HandshakePattern handshakePattern)
HandshakePattern.Token.EE, HandshakePattern.Token.ES, HandshakePattern.Token.SE, HandshakePattern.Token.SS
}) {

final long tokenCount = Arrays.stream(handshakePattern.handshakeMessagePatterns())
final long tokenCount = Arrays.stream(handshakePattern.getHandshakeMessagePatterns())
.flatMap(messagePattern -> Arrays.stream(messagePattern.tokens()))
.filter(t -> t == token)
.count();
Expand All @@ -152,7 +152,7 @@ static void validateKeyAgreementBeforeEncrypt(final HandshakePattern handshakePa
final Set<HandshakePattern.Token> encounteredTokens = new HashSet<>();
final EnumMap<NoiseHandshake.Role, Set<HandshakePattern.Token>> requiredTokensByRole = new EnumMap<>(NoiseHandshake.Role.class);

for (final HandshakePattern.MessagePattern messagePattern : handshakePattern.handshakeMessagePatterns()) {
for (final HandshakePattern.MessagePattern messagePattern : handshakePattern.getHandshakeMessagePatterns()) {
for (final HandshakePattern.Token token : messagePattern.tokens()) {
encounteredTokens.add(token);

Expand Down Expand Up @@ -184,14 +184,14 @@ static void validateKeyAgreementBeforeEncrypt(final HandshakePattern handshakePa

static void validatePreSharedKeyEphemeralKey(final HandshakePattern handshakePattern) {
for (final NoiseHandshake.Role role : NoiseHandshake.Role.values()) {
boolean hasSentEphemeralKey = Arrays.stream(handshakePattern.preMessagePatterns())
boolean hasSentEphemeralKey = Arrays.stream(handshakePattern.getPreMessagePatterns())
.filter(messagePattern -> messagePattern.sender() == role)
.flatMap(messagePattern -> Arrays.stream(messagePattern.tokens()))
.anyMatch(token -> token == HandshakePattern.Token.E);

boolean needsEphemeralKey = false;

for (final HandshakePattern.MessagePattern messagePattern : handshakePattern.handshakeMessagePatterns()) {
for (final HandshakePattern.MessagePattern messagePattern : handshakePattern.getHandshakeMessagePatterns()) {
for (final HandshakePattern.Token token : messagePattern.tokens()) {
if (token == HandshakePattern.Token.PSK) {
needsEphemeralKey = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ public NamedProtocolHandshakeBuilder setLocalEphemeralKeyPair(@Nullable final Ke

public NamedProtocolHandshakeBuilder setLocalStaticKeyPair(@Nullable final KeyPair localStaticKeyPair) {
if (!handshakePattern.requiresLocalStaticKeyPair(role)) {
throw new IllegalStateException(handshakePattern.name() + " handshake pattern does not allow local static keys for " + role + " role");
throw new IllegalStateException(handshakePattern.getName() + " handshake pattern does not allow local static keys for " + role + " role");
}

this.localStaticKeyPair = localStaticKeyPair;
Expand All @@ -69,7 +69,7 @@ public NamedProtocolHandshakeBuilder setLocalStaticKeyPair(@Nullable final KeyPa

public NamedProtocolHandshakeBuilder setRemoteStaticPublicKey(@Nullable final PublicKey remoteStaticPublicKey) {
if (!handshakePattern.requiresRemoteStaticPublicKey(role)) {
throw new IllegalStateException(handshakePattern.name() + " handshake pattern does not allow remote static key for " + role + " role");
throw new IllegalStateException(handshakePattern.getName() + " handshake pattern does not allow remote static key for " + role + " role");
}

this.remoteStaticPublicKey = remoteStaticPublicKey;
Expand All @@ -80,11 +80,11 @@ public NamedProtocolHandshakeBuilder setPreSharedKeys(final List<byte[]> preShar
final int requiredPreSharedKeys = handshakePattern.getRequiredPreSharedKeyCount();

if (requiredPreSharedKeys == 0) {
throw new IllegalStateException(handshakePattern.name() + " handshake pattern does not allow pre-shared keys");
throw new IllegalStateException(handshakePattern.getName() + " handshake pattern does not allow pre-shared keys");
}

if (preSharedKeys.size() != requiredPreSharedKeys) {
throw new IllegalArgumentException(handshakePattern.name() + " requires exactly " + requiredPreSharedKeys + " pre-shared keys");
throw new IllegalArgumentException(handshakePattern.getName() + " requires exactly " + requiredPreSharedKeys + " pre-shared keys");
}

if (preSharedKeys.stream().anyMatch(preSharedKey -> preSharedKey.length != 32)) {
Expand All @@ -97,11 +97,11 @@ public NamedProtocolHandshakeBuilder setPreSharedKeys(final List<byte[]> preShar

public NoiseHandshake build() {
if (handshakePattern.requiresRemoteStaticPublicKey(role) && remoteStaticPublicKey == null) {
throw new IllegalStateException(handshakePattern.name() + " handshake pattern requires a remote static public key for the " + role + " role");
throw new IllegalStateException(handshakePattern.getName() + " handshake pattern requires a remote static public key for the " + role + " role");
}

if (handshakePattern.requiresLocalStaticKeyPair(role) && localStaticKeyPair == null) {
throw new IllegalStateException(handshakePattern.name() + " handshake pattern requires a local static key pair for the " + role + " role");
throw new IllegalStateException(handshakePattern.getName() + " handshake pattern requires a local static key pair for the " + role + " role");
}

// TODO Check key compatibility if applicable
Expand Down
Loading

0 comments on commit 7f4df24

Please sign in to comment.