Skip to content

Commit

Permalink
Fix CLIENT_HELLO processing.
Browse files Browse the repository at this point in the history
Consider message sequence number also to determine, if a CLIENT_HELLO
starts a new handshake.

Signed-off-by: Achim Kraus <achim.kraus@bosch-si.com>
  • Loading branch information
Achim Kraus committed Jul 30, 2019
1 parent 48aa680 commit 200fe1e
Show file tree
Hide file tree
Showing 3 changed files with 114 additions and 73 deletions.
Expand Up @@ -243,6 +243,8 @@ public class DTLSConnector implements Connector, RecordLayer {
*/
private static final int TLS12_CID_PADDING = 0;

private static final long CLIENT_HELLO_TIMEOUT_MILLIS = TimeUnit.SECONDS.toMillis(60);

/** all the configuration options for the DTLS connector */
private final DtlsConnectorConfig config;

Expand Down Expand Up @@ -385,6 +387,16 @@ public void sessionEstablished(Handshaker handshaker, DTLSSession establishedSes
DTLSConnector.this.sessionEstablished(handshaker, establishedSession);
}

@Override
public void handshakeCompleted(final Handshaker handshaker) {
timer.schedule(new Runnable() {
@Override
public void run() {
handshaker.getConnection().isStartedByClientHello(null);
}
}, CLIENT_HELLO_TIMEOUT_MILLIS, TimeUnit.MILLISECONDS);
}

@Override
public void handshakeFailed(Handshaker handshaker, Throwable error) {
List<RawData> listOut = handshaker.takeDeferredApplicationData();
Expand Down Expand Up @@ -1533,37 +1545,52 @@ private void processNewClientHello(final Record record) {
// the IP address indicated in the client hello message
final AvailableConnections connections = new AvailableConnections();
if (isClientInControlOfSourceIpAddress(clientHello, record, connections)) {
boolean verify = false;
Connection connection;
synchronized (connectionStore) {
connection = connectionStore.get(peerAddress);
if (connection == null || !connection.isStartedByClientHello(clientHello)) {
if (connection != null && !connection.isStartedByClientHello(clientHello)) {
Connection sessionConnection = connections.getConnectionBySessionId();
if (sessionConnection != null && sessionConnection != connection) {
// don't overwrite
verify = true;
} else {
if (sessionConnection != null && sessionConnection == connection) {
connections.setRemoveConnectionBySessionId(true);
}
connection = null;
}
}
if (connection == null) {
connection = new Connection(peerAddress, new SerialExecutor(getExecutorService()));
connection.startByClientHello(clientHello);
if (!connectionStore.put(connection)) {
return;
}
}
}
connections.setConnectionByAddress(connection);
try {

connection.getExecutor().execute(new Runnable() {

@Override
public void run() {
if (running.get()) {
processClientHello(clientHello, record, connections);
if (verify) {
sendHelloVerify(clientHello, record, null);
} else {
connections.setConnectionByAddress(connection);
try {
connection.getExecutor().execute(new Runnable() {
@Override
public void run() {
if (running.get()) {
processClientHello(clientHello, record, connections);
}
}
}
});
} catch (RejectedExecutionException e) {
// dont't terminate connection on shutdown!
LOGGER.debug("Execution rejected while processing record [type: {}, peer: {}]",
record.getType(), peerAddress, e);
} catch (RuntimeException e) {
LOGGER.warn("Unexpected error occurred while processing record [type: {}, peer: {}]",
record.getType(), peerAddress, e);
terminateConnection(connections.getConnectionByAddress(), e, AlertLevel.FATAL, AlertDescription.INTERNAL_ERROR);
});
} catch (RejectedExecutionException e) {
// dont't terminate connection on shutdown!
LOGGER.debug("Execution rejected while processing record [type: {}, peer: {}]",
record.getType(), peerAddress, e);
} catch (RuntimeException e) {
LOGGER.warn("Unexpected error occurred while processing record [type: {}, peer: {}]",
record.getType(), peerAddress, e);
terminateConnection(connections.getConnectionByAddress(), e, AlertLevel.FATAL, AlertDescription.INTERNAL_ERROR);
}
}
}
} catch (HandshakeException e) {
Expand All @@ -1587,10 +1614,9 @@ private void processClientHello(ClientHello clientHello, Record record, Availabl
Connection connection = connections.getConnectionByAddress();
if (connection == null) {
throw new NullPointerException("connection by address must not be null!");
}
if (!connection.equalsPeerAddress(record.getPeerAddress())) {
LOGGER.warn("Drop CLIENT_HELLO, changed address {} => {}!",
record.getPeerAddress(), connection.getPeerAddress());
} else if (!connection.equalsPeerAddress(record.getPeerAddress())) {
LOGGER.warn("Drop CLIENT_HELLO, changed address {} => {}!", record.getPeerAddress(),
connection.getPeerAddress());
return;
}
if (LOGGER.isDebugEnabled()) {
Expand All @@ -1602,11 +1628,9 @@ private void processClientHello(ClientHello clientHello, Record record, Availabl
}

try {
if (connection.hasOngoingHandshake() && connection.isStartedByClientHello(clientHello)) {
// client has sent this message before (maybe our response flight has been lost)
// but we do not want to start over again, so let the existing handshaker handle
// the duplicate
processOngoingHandshakeMessage(clientHello, record, connection);
if (connection.hasEstablishedSession() || connection.getOngoingHandshake() != null) {
LOGGER.debug("Discarding duplicate CLIENT_HELLO message [epoch={}] from peer [{}]!", record.getEpoch(),
record.getPeerAddress());
} else if (clientHello.hasSessionId()) {
// client wants to resume a cached session
resumeExistingSession(clientHello, record, connections);
Expand Down Expand Up @@ -1661,7 +1685,7 @@ private boolean isClientInControlOfSourceIpAddress(ClientHello clientHello, Reco
try {
byte[] expectedCookie = null;
byte[] providedCookie = clientHello.getCookie();
if (providedCookie != null && providedCookie.length > 0) {
if (providedCookie.length > 0) {
expectedCookie = cookieGenerator.generateCookie(clientHello);
// if cookie is present, it must match
if (Arrays.equals(expectedCookie, providedCookie)) {
Expand All @@ -1674,7 +1698,7 @@ private boolean isClientInControlOfSourceIpAddress(ClientHello clientHello, Reco
record.getPeerAddress());
}
// otherwise send verify request
} else {
} else {
// threshold 0 always use a verify request
if (0 < thresholdHandshakesWithoutVerifiedPeer) {
int pending = pendingHandshakesWithoutVerifiedPeer.get();
Expand All @@ -1688,28 +1712,11 @@ private boolean isClientInControlOfSourceIpAddress(ClientHello clientHello, Reco
connections.setConnectionBySessionId(sessionConnection);
if (sessionConnection != null) {
// found provided session.
if (sessionConnection.equalsPeerAddress(record.getPeerAddress())) {
// same peer wants to resume his session,
// no verify request required
LOGGER.trace("resuming peer's [{}] session", record.getPeerAddress());
return true;
} else {
Connection addressConnection = connectionStore.get(record.getPeerAddress());
if (addressConnection == null || !addressConnection.hasEstablishedSession()) {
LOGGER.trace("fast resume for peer [{}] [{}]", record.getPeerAddress(),
pending);
return true;
}
}
// for connection with other established session,
// use the verify request
return true;
}
}
}
}
if (expectedCookie == null) {
expectedCookie = cookieGenerator.generateCookie(clientHello);
}
// for all cases not detected above, use a verify request.
sendHelloVerify(clientHello, record, expectedCookie);
return false;
Expand Down Expand Up @@ -1804,7 +1811,7 @@ private void resumeExistingSession(ClientHello clientHello, Record record, final
if (previousConnection.hasEstablishedSession()) {
// client wants to resume a session that has been negotiated by this node
// make sure that the same client only has a single active connection to this server
if (previousConnection.getPeerAddress() == null || previousConnection.equalsPeerAddress(peerAddress)) {
if (connections.isRemoveConnectionBySessionId()) {
// immediately remove previous connection
connectionStore.remove(previousConnection, false);
} else if (clientHello.getCookie().length == 0) {
Expand Down Expand Up @@ -1837,10 +1844,13 @@ public void handshakeFailed(Handshaker handshaker, Throwable error) {
}
}

private void sendHelloVerify(ClientHello clientHello, Record record, byte[] expectedCookie) {
private void sendHelloVerify(ClientHello clientHello, Record record, byte[] expectedCookie) throws GeneralSecurityException {
// send CLIENT_HELLO_VERIFY with cookie in order to prevent
// DOS attack as described in DTLS 1.2 spec
LOGGER.debug("Verifying client IP address [{}] using HELLO_VERIFY_REQUEST", record.getPeerAddress());
if (expectedCookie == null) {
expectedCookie = cookieGenerator.generateCookie(clientHello);
}
HelloVerifyRequest msg = new HelloVerifyRequest(new ProtocolVersion(), expectedCookie, record.getPeerAddress());
// because we do not have a handshaker in place yet that
// manages message_seq numbers, we need to set it explicitly
Expand Down Expand Up @@ -2216,7 +2226,7 @@ private void handleTimeout(DTLSFlight flight, Connection connection) {
Exception cause = null;
if (!connection.isExecuting() || !running.get()) {
cause = new Exception("Stopped by shutdown!");
} else if (!connection.equalsPeerAddress(flight.getPeerAddress())) {
} else if (connectionStore.get(flight.getPeerAddress()) != connection) {
cause = new Exception("Stopped by address change!");
} else {
// set DTLS retransmission maximum
Expand Down
Expand Up @@ -43,6 +43,10 @@ public class AvailableConnections {
* called.
*/
private boolean setBySessionId;
/**
* Indicates, that {@link #getConnectionBySessionId()} must be removed.
*/
private boolean removeBySessionId;

/**
* Creates a new connection pair.
Expand Down Expand Up @@ -104,4 +108,24 @@ public Connection getConnectionBySessionId() {
public boolean isConnectionBySessionIdKnown() {
return setBySessionId;
}

/**
* Set, if the connection for the provided session id must be removed.
*
* @param remove {@code true}, if {@link #getConnectionBySessionId()} must be
* removed, {@code false}, otherwise (default).
*/
public void setRemoveConnectionBySessionId(boolean remove) {
removeBySessionId = remove;
}

/**
* Check, if the connection for the provided session id must be removed.
*
* @return {@code true}, if {@link #getConnectionBySessionId()} must be
* removed, {@code false}, otherwise.
*/
public boolean isRemoveConnectionBySessionId() {
return removeBySessionId;
}
}
Expand Up @@ -62,13 +62,12 @@ public final class Connection {
private static final Logger LOGGER = LoggerFactory.getLogger(Connection.class.getName());

private final AtomicReference<Handshaker> ongoingHandshake = new AtomicReference<Handshaker>();
private final SessionListener sessionListener = new ConnectionSessionListener();
/**
* Random used by client to start the handshake.
*
* Note: used outside of serial-execution!
* Random used by client to start the handshake. Maybe {@code null}, for
* client side connections. Note: used outside of serial-execution!
*/
private final AtomicReference<Random> startedByClient = new AtomicReference<Random>();
private final SessionListener sessionListener = new ConnectionSessionListener();
private ClientHello startingClientHello;
/**
* Expired real time nanoseconds of the last message send or received.
*/
Expand All @@ -91,7 +90,7 @@ public final class Connection {
* @param serialExecutor serial executor.
* @throws NullPointerException if the peer address or the serial executor is {@code null}
*/
public Connection(final InetSocketAddress peerAddress, final SerialExecutor serialExecutor) {
public Connection(InetSocketAddress peerAddress, SerialExecutor serialExecutor) {
if (peerAddress == null) {
throw new NullPointerException("Peer address must not be null");
} else if (serialExecutor == null) {
Expand Down Expand Up @@ -337,36 +336,43 @@ public boolean hasOngoingHandshake() {
/**
* Checks whether this connection is started for the provided CLIENT_HELLO.
*
* Use the random contained in the CLIENT_HELLO.
* Use the random and message sequence number contained in the CLIENT_HELLO.
*
* Note: called outside of serial-execution!
* Note: called outside of serial-execution and so requires external synchronization!
*
* @param clientHello the message to check.
* @param clientHello the message to check. If {@code null}, reset the
* {@link #startingClientHello}.
* @return {@code true} if the given client hello has initially started this
* connection.
* @see #startByClientHello(ClientHello)
*/
public boolean isStartedByClientHello(ClientHello clientHello) {
Random startRandom = startedByClient.get();
if (startRandom != null) {
return startRandom.equals(clientHello.getRandom());
if (clientHello == null) {
startingClientHello = null;
} else if (startingClientHello != null) {
if (startingClientHello.getRandom().equals(clientHello.getRandom())) {
if (startingClientHello.getMessageSeq() >= clientHello.getMessageSeq()) {
return true;
}
}
}
return false;
}

/**
* Set starting CLIENT_HELLO.
*
* Use the random contained in the CLIENT_HELLO. Removed, if when the
* handshake is completed or fails.
* Use the random and handshake message sequence number contained in the
* CLIENT_HELLO. Removed, if when the handshake fails or with configurable
* timeout after handshake completion.
*
* Note: called outside of serial-execution!
* Note: called outside of serial-execution and so requires external synchronization!
*
* @param clientHello message which starts the connection.
* @see #isStartedByClientHello(ClientHello)
*/
public void startByClientHello(ClientHello clientHello) {
startedByClient.set(clientHello.getRandom());
startingClientHello = clientHello;
}

/**
Expand Down Expand Up @@ -500,13 +506,15 @@ public String toString() {
}
if (peerAddress != null) {
builder.append(", ").append(peerAddress);
if (hasOngoingHandshake()) {
builder.append(", ongoing handshake");
if (getOngoingHandshake() != null) {
String id = getOngoingHandshake().getSession().getSessionIdentifier().getAsString().substring(0, 8);
builder.append(", ongoing handshake ").append(id);
}
if (isResumptionRequired()) {
builder.append(", resumption required");
} else if (hasEstablishedSession()) {
builder.append(", session established");
String id = getEstablishedSession().getSessionIdentifier().getAsString().substring(0, 8);
builder.append(", session established ").append(id);
}
}
if (sessionId != null) {
Expand Down Expand Up @@ -535,15 +543,14 @@ public void sessionEstablished(Handshaker handshaker, DTLSSession session) throw
@Override
public void handshakeCompleted(Handshaker handshaker) {
if (ongoingHandshake.compareAndSet(handshaker, null)) {
startedByClient.set(null);
LOGGER.debug("Handshake with [{}] has been completed", handshaker.getPeerAddress());
}
}

@Override
public void handshakeFailed(Handshaker handshaker, Throwable error) {
if (ongoingHandshake.compareAndSet(handshaker, null)) {
startedByClient.set(null);
startingClientHello = null;
LOGGER.debug("Handshake with [{}] has failed", handshaker.getPeerAddress());
}
}
Expand Down

0 comments on commit 200fe1e

Please sign in to comment.