Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix #805: Prevent CHANNEL_CLOSE to be sent between Channel.isOpen and… #813

Merged
merged 2 commits into from
Sep 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions src/itest/groovy/com/hierynomus/sshj/ManyChannelsSpec.groovy
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
/*
* Copyright (C)2009 - SSHJ Contributors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.hierynomus.sshj

import net.schmizz.sshj.SSHClient
import net.schmizz.sshj.common.IOUtils
import net.schmizz.sshj.connection.channel.direct.Session
import spock.lang.Specification

import java.util.concurrent.*

import static org.codehaus.groovy.runtime.IOGroovyMethods.withCloseable

class ManyChannelsSpec extends Specification {

def "should work with many channels without nonexistent channel error (GH issue #805)"() {
given:
SshdContainer sshd = new SshdContainer.Builder()
.withSshdConfig("""${SshdContainer.Builder.DEFAULT_SSHD_CONFIG}
MaxSessions 200
""".stripMargin())
.build()
sshd.start()
SSHClient client = sshd.getConnectedClient()
client.authPublickey("sshj", "src/test/resources/id_rsa")

when:
List<Future<Exception>> futures = []
ExecutorService executorService = Executors.newCachedThreadPool()

for (int i in 0..20) {
futures.add(executorService.submit((Callable<Exception>) {
return execute(client)
}))
}
executorService.shutdown()
executorService.awaitTermination(1, TimeUnit.DAYS)

then:
futures*.get().findAll { it != null }.empty

cleanup:
client.close()
}


private static Exception execute(SSHClient sshClient) {
try {
for (def i in 0..100) {
withCloseable (sshClient.startSession()) {sshSession ->
Session.Command sshCommand = sshSession.exec("ls -la")
IOUtils.readFully(sshCommand.getInputStream()).toString()
sshCommand.close()
}
}
} catch (Exception e) {
return e
}
return null
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,25 @@ public boolean isOpen() {
}
}

// Prevent CHANNEL_CLOSE to be sent between isOpen and a Transport.write call in the runnable, otherwise
// a disconnect with a "packet referred to nonexistent channel" message can occur.
//
// This particularly happens when the transport.Reader thread passes an eof from the server to the
// ChannelInputStream, the reading library-user thread returns, and closes the channel at the same time as the
// transport.Reader thread receives the subsequent CHANNEL_CLOSE from the server.
boolean whileOpen(TransportRunnable runnable) throws TransportException, ConnectionException {
openCloseLock.lock();
try {
if (isOpen()) {
runnable.run();
return true;
}
} finally {
openCloseLock.unlock();
}
return false;
}

private void gotChannelRequest(SSHPacket buf)
throws ConnectionException, TransportException {
final String reqType;
Expand Down Expand Up @@ -427,5 +446,8 @@ public String toString() {
+ rwin + " >";
}

public interface TransportRunnable {
void run() throws TransportException, ConnectionException;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
*/
public final class ChannelOutputStream extends OutputStream implements ErrorNotifiable {

private final Channel chan;
private final AbstractChannel chan;
private final Transport trans;
private final Window.Remote win;

Expand All @@ -47,6 +47,12 @@ private final class DataBuffer {

private final SSHPacket packet = new SSHPacket(Message.CHANNEL_DATA);
private final Buffer.PlainBuffer leftOvers = new Buffer.PlainBuffer();
private final AbstractChannel.TransportRunnable packetWriteRunnable = new AbstractChannel.TransportRunnable() {
@Override
public void run() throws TransportException {
trans.write(packet);
}
};

DataBuffer() {
headerOffset = packet.rpos();
Expand Down Expand Up @@ -99,8 +105,9 @@ boolean flush(int bufferSize, boolean canAwaitExpansion) throws TransportExcepti
if (leftOverBytes > 0) {
leftOvers.putRawBytes(packet.array(), packet.wpos(), leftOverBytes);
}

trans.write(packet);
if (!chan.whileOpen(packetWriteRunnable)) {
throwStreamClosed();
}
win.consume(writeNow);

packet.rpos(headerOffset);
Expand All @@ -119,7 +126,7 @@ boolean flush(int bufferSize, boolean canAwaitExpansion) throws TransportExcepti

}

public ChannelOutputStream(Channel chan, Transport trans, Window.Remote win) {
public ChannelOutputStream(AbstractChannel chan, Transport trans, Window.Remote win) {
this.chan = chan;
this.trans = trans;
this.win = win;
Expand Down Expand Up @@ -157,17 +164,22 @@ private void checkClose() throws SSHException {
if (error != null) {
throw error;
} else {
throw new ConnectionException("Stream closed");
throwStreamClosed();
}
}
}

@Override
public synchronized void close() throws IOException {
// Not closed yet, and underlying channel is open to flush the data to.
if (!closed.getAndSet(true) && chan.isOpen()) {
buffer.flush(false);
trans.write(new SSHPacket(Message.CHANNEL_EOF).putUInt32(chan.getRecipient()));
if (!closed.getAndSet(true)) {
chan.whileOpen(new AbstractChannel.TransportRunnable() {
@Override
public void run() throws TransportException, ConnectionException {
buffer.flush(false);
trans.write(new SSHPacket(Message.CHANNEL_EOF).putUInt32(chan.getRecipient()));
}
});
}
}

Expand All @@ -188,4 +200,7 @@ public String toString() {
return "< ChannelOutputStream for Channel #" + chan.getID() + " >";
}

private static void throwStreamClosed() throws ConnectionException {
throw new ConnectionException("Stream closed");
}
}