Skip to content

Commit

Permalink
switched from #readchar to #readpartial(1) in lib/net/ssh/transport/s…
Browse files Browse the repository at this point in the history
…erver_version.rb, so that closed sockets are recognized
  • Loading branch information
aaalex authored and delano committed Nov 9, 2009
1 parent 63fda96 commit 2fc2d20
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
7 changes: 4 additions & 3 deletions lib/net/ssh/transport/server_version.rb
Expand Up @@ -44,9 +44,10 @@ def negotiate!(socket)
@version = "" @version = ""
loop do loop do
# b = socket.recv(1) # b = socket.recv(1)
b = socket.readchar begin

b = socket.readpartial(1)
if b.nil? raise Net::SSH::Disconnect, "connection closed by remote host" if b.nil?
rescue EOFError => e
raise Net::SSH::Disconnect, "connection closed by remote host" raise Net::SSH::Disconnect, "connection closed by remote host"
end end
@version << b @version << b
Expand Down
19 changes: 14 additions & 5 deletions test/transport/test_server_version.rb
Expand Up @@ -28,6 +28,10 @@ def test_unacceptible_server_version_should_raise_exception
assert_raises(Net::SSH::Exception) { subject(socket(false, "SSH-1.4-Testing_1.0\r\n")) } assert_raises(Net::SSH::Exception) { subject(socket(false, "SSH-1.4-Testing_1.0\r\n")) }
end end


def test_unexpected_server_close_should_raise_exception
assert_raises(Net::SSH::Disconnect) { subject(socket(false, "\r\nDestination server does not have Ssh activated.\r\nContact Cisco Systems, Inc to purchase a\r\nlicense key to activate Ssh.\r\n", true)) }
end

def test_header_lines_should_be_accumulated def test_header_lines_should_be_accumulated
s = subject(socket(true, "Welcome\r\nAnother line\r\nSSH-2.0-Testing_1.0\r\n")) s = subject(socket(true, "Welcome\r\nAnother line\r\nSSH-2.0-Testing_1.0\r\n"))
assert_equal "Welcome\r\nAnother line\r\n", s.header assert_equal "Welcome\r\nAnother line\r\n", s.header
Expand All @@ -40,16 +44,21 @@ def test_server_disconnect_should_raise_exception


private private


def socket(good, version_header) def socket(good, version_header, raise_eot=false)
socket = mock("socket") socket = mock("socket")


data = version_header.split('') data = version_header.split('')
recv_times = data.length recv_times = data.length
if data[-1] != "\n" recv_times += 1 if data[-1] != "\n"
recv_times += 1
end unless raise_eot

# socket.expects(:recv).with(1).times(recv_times).returns(*data).then.returns(nil) # socket.expects(:recv).with(1).times(recv_times).returns(*data).then.returns(nil)
socket.expects(:readchar).times(recv_times).returns(*data).then.returns(nil) # socket.expects(:readchar).times(recv_times).returns(*data).then.returns(nil)
socket.expects(:readpartial).with(1).times(recv_times).returns(*data).then.returns(nil)
else
socket.expects(:readpartial).with(1).times(recv_times+1).returns(*data).then.raises(EOFError, "end of file reached")
end


if good if good
socket.expects(:write).with("#{Net::SSH::Transport::ServerVersion::PROTO_VERSION}\r\n") socket.expects(:write).with("#{Net::SSH::Transport::ServerVersion::PROTO_VERSION}\r\n")
Expand Down

0 comments on commit 2fc2d20

Please sign in to comment.