From 81d77d277c96e24d76f705fa8cfc5d8daea13e44 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henning=20P=C3=B6ttker?= Date: Mon, 15 Apr 2024 09:29:06 +0200 Subject: [PATCH] Don't send keep alive signals before kex is done (#934) Otherwise, they could interfere with strict key exchange. Co-authored-by: Jeroen van Erp --- .../transport/kex/StrictKeyExchangeTest.java | 56 ++++++++++++++++--- src/main/java/net/schmizz/sshj/SSHClient.java | 2 +- 2 files changed, 50 insertions(+), 8 deletions(-) diff --git a/src/itest/java/com/hierynomus/sshj/transport/kex/StrictKeyExchangeTest.java b/src/itest/java/com/hierynomus/sshj/transport/kex/StrictKeyExchangeTest.java index 2abe71a7..9d207c0e 100644 --- a/src/itest/java/com/hierynomus/sshj/transport/kex/StrictKeyExchangeTest.java +++ b/src/itest/java/com/hierynomus/sshj/transport/kex/StrictKeyExchangeTest.java @@ -18,15 +18,26 @@ import java.util.ArrayList; import java.util.List; import java.util.stream.Collectors; +import java.util.stream.Stream; import ch.qos.logback.classic.Logger; import ch.qos.logback.classic.spi.ILoggingEvent; import ch.qos.logback.core.read.ListAppender; import com.hierynomus.sshj.SshdContainer; +import net.schmizz.keepalive.KeepAlive; +import net.schmizz.keepalive.KeepAliveProvider; +import net.schmizz.sshj.Config; +import net.schmizz.sshj.DefaultConfig; import net.schmizz.sshj.SSHClient; +import net.schmizz.sshj.common.Message; +import net.schmizz.sshj.common.SSHPacket; +import net.schmizz.sshj.connection.ConnectionImpl; +import net.schmizz.sshj.transport.TransportException; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; -import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; import org.slf4j.LoggerFactory; import org.testcontainers.junit.jupiter.Container; import org.testcontainers.junit.jupiter.Testcontainers; @@ -62,14 +73,27 @@ private void setUpLogger(String className) { watchedLoggers.add(logger); } - @Test - void strictKeyExchange() throws Throwable { - try (SSHClient client = sshd.getConnectedClient()) { + private static Stream strictKeyExchange() { + Config defaultConfig = new DefaultConfig(); + Config heartbeaterConfig = new DefaultConfig(); + heartbeaterConfig.setKeepAliveProvider(new KeepAliveProvider() { + @Override + public KeepAlive provide(ConnectionImpl connection) { + return new HotLoopHeartbeater(connection); + } + }); + return Stream.of(defaultConfig, heartbeaterConfig).map(Arguments::of); + } + + @MethodSource + @ParameterizedTest + void strictKeyExchange(Config config) throws Throwable { + try (SSHClient client = sshd.getConnectedClient(config)) { client.authPublickey("sshj", "src/itest/resources/keyfiles/id_rsa_opensshv1"); assertTrue(client.isAuthenticated()); } List keyExchangerLogs = getLogs("KeyExchanger"); - assertThat(keyExchangerLogs).containsSequence( + assertThat(keyExchangerLogs).contains( "Initiating key exchange", "Sending SSH_MSG_KEXINIT", "Received SSH_MSG_KEXINIT", @@ -78,7 +102,7 @@ void strictKeyExchange() throws Throwable { List decoderLogs = getLogs("Decoder").stream() .map(log -> log.split(":")[0]) .collect(Collectors.toList()); - assertThat(decoderLogs).containsExactly( + assertThat(decoderLogs).startsWith( "Received packet #0", "Received packet #1", "Received packet #2", @@ -90,7 +114,7 @@ void strictKeyExchange() throws Throwable { List encoderLogs = getLogs("Encoder").stream() .map(log -> log.split(":")[0]) .collect(Collectors.toList()); - assertThat(encoderLogs).containsExactly( + assertThat(encoderLogs).startsWith( "Encoding packet #0", "Encoding packet #1", "Encoding packet #2", @@ -108,4 +132,22 @@ private List getLogs(String className) { .collect(Collectors.toList()); } + private static class HotLoopHeartbeater extends KeepAlive { + + HotLoopHeartbeater(ConnectionImpl conn) { + super(conn, "sshj-Heartbeater"); + } + + @Override + public boolean isEnabled() { + return true; + } + + @Override + protected void doKeepAlive() throws TransportException { + conn.getTransport().write(new SSHPacket(Message.IGNORE)); + } + + } + } diff --git a/src/main/java/net/schmizz/sshj/SSHClient.java b/src/main/java/net/schmizz/sshj/SSHClient.java index dd0e3817..78b91c5f 100644 --- a/src/main/java/net/schmizz/sshj/SSHClient.java +++ b/src/main/java/net/schmizz/sshj/SSHClient.java @@ -804,12 +804,12 @@ protected void onConnect() throws IOException { super.onConnect(); trans.init(getRemoteHostname(), getRemotePort(), getInputStream(), getOutputStream()); + doKex(); final KeepAlive keepAliveThread = conn.getKeepAlive(); if (keepAliveThread.isEnabled()) { ThreadNameProvider.setThreadName(conn.getKeepAlive(), trans); keepAliveThread.start(); } - doKex(); } /**