From d30638b96739e6f97e2852a715cd2c014a518d51 Mon Sep 17 00:00:00 2001 From: aguynamedryan Date: Thu, 27 Jan 2011 14:28:50 -0800 Subject: [PATCH] Add support for multiple gateways --- lib/capistrano/configuration/connections.rb | 32 +++++++++++++++++---- test/configuration/connections_test.rb | 29 +++++++++++++++++++ 2 files changed, 55 insertions(+), 6 deletions(-) diff --git a/lib/capistrano/configuration/connections.rb b/lib/capistrano/configuration/connections.rb index 994e01b11..4656eac9a 100644 --- a/lib/capistrano/configuration/connections.rb +++ b/lib/capistrano/configuration/connections.rb @@ -24,13 +24,29 @@ def connect_to(server) class GatewayConnectionFactory #:nodoc: def initialize(gateway, options) @options = options - @options[:logger].debug "Creating gateway using #{[*gateway].join(', ')}" if @options[:logger] Thread.abort_on_exception = true - @gateways = [*gateway].collect { |g| ServerDefinition.new(g) } - tunnel = SSH.connection_strategy(@gateways[0], @options) do |host, user, connect_options| + @gateways = {} + if gateway.is_a?(Hash) + @options[:logger].debug "Creating multiple gateways using #{gateway.inspect}" if @options[:logger] + gateway.each do |gw, hosts| + gateway_connection = add_gateway(gw) + [*hosts].each do |host| + @gateways[:default] ||= gateway_connection + @gateways[host] = gateway_connection + end + end + else + @options[:logger].debug "Creating gateway using #{[*gateway].join(', ')}" if @options[:logger] + @gateways[:default] = add_gateway(gateway) + end + end + + def add_gateway(gateway) + gateways = [*gateway].collect { |g| ServerDefinition.new(g) } + tunnel = SSH.connection_strategy(gateways[0], @options) do |host, user, connect_options| Net::SSH::Gateway.new(host, user, connect_options) end - @gateway = (@gateways[1..-1]).inject(tunnel) do |tunnel, destination| + (gateways[1..-1]).inject(tunnel) do |tunnel, destination| @options[:logger].debug "Creating tunnel to #{destination}" if @options[:logger] local_host = ServerDefinition.new("127.0.0.1", :user => destination.user, :port => tunnel.open(destination.host, (destination.port || 22))) SSH.connection_strategy(local_host, @options) do |host, user, connect_options| @@ -41,11 +57,15 @@ def initialize(gateway, options) def connect_to(server) @options[:logger].debug "establishing connection to `#{server}' via gateway" if @options[:logger] - local_host = ServerDefinition.new("127.0.0.1", :user => server.user, :port => @gateway.open(server.host, server.port || 22)) + local_host = ServerDefinition.new("127.0.0.1", :user => server.user, :port => gateway_for(server).open(server.host, server.port || 22)) session = SSH.connect(local_host, @options) session.xserver = server session end + + def gateway_for(server) + @gateways[server.host] || @gateways[:default] + end end # A hash of the SSH sessions that are currently open and available. @@ -87,7 +107,7 @@ def connect!(options={}) def connection_factory @connection_factory ||= begin if exists?(:gateway) - logger.debug "establishing connection to gateway `#{fetch(:gateway)}'" + logger.debug "establishing connection to gateway `#{fetch(:gateway).inspect}'" GatewayConnectionFactory.new(fetch(:gateway), self) else DefaultConnectionFactory.new(self) diff --git a/test/configuration/connections_test.rb b/test/configuration/connections_test.rb index ee99205c7..014ecaa57 100644 --- a/test/configuration/connections_test.rb +++ b/test/configuration/connections_test.rb @@ -81,6 +81,15 @@ def test_connection_factory_as_gateway_should_chain_gateways_if_gateway_variable assert_instance_of Capistrano::Configuration::Connections::GatewayConnectionFactory, @config.connection_factory end + def test_connection_factory_as_gateway_should_chain_gateways_if_gateway_variable_is_a_hash + @config.values[:gateway] = { ["j@gateway1", "k@gateway2"] => :default } + gateway1 = mock + Net::SSH::Gateway.expects(:new).with("gateway1", "j", :password => nil, :auth_methods => %w(publickey hostbased), :config => false).returns(gateway1) + gateway1.expects(:open).returns(65535) + Net::SSH::Gateway.expects(:new).with("127.0.0.1", "k", :port => 65535, :password => nil, :auth_methods => %w(publickey hostbased), :config => false).returns(stub_everything) + assert_instance_of Capistrano::Configuration::Connections::GatewayConnectionFactory, @config.connection_factory + end + def test_connection_factory_as_gateway_should_share_gateway_between_connections @config.values[:gateway] = "j@gateway" Net::SSH::Gateway.expects(:new).once.with("gateway", "j", :password => nil, :auth_methods => %w(publickey hostbased), :config => false).returns(stub_everything) @@ -89,6 +98,26 @@ def test_connection_factory_as_gateway_should_share_gateway_between_connections @config.establish_connections_to(server("capistrano")) @config.establish_connections_to(server("another")) end + + def test_connection_factory_as_gateway_should_share_gateway_between_like_connections_if_gateway_variable_is_a_hash + @config.values[:gateway] = { "j@gateway" => [ "capistrano", "another"] } + Net::SSH::Gateway.expects(:new).once.with("gateway", "j", :password => nil, :auth_methods => %w(publickey hostbased), :config => false).returns(stub_everything) + Capistrano::SSH.stubs(:connect).returns(stub_everything) + assert_instance_of Capistrano::Configuration::Connections::GatewayConnectionFactory, @config.connection_factory + @config.establish_connections_to(server("capistrano")) + @config.establish_connections_to(server("another")) + end + + def test_connection_factory_as_gateways_should_not_share_gateway_between_unlike_connections_if_gateway_variable_is_a_hash + @config.values[:gateway] = { "j@gateway" => [ "capistrano", "another"], "k@gateway2" => "yafhost" } + Net::SSH::Gateway.expects(:new).once.with("gateway", "j", :password => nil, :auth_methods => %w(publickey hostbased), :config => false).returns(stub_everything) + Net::SSH::Gateway.expects(:new).once.with("gateway2", "k", :password => nil, :auth_methods => %w(publickey hostbased), :config => false).returns(stub_everything) + Capistrano::SSH.stubs(:connect).returns(stub_everything) + assert_instance_of Capistrano::Configuration::Connections::GatewayConnectionFactory, @config.connection_factory + @config.establish_connections_to(server("capistrano")) + @config.establish_connections_to(server("another")) + @config.establish_connections_to(server("yafhost")) + end def test_establish_connections_to_should_accept_a_single_nonarray_parameter Capistrano::SSH.expects(:connect).with { |s,| s.host == "capistrano" }.returns(:success)