Skip to content

Commit

Permalink
Add Transport.isKeyExchangeRequired() to avoid unnecessary KEXINIT (#811
Browse files Browse the repository at this point in the history
)

* Added Transport.isKeyExchangeRequired() to avoid unnecessary KEXINIT

- Updated SSHClient.onConnect() to check isKeyExchangeRequired() before calling doKex()
- Added started timestamp in ThreadNameProvider for improved tracking

* Moved KeepAliveThread State check after authentication to avoid test timing issues
  • Loading branch information
exceptionfactory authored Sep 16, 2022
1 parent 430cbfc commit 2551f8e
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ public class ThreadNameProvider {
public static void setThreadName(final Thread thread, final RemoteAddressProvider remoteAddressProvider) {
final InetSocketAddress remoteSocketAddress = remoteAddressProvider.getRemoteSocketAddress();
final String address = remoteSocketAddress == null ? DISCONNECTED : remoteSocketAddress.toString();
final String threadName = String.format("sshj-%s-%s", thread.getClass().getSimpleName(), address);
final long started = System.currentTimeMillis();
final String threadName = String.format("sshj-%s-%s-%d", thread.getClass().getSimpleName(), address, started);
thread.setName(threadName);
}
}
7 changes: 6 additions & 1 deletion src/main/java/net/schmizz/sshj/SSHClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -810,7 +810,12 @@ protected void onConnect()
ThreadNameProvider.setThreadName(conn.getKeepAlive(), trans);
keepAliveThread.start();
}
doKex();
if (trans.isKeyExchangeRequired()) {
log.debug("Initiating Key Exchange for new connection");
doKex();
} else {
log.debug("Key Exchange already completed for new connection");
}
}

/**
Expand Down
7 changes: 7 additions & 0 deletions src/main/java/net/schmizz/sshj/transport/Transport.java
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,13 @@ void init(String host, int port, InputStream in, OutputStream out)
void doKex()
throws TransportException;

/**
* Is Key Exchange required based on current transport status
*
* @return Key Exchange required status
*/
boolean isKeyExchangeRequired();

/** @return the version string used by this client to identify itself to an SSH server, e.g. "SSHJ_3_0" */
String getClientVersion();

Expand Down
10 changes: 10 additions & 0 deletions src/main/java/net/schmizz/sshj/transport/TransportImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,16 @@ public void doKex()
kexer.startKex(true);
}

/**
* Is Key Exchange required returns true when Key Exchange is not done and when Key Exchange is not ongoing
*
* @return Key Exchange required status
*/
@Override
public boolean isKeyExchangeRequired() {
return !kexer.isKexDone() && !kexer.isKexOngoing();
}

public boolean isKexDone() {
return kexer.isKexDone();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,11 @@ public void shouldStartThreadOnConnectAndInterruptOnDisconnect() throws IOExcept
assertEquals(Thread.State.NEW, keepAlive.getState());

fixture.connectClient(sshClient);
assertEquals(Thread.State.TIMED_WAITING, keepAlive.getState());

assertThrows(UserAuthException.class, () -> sshClient.authPassword("bad", "credentials"));

assertEquals(Thread.State.TIMED_WAITING, keepAlive.getState());

fixture.stopClient();
Thread.sleep(STOP_SLEEP);

Expand Down

0 comments on commit 2551f8e

Please sign in to comment.