From 874fbd53dd898c325edc36ec37d0518f03bfd987 Mon Sep 17 00:00:00 2001 From: Roger Peppe Date: Thu, 25 May 2017 08:54:59 +0100 Subject: [PATCH] api: add explicit DNS cache The API connection logic in juju.NewAPIConnection implements a half-hearted DNS cache by resolving DNS addresses itself and storing them in the controllers.yaml fields. This obscures the actual addresses that have been returned from the controller. Instead of that, we make api.Open responsible for all DNS resolution and pass a DNS cache interface into DialOpts so that the NewAPIConnection logic can still do its own DNS caching. This means that the API connection logic always knows the symbolic host name of the controller that it's dialing, which is important when using public certificate validation. It also means that even when a Go client calls api.Open with a single host name, it will get the benefit of concurrent dialing to each of the resolved IP addresses. --- api/apiclient.go | 252 +++++++++++++++++++++++++--------- api/apiclient_test.go | 307 +++++++++++++++++++++++++++++++++++++----- api/export_test.go | 9 +- api/interface.go | 33 ++++- 4 files changed, 507 insertions(+), 94 deletions(-) diff --git a/api/apiclient.go b/api/apiclient.go index 69a96677813..d2710ea1014 100644 --- a/api/apiclient.go +++ b/api/apiclient.go @@ -72,6 +72,9 @@ type state struct { // addr is the address used to connect to the API server. addr string + // ipAddr is the IP address used to connect to the API server. + ipAddr string + // cookieURL is the URL that HTTP cookies for the API // will be associated with (specifically macaroon auth cookies). cookieURL *url.URL @@ -200,19 +203,12 @@ func Open(info *Info, opts DialOpts) (Connection, error) { httpc := *bakeryClient.Client bakeryClient.Client = &httpc } - apiURL, err := url.Parse(dialResult.urlStr) - if err != nil { - // This should never happen as the url would have failed during dialAPI above. - // However the code paths don't allow capture of the url.URL used. - return nil, errors.Trace(err) - } - apiHost := apiURL.Host // Technically when there's no CACert, we don't need this // machinery, because we could just use http.DefaultTransport // for everything, but it's easier just to leave it in place. bakeryClient.Client.Transport = &hostSwitchingTransport{ - primaryHost: apiHost, + primaryHost: dialResult.addr, primary: utils.NewHttpTLSTransport(dialResult.tlsConfig), fallback: http.DefaultTransport, } @@ -221,15 +217,16 @@ func Open(info *Info, opts DialOpts) (Connection, error) { client: client, conn: dialResult.conn, clock: opts.Clock, - addr: apiHost, + addr: dialResult.addr, + ipAddr: dialResult.ipAddr, cookieURL: &url.URL{ Scheme: "https", - Host: apiHost, + Host: dialResult.addr, Path: "/", }, pingerFacadeVersion: facadeVersions["Pinger"], serverScheme: "https", - serverRootAddress: apiHost, + serverRootAddress: dialResult.addr, // We populate the username and password before // login because, when doing HTTP requests, we'll want // to use the same username and password for authenticating @@ -522,7 +519,9 @@ func tagToString(tag names.Tag) string { // and TLS configuration used to connect to it. type dialResult struct { conn jsoncodec.JSONConn + addr string urlStr string + ipAddr string tlsConfig *tls.Config } @@ -537,9 +536,11 @@ func (c *dialResult) Close() error { // but adds some information for the local dial logic. type dialOpts struct { DialOpts - tlsConfig *tls.Config sniHostName string deadline time.Time + // certPool holds a cert pool containing the CACert + // if there is one. + certPool *x509.CertPool } // dialAPI establishes a websocket connection to the RPC @@ -556,32 +557,27 @@ func dialAPI(info *Info, opts0 DialOpts) (*dialResult, error) { DialOpts: opts0, sniHostName: info.SNIHostName, } - tlsConfig := utils.SecureTLSConfig() - tlsConfig.InsecureSkipVerify = opts.InsecureSkipVerify if info.CACert != "" { - // We want to be specific here (rather than just using "anything". - // See commit 7fc118f015d8480dfad7831788e4b8c0432205e8 (PR 899). - tlsConfig.ServerName = "juju-apiserver" certPool, err := CreateCertPool(info.CACert) if err != nil { return nil, errors.Annotate(err, "cert pool creation failed") } - tlsConfig.RootCAs = certPool - } else { - // No CA certificate so use the SNI host name for all - // connections (if SNIHostName is empty, the host - // name in the address will be used as usual). - tlsConfig.ServerName = info.SNIHostName + opts.certPool = certPool } - opts.tlsConfig = tlsConfig // Set opts.DialWebsocket and opts.Clock here rather than in open because // some tests call dialAPI directly. if opts.DialWebsocket == nil { opts.DialWebsocket = gorillaDialWebsocket } + if opts.IPAddrResolver == nil { + opts.IPAddrResolver = net.DefaultResolver + } if opts.Clock == nil { opts.Clock = clock.WallClock } + if opts.DNSCache == nil { + opts.DNSCache = nopDNSCache{} + } // TODO(rogpeppe) Pass a context with an existing deadline into dialAPI. // We're avoiding that for the moment because it will break // lots of tests that rely on passing a zero timeout into @@ -604,15 +600,18 @@ func dialAPI(info *Info, opts0 DialOpts) (*dialResult, error) { } // gorillaDialWebsocket makes a websocket connection using the -// gorilla websocket package. -func gorillaDialWebsocket(ctx context.Context, urlStr string, tlsConfig *tls.Config) (jsoncodec.JSONConn, error) { +// gorilla websocket package. The ipAddr parameter holds the +// actual IP address that will be contacted - the host in urlStr +// is used only for TLS verification when tlsConfig.ServerName +// is empty. +func gorillaDialWebsocket(ctx context.Context, urlStr string, tlsConfig *tls.Config, ipAddr string) (jsoncodec.JSONConn, error) { // TODO(rogpeppe) We'd like to set Deadline here // but that would break lots of tests that rely on // setting a zero timeout. netDialer := net.Dialer{} dialer := &websocket.Dialer{ NetDial: func(netw, addr string) (net.Conn, error) { - return netDialer.DialContext(ctx, netw, addr) + return netDialer.DialContext(ctx, netw, ipAddr) }, Proxy: proxy.DefaultConfig.GetProxy, TLSClientConfig: tlsConfig, @@ -637,17 +636,73 @@ func dialWebsocketMulti(ctx context.Context, addrs []string, path string, opts d // Dial all addresses at reasonable intervals. try := parallel.NewTry(0, nil) defer try.Kill() - for _, addr := range addrs { - err := startDialWebsocket(ctx, try, addr, path, opts) - if err == parallel.ErrStopped { + // Make a context that's cancelled when the try + // completes so that (for example) a slow DNS + // query will be cancelled if a previous try succeeds. + ctx, cancel := context.WithCancel(ctx) + go func() { + <-try.Dead() + cancel() + }() + tried := make(map[string]bool) + var cacheUsed []string + for { + if len(addrs) == 0 && len(cacheUsed) > 0 { + // We've tried all the addresses but for some + // of them we used cached values which might + // have become out of date, so retry them + // with no cache. + addrs = cacheUsed + cacheUsed = nil + opts.DNSCache = emptyDNSCache{opts.DNSCache} + } + if len(addrs) == 0 { break } + addr := addrs[0] + addrs = addrs[1:] + host, port, err := net.SplitHostPort(addr) if err != nil { - return nil, errors.Trace(err) + // Defensive - this should never happen because + // the addresses are checked with Info.Validate + // beforehand. + err := errors.Errorf("invalid address %q: %v", addr, err) + recordTryError(try, err) + continue } - select { - case <-opts.Clock.After(opts.DialAddressInterval): - case <-try.Dead(): + ips := opts.DNSCache.Lookup(host) + if len(ips) > 0 { + cacheUsed = append(cacheUsed, addr) + } else if isNumericHost(host) { + ips = []string{host} + } else { + var err error + ips, err = lookupIPAddr(ctx, host, opts.IPAddrResolver) + if err != nil { + err := errors.Errorf("cannot resolve %q: %v", host, err) + recordTryError(try, err) + continue + } + opts.DNSCache.Add(host, ips) + logger.Debugf("looked up %v -> %v", host, ips) + } + for _, ip := range ips { + ipStr := net.JoinHostPort(ip, port) + if tried[ipStr] { + continue + } + tried[ipStr] = true + err := startDialWebsocket(ctx, try, ipStr, addr, path, opts) + if err == parallel.ErrStopped { + break + } + if err != nil { + return nil, errors.Trace(err) + } + select { + case <-opts.Clock.After(opts.DialAddressInterval): + case <-try.Dead(): + } } } try.Close() @@ -658,9 +713,37 @@ func dialWebsocketMulti(ctx context.Context, addrs []string, path string, opts d return result.(*dialResult), nil } +func lookupIPAddr(ctx context.Context, host string, resolver IPAddrResolver) ([]string, error) { + addrs, err := resolver.LookupIPAddr(ctx, host) + if err != nil { + return nil, errors.Trace(err) + } + ips := make([]string, 0, len(addrs)) + for _, addr := range addrs { + if addr.Zone != "" { + // Ignore IPv6 zone. Hopefully this shouldn't + // cause any problems in practice. + logger.Infof("ignoring IP address with zone %q", addr) + continue + } + ips = append(ips, addr.IP.String()) + } + return ips, nil +} + +// recordTryError starts a try that just returns the given error. +// This is so that we can use the usual Try error combination +// logic even for errors that happen before we start a try. +func recordTryError(try *parallel.Try, err error) { + logger.Infof("%v", err) + try.Start(func(_ <-chan struct{}) (io.Closer, error) { + return nil, errors.Trace(err) + }) +} + // startDialWebsocket starts websocket connection to a single address // on the given try instance. -func startDialWebsocket(ctx context.Context, try *parallel.Try, addr, path string, opts dialOpts) error { +func startDialWebsocket(ctx context.Context, try *parallel.Try, ipAddr, addr, path string, opts dialOpts) error { openAttempt := retry.Regular{ Total: opts.Timeout, Delay: opts.RetryDelay, @@ -671,8 +754,10 @@ func startDialWebsocket(ctx context.Context, try *parallel.Try, addr, path strin d := dialer{ ctx: ctx, openAttempt: openAttempt, - addr: addr, + serverName: opts.sniHostName, + ipAddr: ipAddr, urlStr: "wss://" + addr + path, + addr: addr, opts: opts, } return try.Start(d.dial) @@ -681,9 +766,23 @@ func startDialWebsocket(ctx context.Context, try *parallel.Try, addr, path strin type dialer struct { ctx context.Context openAttempt retry.Strategy - addr string - urlStr string - opts dialOpts + + // serverName holds the SNI name to use + // when connecting with a public certificate. + serverName string + + // addr holds the host:port that is being dialed. + addr string + + // addr holds the ipaddr:port (one of the addresses + // that addr resolves to) that is being dialed. + ipAddr string + + // urlStr holds the URL that is being dialed. + urlStr string + + // opts holds the dial options. + opts dialOpts } // dial implements the function value expected by Try.Start @@ -696,6 +795,8 @@ func (d dialer) dial(_ <-chan struct{}) (io.Closer, error) { if err == nil { return &dialResult{ conn: conn, + addr: d.addr, + ipAddr: d.ipAddr, urlStr: d.urlStr, tlsConfig: tlsConfig, }, nil @@ -711,23 +812,31 @@ func (d dialer) dial(_ <-chan struct{}) (io.Closer, error) { // dial1 makes a single dial attempt. func (d dialer) dial1() (jsoncodec.JSONConn, *tls.Config, error) { - conn, err := d.opts.DialWebsocket(d.ctx, d.urlStr, d.opts.tlsConfig) + tlsConfig := utils.SecureTLSConfig() + tlsConfig.InsecureSkipVerify = d.opts.InsecureSkipVerify + if d.opts.certPool != nil { + // We want to be specific here (rather than just using "anything"). + // See commit 7fc118f015d8480dfad7831788e4b8c0432205e8 (PR 899). + tlsConfig.RootCAs = d.opts.certPool + tlsConfig.ServerName = "juju-apiserver" + } else { + tlsConfig.ServerName = d.serverName + } + conn, err := d.opts.DialWebsocket(d.ctx, d.urlStr, tlsConfig, d.ipAddr) if err == nil { logger.Debugf("successfully dialed %q", d.urlStr) - return conn, d.opts.tlsConfig, nil + return conn, tlsConfig, nil } if !isX509Error(err) { return nil, nil, errors.Trace(err) } - if (d.opts.sniHostName == "" && isNumericIP(d.addr)) || d.opts.tlsConfig.RootCAs == nil { - // We're trying to connect to a numeric IP address with - // no other server name available, or we're using public - // certificate validation. In the former case, using - // public cert validation won't help, because you - // generally can't obtain a public cert for a numeric IP - // address. In the latter case, we either don't have the - // private CA cert or we've already tried it. In both - // those cases, we won't succeed when trying again + if tlsConfig.RootCAs == nil || d.serverName == "" { + // There's no private certificate or we don't have a + // public hostname. In the former case, we've already + // tried public certificates; in the latter, public cert + // validation won't help, because you generally can't + // obtain a public cert for a numeric IP address. In + // both those cases, we won't succeed when trying again // because a cert error isn't temporary, so return // immediately. // @@ -738,23 +847,20 @@ func (d dialer) dial1() (jsoncodec.JSONConn, *tls.Config, error) { } // It's possible we're inappropriately using the private // CA certificate, so retry immediately with the public one. - tlsConfig1 := *d.opts.tlsConfig - tlsConfig1.RootCAs = nil - tlsConfig1.ServerName = d.opts.sniHostName - conn, rootCAErr := d.opts.DialWebsocket(d.ctx, d.urlStr, &tlsConfig1) + tlsConfig.RootCAs = nil + tlsConfig.ServerName = d.serverName + conn, rootCAErr := d.opts.DialWebsocket(d.ctx, d.urlStr, tlsConfig, d.ipAddr) if rootCAErr != nil { logger.Debugf("failed to dial websocket using fallback public CA: %v", rootCAErr) // We return the original error as it's usually more meaningful. return nil, nil, errors.Trace(err) } - return conn, &tlsConfig1, nil + return conn, tlsConfig, nil } -func isNumericIP(addr string) bool { - host, _, err := net.SplitHostPort(addr) - if err != nil { - return false - } +// isNumericHost reports whether the given host name is +// a numeric IP address. +func isNumericHost(host string) bool { return net.ParseIP(host) != nil } @@ -843,6 +949,12 @@ func (s *state) Addr() string { return s.addr } +// IPAddr returns the resolved IP address that was used to +// connect to the API server. +func (s *state) IPAddr() string { + return s.ipAddr +} + // ModelTag implements base.APICaller.ModelTag. func (s *state) ModelTag() (names.ModelTag, bool) { return s.modelTag, s.modelTag.Id() != "" @@ -911,3 +1023,23 @@ func (s *state) isLoggedIn() bool { func (s *state) setLoggedIn() { atomic.StoreInt32(&s.loggedIn, 1) } + +// emptyDNSCache implements DNSCache by +// never returning any entries but writing any +// added entries to the embedded DNSCache object. +type emptyDNSCache struct { + DNSCache +} + +func (emptyDNSCache) Lookup(host string) []string { + return nil +} + +type nopDNSCache struct{} + +func (nopDNSCache) Lookup(host string) []string { + return nil +} + +func (nopDNSCache) Add(host string, ips []string) { +} diff --git a/api/apiclient_test.go b/api/apiclient_test.go index 2c89581771e..c747a8f0586 100644 --- a/api/apiclient_test.go +++ b/api/apiclient_test.go @@ -7,6 +7,7 @@ import ( "context" "crypto/tls" "crypto/x509" + "fmt" "net" "sync" "sync/atomic" @@ -76,6 +77,7 @@ func (s *apiclientSuite) TestDialAPIMultiple(c *gc.C) { // Now break Addrs[0], and ensure that Addrs[1] // is successfully connected to. proxy.Close() + info.Addrs = []string{proxy.Addr(), serverAddr} conn, location, err = api.DialAPI(info, api.DialOpts{}) c.Assert(err, jc.ErrorIsNil) @@ -84,25 +86,29 @@ func (s *apiclientSuite) TestDialAPIMultiple(c *gc.C) { } func (s *apiclientSuite) TestDialAPIMultipleError(c *gc.C) { - listener, err := net.Listen("tcp", "127.0.0.1:0") - c.Assert(err, jc.ErrorIsNil) - defer listener.Close() + var addrs []string + // count holds the number of times we've accepted a connection. var count int32 - go func() { - for { - client, err := listener.Accept() - if err != nil { - return + for i := 0; i < 3; i++ { + listener, err := net.Listen("tcp", "127.0.0.1:0") + c.Assert(err, jc.ErrorIsNil) + defer listener.Close() + addrs = append(addrs, listener.Addr().String()) + go func() { + for { + client, err := listener.Accept() + if err != nil { + return + } + atomic.AddInt32(&count, 1) + client.Close() } - atomic.AddInt32(&count, 1) - client.Close() - } - }() + }() + } info := s.APIInfo(c) - addr := listener.Addr().String() - info.Addrs = []string{addr, addr, addr} - _, _, err = api.DialAPI(info, api.DialOpts{}) + info.Addrs = addrs + _, _, err := api.DialAPI(info, api.DialOpts{}) c.Assert(err, gc.ErrorMatches, `unable to connect to API: .*`) c.Assert(atomic.LoadInt32(&count), gc.Equals, int32(3)) } @@ -175,7 +181,7 @@ func (s *apiclientSuite) TestDialWebsocketStopsOtherDialAttempts(c *gc.C) { replyc chan<- dialResponse } dialed := make(chan dialInfo) - fakeDialer := func(ctx context.Context, urlStr string, tlsConfig *tls.Config) (jsoncodec.JSONConn, error) { + fakeDialer := func(ctx context.Context, urlStr string, tlsConfig *tls.Config, ipAddr string) (jsoncodec.JSONConn, error) { reply := make(chan dialResponse) dialed <- dialInfo{ ctx: ctx, @@ -185,7 +191,7 @@ func (s *apiclientSuite) TestDialWebsocketStopsOtherDialAttempts(c *gc.C) { r := <-reply return r.conn, nil } - conn0 := &fakeConn{} + conn0 := fakeConn{} clock := testing.NewClock(time.Now()) openDone := make(chan struct{}) const dialAddressInterval = 50 * time.Millisecond @@ -193,8 +199,8 @@ func (s *apiclientSuite) TestDialWebsocketStopsOtherDialAttempts(c *gc.C) { defer close(openDone) conn, err := api.Open(&api.Info{ Addrs: []string{ - "place0.example:1234", "place1.example:1234", + "place2.example:1234", }, SkipLogin: true, CACert: jtesting.CACert, @@ -204,6 +210,10 @@ func (s *apiclientSuite) TestDialWebsocketStopsOtherDialAttempts(c *gc.C) { DialAddressInterval: dialAddressInterval, DialWebsocket: fakeDialer, Clock: clock, + IPAddrResolver: fakeResolver{ + "place1.example": {"0.1.1.1"}, + "place2.example": {"0.2.2.2"}, + }, }) c.Check(api.UnderlyingConn(conn), gc.Equals, conn0) c.Check(err, jc.ErrorIsNil) @@ -219,7 +229,7 @@ func (s *apiclientSuite) TestDialWebsocketStopsOtherDialAttempts(c *gc.C) { case <-time.After(jtesting.LongWait): c.Fatalf("timed out waiting for dial") } - c.Assert(info0.location, gc.Equals, "wss://place0.example:1234/api") + c.Assert(info0.location, gc.Equals, "wss://place1.example:1234/api") var info1 dialInfo // Wait for the next dial to be made. @@ -231,7 +241,7 @@ func (s *apiclientSuite) TestDialWebsocketStopsOtherDialAttempts(c *gc.C) { case <-time.After(jtesting.LongWait): c.Fatalf("timed out waiting for dial") } - c.Assert(info1.location, gc.Equals, "wss://place1.example:1234/api") + c.Assert(info1.location, gc.Equals, "wss://place2.example:1234/api") // Allow the first dial to succeed. info0.replyc <- dialResponse{ @@ -253,7 +263,7 @@ func (s *apiclientSuite) TestDialWebsocketStopsOtherDialAttempts(c *gc.C) { case <-time.After(jtesting.LongWait): c.Fatalf("timed out waiting for context to be closed") } - conn1 := &fakeConn{ + conn1 := fakeConn{ closed: make(chan struct{}), } // Allow the second dial to succeed. @@ -305,13 +315,13 @@ var openWithSNIHostnameTests = []struct { }, { about: "with cert; DNS name - use cert", info: &api.Info{ - Addrs: []string{"foo.com:1234"}, + Addrs: []string{"0.1.1.1:1234"}, SNIHostName: "foo.com", SkipLogin: true, CACert: jtesting.CACert, }, expectDial: apiDialInfo{ - location: "wss://foo.com:1234/api", + location: "wss://0.1.1.1:1234/api", hasRootCAs: true, serverName: "juju-apiserver", }, @@ -431,7 +441,7 @@ type dialInfo struct { func (s *apiclientSuite) testOpenDialError(c *gc.C, t dialTest) { dialed := make(chan dialInfo) - fakeDialer := func(ctx context.Context, urlStr string, tlsConfig *tls.Config) (jsoncodec.JSONConn, error) { + fakeDialer := func(ctx context.Context, urlStr string, tlsConfig *tls.Config, ipAddr string) (jsoncodec.JSONConn, error) { reply := make(chan error) dialed <- dialInfo{ location: urlStr, @@ -444,10 +454,11 @@ func (s *apiclientSuite) testOpenDialError(c *gc.C, t dialTest) { go func() { defer close(done) conn, err := api.Open(t.apiInfo, api.DialOpts{ - Timeout: 5 * time.Second, - RetryDelay: 1 * time.Second, - DialWebsocket: fakeDialer, - Clock: &fakeClock{}, + Timeout: 5 * time.Second, + RetryDelay: 1 * time.Second, + DialWebsocket: fakeDialer, + IPAddrResolver: seqResolver(t.apiInfo.Addrs...), + Clock: &fakeClock{}, }) c.Check(conn, gc.Equals, nil) c.Check(err, gc.ErrorMatches, t.expectOpenError) @@ -561,6 +572,185 @@ func (s *apiclientSuite) TestOpenWithRedirect(c *gc.C) { }) } +func (s *apiclientSuite) TestOpenCachesDNS(c *gc.C) { + fakeDialer := func(ctx context.Context, urlStr string, tlsConfig *tls.Config, ipAddr string) (jsoncodec.JSONConn, error) { + return fakeConn{}, nil + } + dnsCache := make(dnsCacheMap) + conn, err := api.Open(&api.Info{ + Addrs: []string{ + "place1.example:1234", + }, + SkipLogin: true, + CACert: jtesting.CACert, + }, api.DialOpts{ + Timeout: 5 * time.Second, + RetryDelay: 1 * time.Second, + DialWebsocket: fakeDialer, + IPAddrResolver: fakeResolver{ + "place1.example": {"0.1.1.1"}, + }, + DNSCache: dnsCache, + Clock: &fakeClock{}, + }) + c.Assert(err, jc.ErrorIsNil) + c.Assert(conn, gc.NotNil) + c.Assert(dnsCache.Lookup("place1.example"), jc.DeepEquals, []string{"0.1.1.1"}) +} + +func (s *apiclientSuite) TestDNSCacheUsed(c *gc.C) { + var dialed string + fakeDialer := func(ctx context.Context, urlStr string, tlsConfig *tls.Config, ipAddr string) (jsoncodec.JSONConn, error) { + dialed = ipAddr + return fakeConn{}, nil + } + conn, err := api.Open(&api.Info{ + Addrs: []string{ + "place1.example:1234", + }, + SkipLogin: true, + CACert: jtesting.CACert, + }, api.DialOpts{ + Timeout: 5 * time.Second, + RetryDelay: 1 * time.Second, + DialWebsocket: fakeDialer, + IPAddrResolver: fakeResolver{ + "place1.example": {"0.2.2.2"}, + }, + DNSCache: dnsCacheMap{ + "place1.example": {"0.1.1.1"}, + }, + Clock: &fakeClock{}, + }) + c.Assert(err, jc.ErrorIsNil) + c.Assert(conn, gc.NotNil) + // The dialed IP address should have come from the cache, not the IP address + // resolver. + c.Assert(dialed, gc.Equals, "0.1.1.1:1234") + c.Assert(conn.IPAddr(), gc.Equals, "0.1.1.1:1234") +} + +func (s *apiclientSuite) TestNumericAddressIsNotAddedToCache(c *gc.C) { + fakeDialer := func(ctx context.Context, urlStr string, tlsConfig *tls.Config, ipAddr string) (jsoncodec.JSONConn, error) { + return fakeConn{}, nil + } + dnsCache := make(dnsCacheMap) + conn, err := api.Open(&api.Info{ + Addrs: []string{ + "0.1.2.3:1234", + }, + SkipLogin: true, + CACert: jtesting.CACert, + }, api.DialOpts{ + Timeout: 5 * time.Second, + RetryDelay: 1 * time.Second, + DialWebsocket: fakeDialer, + IPAddrResolver: fakeResolver{}, + DNSCache: dnsCache, + Clock: &fakeClock{}, + }) + c.Assert(err, jc.ErrorIsNil) + c.Assert(conn, gc.NotNil) + c.Assert(conn.Addr(), gc.Equals, "0.1.2.3:1234") + c.Assert(conn.IPAddr(), gc.Equals, "0.1.2.3:1234") + c.Assert(dnsCache, gc.HasLen, 0) +} + +func (s *apiclientSuite) TestFallbackToIPLookupWhenCacheOutOfDate(c *gc.C) { + var mu sync.Mutex + dialed := make(map[string]bool) + fakeDialer := func(ctx context.Context, urlStr string, tlsConfig *tls.Config, ipAddr string) (jsoncodec.JSONConn, error) { + mu.Lock() + defer mu.Unlock() + dialed[ipAddr] = true + if ipAddr == "0.2.2.2:1234" { + return fakeConn{}, nil + } + return nil, errors.Errorf("bad address") + } + dnsCache := dnsCacheMap{ + "place1.example": {"0.1.1.1"}, + } + conn, err := api.Open(&api.Info{ + Addrs: []string{ + "place1.example:1234", + }, + SkipLogin: true, + CACert: jtesting.CACert, + }, api.DialOpts{ + Timeout: 5 * time.Second, + RetryDelay: 1 * time.Second, + DialWebsocket: fakeDialer, + IPAddrResolver: fakeResolver{ + "place1.example": {"0.2.2.2"}, + }, + DNSCache: dnsCache, + Clock: &fakeClock{}, + }) + c.Assert(err, jc.ErrorIsNil) + c.Assert(conn, gc.NotNil) + mu.Lock() + defer mu.Unlock() + c.Assert(dialed, jc.DeepEquals, map[string]bool{ + "0.2.2.2:1234": true, + "0.1.1.1:1234": true, + }) + c.Assert(dnsCache.Lookup("place1.example"), jc.DeepEquals, []string{"0.2.2.2"}) +} + +func (s *apiclientSuite) TestWithUnresolvableAddr(c *gc.C) { + fakeDialer := func(ctx context.Context, urlStr string, tlsConfig *tls.Config, ipAddr string) (jsoncodec.JSONConn, error) { + c.Errorf("dial was called but should not have been") + return nil, errors.Errorf("cannot dial") + } + conn, err := api.Open(&api.Info{ + Addrs: []string{ + "nowhere.example:1234", + }, + SkipLogin: true, + CACert: jtesting.CACert, + }, api.DialOpts{ + Timeout: 5 * time.Second, + RetryDelay: 1 * time.Second, + DialWebsocket: fakeDialer, + IPAddrResolver: fakeResolver{}, + Clock: &fakeClock{}, + }) + c.Assert(err, gc.ErrorMatches, `cannot resolve "nowhere.example": mock resolver cannot resolve "nowhere.example"`) + c.Assert(conn, jc.ErrorIsNil) +} + +func (s *apiclientSuite) TestWithUnresolvableAddrAfterCacheFallback(c *gc.C) { + fakeDialer := func(ctx context.Context, urlStr string, tlsConfig *tls.Config, ipAddr string) (jsoncodec.JSONConn, error) { + if ipAddr == "0.2.2.2:1234" { + return nil, errors.Errorf("cannot connect with real address") + } + return nil, errors.Errorf("bad address from cache") + } + dnsCache := dnsCacheMap{ + "place1.example": {"0.1.1.1"}, + } + conn, err := api.Open(&api.Info{ + Addrs: []string{ + "place1.example:1234", + }, + SkipLogin: true, + CACert: jtesting.CACert, + }, api.DialOpts{ + Timeout: 5 * time.Second, + RetryDelay: 1 * time.Second, + DialWebsocket: fakeDialer, + IPAddrResolver: fakeResolver{ + "place1.example": {"0.2.2.2"}, + }, + DNSCache: dnsCache, + Clock: &fakeClock{}, + }) + c.Assert(err, gc.ErrorMatches, `unable to connect to API: cannot connect with real address`) + c.Assert(conn, gc.Equals, nil) + c.Assert(dnsCache.Lookup("place1.example"), jc.DeepEquals, []string{"0.2.2.2"}) +} + func (s *apiclientSuite) TestAPICallNoError(c *gc.C) { clock := &fakeClock{} conn := api.NewTestingState(api.TestingStateParams{ @@ -773,15 +963,68 @@ type fakeConn struct { closed chan struct{} } -func (c *fakeConn) Receive(x interface{}) error { +func (c fakeConn) Receive(x interface{}) error { return errors.New("no data available from fake connection") } -func (c *fakeConn) Send(x interface{}) error { +func (c fakeConn) Send(x interface{}) error { return errors.New("cannot write to fake connection") } -func (c *fakeConn) Close() error { - close(c.closed) +func (c fakeConn) Close() error { + if c.closed != nil { + close(c.closed) + } return nil } + +// fakeResolver implements IPAddrResolver +// by looking up the addresses in the map, +// which maps host names to IP addresses. +type fakeResolver map[string][]string + +func (r fakeResolver) LookupIPAddr(ctx context.Context, host string) ([]net.IPAddr, error) { + if ip := net.ParseIP(host); ip != nil { + return []net.IPAddr{{IP: ip}}, nil + } + ipStrs := r[host] + if len(ipStrs) == 0 { + return nil, errors.Errorf("mock resolver cannot resolve %q", host) + } + ipAddrs := make([]net.IPAddr, len(ipStrs)) + for i, ipStr := range ipStrs { + ip := net.ParseIP(ipStr) + if ip == nil { + panic("invalid IP address: " + ipStr) + } + ipAddrs[i] = net.IPAddr{ + IP: ip, + } + } + return ipAddrs, nil +} + +// seqResolver returns an implementation of +// IPAddrResolver that maps the given addresses +// to sequential IP addresses 0.1.1.1, 0.2.2.2, etc. +func seqResolver(addrs ...string) api.IPAddrResolver { + r := make(fakeResolver) + for i, addr := range addrs { + host, _, err := net.SplitHostPort(addr) + if err != nil { + panic(err) + } + r[host] = []string{fmt.Sprintf("0.%[1]d.%[1]d.%[1]d", i+1)} + } + return r +} + +type dnsCacheMap map[string][]string + +func (m dnsCacheMap) Lookup(host string) []string { + return m[host] +} + +func (m dnsCacheMap) Add(host string, ips []string) { + m[host] = append([]string{}, ips...) +} diff --git a/api/export_test.go b/api/export_test.go index 2d515e5baea..bdfead89a4e 100644 --- a/api/export_test.go +++ b/api/export_test.go @@ -4,6 +4,8 @@ package api import ( + "net/url" + "github.com/juju/errors" "github.com/juju/utils/clock" "gopkg.in/juju/names.v2" @@ -28,7 +30,12 @@ func DialAPI(info *Info, opts DialOpts) (jsoncodec.JSONConn, string, error) { if err != nil { return nil, "", err } - return result.conn, result.urlStr, nil + // Replace the IP address in the URL with the + // host name so that tests can check it more + // easily. + u, _ := url.Parse(result.urlStr) + u.Host = result.addr + return result.conn, u.String(), nil } // RPCConnection defines the methods that are called on the rpc.Conn instance. diff --git a/api/interface.go b/api/interface.go index fdfaa8a7dde..f8fc2bb23fb 100644 --- a/api/interface.go +++ b/api/interface.go @@ -6,6 +6,7 @@ package api import ( "context" "crypto/tls" + "net" "net/url" "time" @@ -148,16 +149,43 @@ type DialOpts struct { // DialWebsocket is used to make connections to API servers. // It will be called with a websocket URL to connect to, // and the TLS configuration to use to secure the connection. + // If ipAddr is non-empty, the actual net.Dial should use + // that IP address, regardless of the URL host. // // If DialWebsocket is nil, a default implementation using // gorilla websockets will be used. - DialWebsocket func(ctx context.Context, urlStr string, tlsConfig *tls.Config) (jsoncodec.JSONConn, error) + DialWebsocket func(ctx context.Context, urlStr string, tlsConfig *tls.Config, ipAddr string) (jsoncodec.JSONConn, error) + + // IPAddrResolver is used to resolve host names to IP addresses. + // If it is nil, net.DefaultResolver will be used. + IPAddrResolver IPAddrResolver + + // DNSCache is consulted to find and store cached DNS lookups. + // If it is nil, no cache will be used or updated. + DNSCache DNSCache // Clock is used as a time source for retries. // If it is nil, clock.WallClock will be used. Clock clock.Clock } +// IPAddrResolver implements a resolved from host name to the +// set of IP addresses associated with it. It is notably +// implemented by net.Resolver. +type IPAddrResolver interface { + LookupIPAddr(ctx context.Context, host string) ([]net.IPAddr, error) +} + +// DNSCache implements a cache of DNS lookup results. +type DNSCache interface { + // Lookup returns an IP addresses associated + // with the given host. + Lookup(host string) []string + // Add sets the IP addresses associated with + // the given host name. + Add(host string, ips []string) +} + // DefaultDialOpts returns a DialOpts representing the default // parameters for contacting a controller. func DefaultDialOpts() DialOpts { @@ -186,6 +214,9 @@ type Connection interface { // Addr returns the address used to connect to the API server. Addr() string + // IPAddr returns the IP address used to connect to the API server. + IPAddr() string + // APIHostPorts returns addresses that may be used to connect // to the API server, including the address used to connect. //