Skip to content

Commit

Permalink
Retry authentication with all remaining auth methods after partial su…
Browse files Browse the repository at this point in the history
…ccess

Signed-off-by: Jeroen van Erp <jeroen@hierynomus.com>
  • Loading branch information
hierynomus committed Sep 23, 2022
1 parent d628c47 commit e43c672
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 11 deletions.
22 changes: 20 additions & 2 deletions src/main/java/net/schmizz/sshj/SSHClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
import net.schmizz.sshj.transport.verification.FingerprintVerifier;
import net.schmizz.sshj.transport.verification.HostKeyVerifier;
import net.schmizz.sshj.transport.verification.OpenSSHKnownHosts;
import net.schmizz.sshj.userauth.AuthResult;
import net.schmizz.sshj.userauth.UserAuth;
import net.schmizz.sshj.userauth.UserAuthException;
import net.schmizz.sshj.userauth.UserAuthImpl;
Expand Down Expand Up @@ -218,13 +219,30 @@ public void auth(String username, Iterable<AuthMethod> methods)
throws UserAuthException, TransportException {
checkConnected();
final Deque<UserAuthException> savedEx = new LinkedList<UserAuthException>();
for (AuthMethod method: methods) {
final List<AuthMethod> tried = new LinkedList<AuthMethod>();

for (Iterator<AuthMethod> it = methods.iterator(); it.hasNext();) {
AuthMethod method = it.next();
method.setLoggerFactory(loggerFactory);

try {
if (auth.authenticate(username, (Service) conn, method, trans.getTimeoutMs()))
AuthResult result = auth.authenticate(username, (Service) conn, method, trans.getTimeoutMs());

if (result == AuthResult.SUCCESS) {
return;
} else if (result == AuthResult.PARTIAL) {
// Put all remaining methods in the tried list, so that we can try them for the second round of authentication
while (it.hasNext()) {
tried.add(it.next());
}

auth(username, tried);
return;
}
tried.add(method);
} catch (UserAuthException e) {
savedEx.push(e);
tried.add(method);
}
}
throw new UserAuthException("Exhausted available authentication methods", savedEx.peek());
Expand Down
7 changes: 7 additions & 0 deletions src/main/java/net/schmizz/sshj/userauth/AuthResult.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package net.schmizz.sshj.userauth;

public enum AuthResult {
SUCCESS,
FAILURE,
PARTIAL
}
4 changes: 2 additions & 2 deletions src/main/java/net/schmizz/sshj/userauth/UserAuth.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ public interface UserAuth {
* @param nextService the service to set on successful authentication
* @param methods the {@link AuthMethod}'s to try
*
* @return whether authentication was successful
* @return whether authentication was successful, failed, or partially successful
*
* @throws UserAuthException in case of authentication failure
* @throws TransportException if there was a transport-layer error
*/
boolean authenticate(String username, Service nextService, AuthMethod methods, int timeoutMs)
AuthResult authenticate(String username, Service nextService, AuthMethod methods, int timeoutMs)
throws UserAuthException, TransportException;

/**
Expand Down
16 changes: 9 additions & 7 deletions src/main/java/net/schmizz/sshj/userauth/UserAuthImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public class UserAuthImpl
extends AbstractService
implements UserAuth {

private final Promise<Boolean, UserAuthException> authenticated;
private final Promise<AuthResult, UserAuthException> authenticated;

// Externally available
private volatile String banner = "";
Expand All @@ -53,13 +53,13 @@ public class UserAuthImpl

public UserAuthImpl(Transport trans) {
super("ssh-userauth", trans);
authenticated = new Promise<Boolean, UserAuthException>("authenticated", UserAuthException.chainer, trans.getConfig().getLoggerFactory());
authenticated = new Promise<AuthResult, UserAuthException>("authenticated", UserAuthException.chainer, trans.getConfig().getLoggerFactory());
}

@Override
public boolean authenticate(String username, Service nextService, AuthMethod method, int timeoutMs)
public AuthResult authenticate(String username, Service nextService, AuthMethod method, int timeoutMs)
throws UserAuthException, TransportException {
final boolean outcome;
final AuthResult outcome;

authenticated.lock();
try {
Expand All @@ -73,8 +73,10 @@ public boolean authenticate(String username, Service nextService, AuthMethod met
currentMethod.request();
outcome = authenticated.retrieve(timeoutMs, TimeUnit.MILLISECONDS);

if (outcome) {
if (outcome == AuthResult.SUCCESS) {
log.debug("`{}` auth successful", method.getName());
} else if (outcome == AuthResult.PARTIAL) {
log.debug("`{}` auth partially successful", method.getName());
} else {
log.debug("`{}` auth failed", method.getName());
}
Expand Down Expand Up @@ -124,7 +126,7 @@ public void handle(Message msg, SSHPacket buf)
// Should fix https://github.com/hierynomus/sshj/issues/237
trans.setAuthenticated(); // So it can put delayed compression into force if applicable
trans.setService(nextService); // We aren't in charge anymore, next service is
authenticated.deliver(true);
authenticated.deliver(AuthResult.SUCCESS);
break;

case USERAUTH_FAILURE:
Expand All @@ -133,7 +135,7 @@ public void handle(Message msg, SSHPacket buf)
if (allowedMethods.contains(currentMethod.getName()) && currentMethod.shouldRetry()) {
currentMethod.request();
} else {
authenticated.deliver(false);
authenticated.deliver(partialSuccess ? AuthResult.PARTIAL : AuthResult.FAILURE);
}
break;

Expand Down

0 comments on commit e43c672

Please sign in to comment.