Skip to content

Commit

Permalink
Added keep-alive mechanism that detects disconnects (Fixes #166)
Browse files Browse the repository at this point in the history
  • Loading branch information
hierynomus committed Jan 19, 2015
1 parent a7872b3 commit a7802dd
Show file tree
Hide file tree
Showing 16 changed files with 318 additions and 95 deletions.
10 changes: 10 additions & 0 deletions src/main/java/net/schmizz/concurrent/Promise.java
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,16 @@ public boolean inError() {
}
}

/** @return whether this promise was fulfilled with either a value or an error. */
public boolean isFulfilled() {
lock.lock();
try {
return pendingEx != null || val != null;
} finally {
lock.unlock();
}
}

/** @return whether this promise has threads waiting on it. */
public boolean hasWaiters() {
lock.lock();
Expand Down
34 changes: 34 additions & 0 deletions src/main/java/net/schmizz/keepalive/Heartbeater.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/**
* Copyright 2009 sshj contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package net.schmizz.keepalive;

import net.schmizz.sshj.common.Message;
import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.connection.ConnectionImpl;
import net.schmizz.sshj.transport.TransportException;

final class Heartbeater
extends KeepAlive {

Heartbeater(ConnectionImpl conn) {
super(conn, "heartbeater");
}

@Override
protected void doKeepAlive() throws TransportException {
conn.getTransport().write(new SSHPacket(Message.IGNORE));
}
}
68 changes: 68 additions & 0 deletions src/main/java/net/schmizz/keepalive/KeepAlive.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
package net.schmizz.keepalive;

import net.schmizz.sshj.common.Message;
import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.connection.ConnectionException;
import net.schmizz.sshj.connection.ConnectionImpl;
import net.schmizz.sshj.transport.TransportException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public abstract class KeepAlive extends Thread {
protected final Logger log = LoggerFactory.getLogger(getClass());

protected final ConnectionImpl conn;

protected int keepAliveInterval = 0;

protected KeepAlive(ConnectionImpl conn, String name) {
this.conn = conn;
setName(name);
}

public synchronized int getKeepAliveInterval() {
return keepAliveInterval;
}

public synchronized void setKeepAliveInterval(int keepAliveInterval) {
this.keepAliveInterval = keepAliveInterval;
if (keepAliveInterval > 0 && getState() == State.NEW) {
start();
}
notify();
}

synchronized protected int getPositiveInterval()
throws InterruptedException {
while (keepAliveInterval <= 0) {
wait();
}
return keepAliveInterval;
}

@Override
public void run() {
log.debug("Starting {}, sending keep-alive every {} seconds", getClass().getSimpleName(), keepAliveInterval);
try {
while (!isInterrupted()) {
final int hi = getPositiveInterval();
if (conn.getTransport().isRunning()) {
log.debug("Sending keep-alive since {} seconds elapsed", hi);
doKeepAlive();
}
Thread.sleep(hi * 1000);
}
} catch (Exception e) {
// If we weren't interrupted, kill the transport, then this exception was unexpected.
// Else we're in shutdown-mode already, so don't forcibly kill the transport.
if (!isInterrupted()) {
conn.getTransport().die(e);
}
}

log.debug("Stopping {}", getClass().getSimpleName());

}

protected abstract void doKeepAlive() throws TransportException, ConnectionException;
}
24 changes: 24 additions & 0 deletions src/main/java/net/schmizz/keepalive/KeepAliveProvider.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package net.schmizz.keepalive;

import net.schmizz.sshj.connection.ConnectionImpl;

public abstract class KeepAliveProvider {

public static final KeepAliveProvider HEARTBEAT = new KeepAliveProvider() {
@Override
public KeepAlive provide(ConnectionImpl connection) {
return new Heartbeater(connection);
}
};

public static final KeepAliveProvider KEEP_ALIVE = new KeepAliveProvider() {
@Override
public KeepAlive provide(ConnectionImpl connection) {
return new KeepAliveRunner(connection);
}
};

public abstract KeepAlive provide(ConnectionImpl connection);


}
60 changes: 60 additions & 0 deletions src/main/java/net/schmizz/keepalive/KeepAliveRunner.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
package net.schmizz.keepalive;

import net.schmizz.concurrent.Promise;
import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.connection.ConnectionException;
import net.schmizz.sshj.connection.ConnectionImpl;
import net.schmizz.sshj.transport.TransportException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.LinkedList;
import java.util.Queue;

import static java.lang.String.format;
import static net.schmizz.sshj.common.DisconnectReason.CONNECTION_LOST;

public class KeepAliveRunner extends KeepAlive {

/** The max number of keep-alives that should be unanswered before killing the connection. */
private int maxAliveCount = 5;

/** The queue of promises. */
private final Queue<Promise<SSHPacket, ConnectionException>> queue =
new LinkedList<Promise<SSHPacket, ConnectionException>>();

KeepAliveRunner(ConnectionImpl conn) {
super(conn, "keep-alive");
}

synchronized public int getMaxAliveCount() {
return maxAliveCount;
}

synchronized public void setMaxAliveCount(int maxAliveCount) {
this.maxAliveCount = maxAliveCount;
}

@Override
protected void doKeepAlive() throws TransportException, ConnectionException {
emptyQueue(queue);
checkMaxReached(queue);
queue.add(conn.sendGlobalRequest("keepalive@openssh.com", true, new byte[0]));
}

private void checkMaxReached(Queue<Promise<SSHPacket, ConnectionException>> queue) throws ConnectionException {
if (queue.size() >= maxAliveCount) {
throw new ConnectionException(CONNECTION_LOST,
format("Did not receive any keep-alive response for %s seconds", maxAliveCount * keepAliveInterval));
}
}

private void emptyQueue(Queue<Promise<SSHPacket, ConnectionException>> queue) {
Promise<SSHPacket, ConnectionException> peek = queue.peek();
while (peek != null && peek.isFulfilled()) {
log.debug("Received response from server to our keep-alive.");
queue.remove();
peek = queue.peek();
}
}
}
11 changes: 11 additions & 0 deletions src/main/java/net/schmizz/sshj/Config.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package net.schmizz.sshj;

import net.schmizz.keepalive.KeepAliveProvider;
import net.schmizz.sshj.common.Factory;
import net.schmizz.sshj.signature.Signature;
import net.schmizz.sshj.transport.cipher.Cipher;
Expand Down Expand Up @@ -144,4 +145,14 @@ public interface Config {
*/
void setVersion(String version);

/**
* @return The provider that creates the keep-alive implementation of choice.
*/
KeepAliveProvider getKeepAliveProvider();

/**
* Set the provider that provides the keep-alive implementation.
* @param keepAliveProvider keep-alive provider
*/
void setKeepAliveProvider(KeepAliveProvider keepAliveProvider);
}
11 changes: 11 additions & 0 deletions src/main/java/net/schmizz/sshj/ConfigImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package net.schmizz.sshj;

import net.schmizz.keepalive.KeepAliveProvider;
import net.schmizz.sshj.common.Factory;
import net.schmizz.sshj.signature.Signature;
import net.schmizz.sshj.transport.cipher.Cipher;
Expand All @@ -34,6 +35,7 @@ public class ConfigImpl
private String version;

private Factory<Random> randomFactory;
private KeepAliveProvider keepAliveProvider;

private List<Factory.Named<KeyExchange>> kexFactories;
private List<Factory.Named<Cipher>> cipherFactories;
Expand Down Expand Up @@ -146,4 +148,13 @@ public void setVersion(String version) {
this.version = version;
}

@Override
public KeepAliveProvider getKeepAliveProvider() {
return keepAliveProvider;
}

@Override
public void setKeepAliveProvider(KeepAliveProvider keepAliveProvider) {
this.keepAliveProvider = keepAliveProvider;
}
}
2 changes: 2 additions & 0 deletions src/main/java/net/schmizz/sshj/DefaultConfig.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/
package net.schmizz.sshj;

import net.schmizz.keepalive.KeepAliveProvider;
import net.schmizz.sshj.common.Factory;
import net.schmizz.sshj.common.SecurityUtils;
import net.schmizz.sshj.signature.SignatureDSA;
Expand Down Expand Up @@ -92,6 +93,7 @@ public DefaultConfig() {
initCompressionFactories();
initMACFactories();
initSignatureFactories();
setKeepAliveProvider(KeepAliveProvider.HEARTBEAT);
}

protected void initKeyExchangeFactories(boolean bouncyCastleRegistered) {
Expand Down
4 changes: 2 additions & 2 deletions src/main/java/net/schmizz/sshj/SSHClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ public SSHClient() {
*/
public SSHClient(Config config) {
super(DEFAULT_PORT);
this.trans = new TransportImpl(config);
this.trans = new TransportImpl(config, this);
this.auth = new UserAuthImpl(trans);
this.conn = new ConnectionImpl(trans);
this.conn = new ConnectionImpl(trans, config.getKeepAliveProvider());
}

/**
Expand Down
6 changes: 6 additions & 0 deletions src/main/java/net/schmizz/sshj/connection/Connection.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package net.schmizz.sshj.connection;

import net.schmizz.concurrent.Promise;
import net.schmizz.keepalive.KeepAlive;
import net.schmizz.sshj.common.SSHPacket;
import net.schmizz.sshj.connection.channel.Channel;
import net.schmizz.sshj.connection.channel.OpenFailException;
Expand Down Expand Up @@ -150,4 +151,9 @@ void sendOpenFailure(int recipient, OpenFailException.Reason reason, String mess
* @param timeout timeout in milliseconds
*/
void setTimeoutMs(int timeout);

/**
* @return The configured {@link net.schmizz.keepalive.KeepAlive} mechanism.
*/
KeepAlive getKeepAlive();
}
16 changes: 15 additions & 1 deletion src/main/java/net/schmizz/sshj/connection/ConnectionImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

import net.schmizz.concurrent.ErrorDeliveryUtil;
import net.schmizz.concurrent.Promise;
import net.schmizz.keepalive.KeepAlive;
import net.schmizz.keepalive.KeepAliveProvider;
import net.schmizz.sshj.AbstractService;
import net.schmizz.sshj.common.Buffer;
import net.schmizz.sshj.common.DisconnectReason;
Expand Down Expand Up @@ -51,6 +53,9 @@ public class ConnectionImpl

private final Queue<Promise<SSHPacket, ConnectionException>> globalReqPromises = new LinkedList<Promise<SSHPacket, ConnectionException>>();

/** {@code keep-alive} mechanism */
private final KeepAlive keepAlive;

private long windowSize = 2048 * 1024;
private int maxPacketSize = 32 * 1024;

Expand All @@ -59,11 +64,14 @@ public class ConnectionImpl
/**
* Create with an associated {@link Transport}.
*
* @param config the ssh config
* @param trans transport layer
* @param keepAlive
*/
public ConnectionImpl(Transport trans) {
public ConnectionImpl(Transport trans, KeepAliveProvider keepAlive) {
super("ssh-connection", trans);
timeoutMs = trans.getTimeoutMs();
this.keepAlive = keepAlive.provide(this);
}

@Override
Expand Down Expand Up @@ -250,6 +258,7 @@ public void notifyError(SSHException error) {
ErrorDeliveryUtil.alertPromises(error, globalReqPromises);
globalReqPromises.clear();
}
keepAlive.interrupt();
ErrorNotifiable.Util.alertAll(error, channels.values());
channels.clear();
}
Expand All @@ -264,4 +273,9 @@ public int getTimeoutMs() {
return timeoutMs;
}

@Override
public KeepAlive getKeepAlive() {
return keepAlive;
}

}

0 comments on commit a7802dd

Please sign in to comment.