From 7f5aaa3148f89299aa8622951cdd4b74214abb76 Mon Sep 17 00:00:00 2001 From: James Rouzier Date: Fri, 9 Sep 2022 19:13:53 +0000 Subject: [PATCH] Init commit for custom forward plugin --- go/cmd/pfdns/plugin.go | 3 +- go/plugin/coredns/forward/README.md | 265 ++++++++++++++ go/plugin/coredns/forward/connect.go | 152 ++++++++ go/plugin/coredns/forward/dnstap.go | 63 ++++ go/plugin/coredns/forward/forward.go | 239 +++++++++++++ go/plugin/coredns/forward/forward_test.go | 24 ++ go/plugin/coredns/forward/fuzz.go | 34 ++ go/plugin/coredns/forward/health.go | 106 ++++++ go/plugin/coredns/forward/health_test.go | 283 +++++++++++++++ go/plugin/coredns/forward/log_test.go | 5 + go/plugin/coredns/forward/metrics.go | 61 ++++ go/plugin/coredns/forward/persistent.go | 161 +++++++++ go/plugin/coredns/forward/persistent_test.go | 109 ++++++ go/plugin/coredns/forward/policy.go | 68 ++++ go/plugin/coredns/forward/proxy.go | 82 +++++ go/plugin/coredns/forward/proxy_test.go | 99 ++++++ go/plugin/coredns/forward/setup.go | 292 +++++++++++++++ .../coredns/forward/setup_policy_test.go | 47 +++ go/plugin/coredns/forward/setup_test.go | 334 ++++++++++++++++++ go/plugin/coredns/forward/type.go | 37 ++ 20 files changed, 2463 insertions(+), 1 deletion(-) create mode 100644 go/plugin/coredns/forward/README.md create mode 100644 go/plugin/coredns/forward/connect.go create mode 100644 go/plugin/coredns/forward/dnstap.go create mode 100644 go/plugin/coredns/forward/forward.go create mode 100644 go/plugin/coredns/forward/forward_test.go create mode 100644 go/plugin/coredns/forward/fuzz.go create mode 100644 go/plugin/coredns/forward/health.go create mode 100644 go/plugin/coredns/forward/health_test.go create mode 100644 go/plugin/coredns/forward/log_test.go create mode 100644 go/plugin/coredns/forward/metrics.go create mode 100644 go/plugin/coredns/forward/persistent.go create mode 100644 go/plugin/coredns/forward/persistent_test.go create mode 100644 go/plugin/coredns/forward/policy.go create mode 100644 go/plugin/coredns/forward/proxy.go create mode 100644 go/plugin/coredns/forward/proxy_test.go create mode 100644 go/plugin/coredns/forward/setup.go create mode 100644 go/plugin/coredns/forward/setup_policy_test.go create mode 100644 go/plugin/coredns/forward/setup_test.go create mode 100644 go/plugin/coredns/forward/type.go diff --git a/go/cmd/pfdns/plugin.go b/go/cmd/pfdns/plugin.go index 1b9deae9f3cb..9bd83b4ba110 100644 --- a/go/cmd/pfdns/plugin.go +++ b/go/cmd/pfdns/plugin.go @@ -20,7 +20,6 @@ import ( _ "github.com/coredns/coredns/plugin/erratic" _ "github.com/coredns/coredns/plugin/errors" _ "github.com/coredns/coredns/plugin/file" - _ "github.com/coredns/coredns/plugin/forward" _ "github.com/coredns/coredns/plugin/geoip" _ "github.com/coredns/coredns/plugin/grpc" _ "github.com/coredns/coredns/plugin/header" @@ -50,5 +49,7 @@ import ( _ "github.com/coredns/coredns/plugin/transfer" _ "github.com/coredns/coredns/plugin/tsig" _ "github.com/coredns/coredns/plugin/whoami" + _ "github.com/inverse-inc/packetfence/go/plugin/coredns/pfdns" + _ "github.com/inverse-inc/packetfence/go/plugin/coredns/forward" ) diff --git a/go/plugin/coredns/forward/README.md b/go/plugin/coredns/forward/README.md new file mode 100644 index 000000000000..0088c9c7cb39 --- /dev/null +++ b/go/plugin/coredns/forward/README.md @@ -0,0 +1,265 @@ +# forward + +## Name + +*forward* - facilitates proxying DNS messages to upstream resolvers. + +## Description + +The *forward* plugin re-uses already opened sockets to the upstreams. It supports UDP, TCP and +DNS-over-TLS and uses in band health checking. + +When it detects an error a health check is performed. This checks runs in a loop, performing each +check at a *0.5s* interval for as long as the upstream reports unhealthy. Once healthy we stop +health checking (until the next error). The health checks use a recursive DNS query (`. IN NS`) +to get upstream health. Any response that is not a network error (REFUSED, NOTIMPL, SERVFAIL, etc) +is taken as a healthy upstream. The health check uses the same protocol as specified in **TO**. If +`max_fails` is set to 0, no checking is performed and upstreams will always be considered healthy. + +When *all* upstreams are down it assumes health checking as a mechanism has failed and will try to +connect to a random upstream (which may or may not work). + +## Syntax + +In its most basic form, a simple forwarder uses this syntax: + +~~~ +forward FROM TO... +~~~ + +* **FROM** is the base domain to match for the request to be forwarded. Domains using CIDR notation + that expand to multiple reverse zones are not fully supported; only the first expanded zone is used. +* **TO...** are the destination endpoints to forward to. The **TO** syntax allows you to specify + a protocol, `tls://9.9.9.9` or `dns://` (or no protocol) for plain DNS. The number of upstreams is + limited to 15. + +Multiple upstreams are randomized (see `policy`) on first use. When a healthy proxy returns an error +during the exchange the next upstream in the list is tried. + +Extra knobs are available with an expanded syntax: + +~~~ +forward FROM TO... { + except IGNORED_NAMES... + force_tcp + prefer_udp + expire DURATION + max_fails INTEGER + tls CERT KEY CA + tls_servername NAME + policy random|round_robin|sequential + health_check DURATION [no_rec] [domain FQDN] + max_concurrent MAX +} +~~~ + +* **FROM** and **TO...** as above. +* **IGNORED_NAMES** in `except` is a space-separated list of domains to exclude from forwarding. + Requests that match none of these names will be passed through. +* `force_tcp`, use TCP even when the request comes in over UDP. +* `prefer_udp`, try first using UDP even when the request comes in over TCP. If response is truncated + (TC flag set in response) then do another attempt over TCP. In case if both `force_tcp` and + `prefer_udp` options specified the `force_tcp` takes precedence. +* `max_fails` is the number of subsequent failed health checks that are needed before considering + an upstream to be down. If 0, the upstream will never be marked as down (nor health checked). + Default is 2. +* `expire` **DURATION**, expire (cached) connections after this time, the default is 10s. +* `tls` **CERT** **KEY** **CA** define the TLS properties for TLS connection. From 0 to 3 arguments can be + provided with the meaning as described below + + * `tls` - no client authentication is used, and the system CAs are used to verify the server certificate + * `tls` **CA** - no client authentication is used, and the file CA is used to verify the server certificate + * `tls` **CERT** **KEY** - client authentication is used with the specified cert/key pair. + The server certificate is verified with the system CAs + * `tls` **CERT** **KEY** **CA** - client authentication is used with the specified cert/key pair. + The server certificate is verified using the specified CA file + +* `tls_servername` **NAME** allows you to set a server name in the TLS configuration; for instance 9.9.9.9 + needs this to be set to `dns.quad9.net`. Multiple upstreams are still allowed in this scenario, + but they have to use the same `tls_servername`. E.g. mixing 9.9.9.9 (QuadDNS) with 1.1.1.1 + (Cloudflare) will not work. Using TLS forwarding but not setting `tls_servername` results in anyone + being able to man-in-the-middle your connection to the DNS server you are forwarding to. Because of this, + it is strongly recommended to set this value when using TLS forwarding. +* `policy` specifies the policy to use for selecting upstream servers. The default is `random`. + * `random` is a policy that implements random upstream selection. + * `round_robin` is a policy that selects hosts based on round robin ordering. + * `sequential` is a policy that selects hosts based on sequential ordering. +* `health_check` configure the behaviour of health checking of the upstream servers + * `` - use a different duration for health checking, the default duration is 0.5s. + * `no_rec` - optional argument that sets the RecursionDesired-flag of the dns-query used in health checking to `false`. + The flag is default `true`. + * `domain FQDN` - set the domain name used for health checks to **FQDN**. + If not configured, the domain name used for health checks is `.`. +* `max_concurrent` **MAX** will limit the number of concurrent queries to **MAX**. Any new query that would + raise the number of concurrent queries above the **MAX** will result in a REFUSED response. This + response does not count as a health failure. When choosing a value for **MAX**, pick a number + at least greater than the expected *upstream query rate* * *latency* of the upstream servers. + As an upper bound for **MAX**, consider that each concurrent query will use about 2kb of memory. + +Also note the TLS config is "global" for the whole forwarding proxy if you need a different +`tls-name` for different upstreams you're out of luck. + +On each endpoint, the timeouts for communication are set as follows: + +* The dial timeout by default is 30s, and can decrease automatically down to 1s based on early results. +* The read timeout is static at 2s. + +## Metadata + +The forward plugin will publish the following metadata, if the *metadata* +plugin is also enabled: + +* `forward/upstream`: the upstream used to forward the request + +## Metrics + +If monitoring is enabled (via the *prometheus* plugin) then the following metric are exported: + +* `coredns_forward_requests_total{to}` - query count per upstream. +* `coredns_forward_responses_total{to}` - Counter of responses received per upstream. +* `coredns_forward_request_duration_seconds{to, rcode, type}` - duration per upstream, RCODE, type +* `coredns_forward_responses_total{to, rcode}` - count of RCODEs per upstream. +* `coredns_forward_healthcheck_failures_total{to}` - number of failed health checks per upstream. +* `coredns_forward_healthcheck_broken_total{}` - counter of when all upstreams are unhealthy, + and we are randomly (this always uses the `random` policy) spraying to an upstream. +* `coredns_forward_max_concurrent_rejects_total{}` - counter of the number of queries rejected because the + number of concurrent queries were at maximum. +* `coredns_forward_conn_cache_hits_total{to, proto}` - counter of connection cache hits per upstream and protocol. +* `coredns_forward_conn_cache_misses_total{to, proto}` - counter of connection cache misses per upstream and protocol. +Where `to` is one of the upstream servers (**TO** from the config), `rcode` is the returned RCODE +from the upstream, `proto` is the transport protocol like `udp`, `tcp`, `tcp-tls`. + +## Examples + +Proxy all requests within `example.org.` to a nameserver running on a different port: + +~~~ corefile +example.org { + forward . 127.0.0.1:9005 +} +~~~ + +Send all requests within `lab.example.local.` to `10.20.0.1`, all requests within `example.local.` (and not in +`lab.example.local.`) to `10.0.0.1`, all others requests to the servers defined in `/etc/resolv.conf`, and +caches results. Note that a CoreDNS server configured with multiple _forward_ plugins in a server block will evaluate those +forward plugins in the order they are listed when serving a request. Therefore, subdomains should be +placed before parent domains otherwise subdomain requests will be forwarded to the parent domain's upstream. +Accordingly, in this example `lab.example.local` is before `example.local`, and `example.local` is before `.`. + +~~~ corefile +. { + cache + forward lab.example.local 10.20.0.1 + forward example.local 10.0.0.1 + forward . /etc/resolv.conf +} +~~~ + +The example above is almost equivalent to the following example, except that example below defines three separate plugin +chains (and thus 3 separate instances of _cache_). + +~~~ corefile +lab.example.local { + cache + forward . 10.20.0.1 +} +example.local { + cache + forward . 10.0.0.1 +} +. { + cache + forward . /etc/resolv.conf +} +~~~ + +Load balance all requests between three resolvers, one of which has a IPv6 address. + +~~~ corefile +. { + forward . 10.0.0.10:53 10.0.0.11:1053 [2003::1]:53 +} +~~~ + +Forward everything except requests to `example.org` + +~~~ corefile +. { + forward . 10.0.0.10:1234 { + except example.org + } +} +~~~ + +Proxy everything except `example.org` using the host's `resolv.conf`'s nameservers: + +~~~ corefile +. { + forward . /etc/resolv.conf { + except example.org + } +} +~~~ + +Proxy all requests to 9.9.9.9 using the DNS-over-TLS (DoT) protocol, and cache every answer for up to 30 +seconds. Note the `tls_servername` is mandatory if you want a working setup, as 9.9.9.9 can't be +used in the TLS negotiation. Also set the health check duration to 5s to not completely swamp the +service with health checks. + +~~~ corefile +. { + forward . tls://9.9.9.9 { + tls_servername dns.quad9.net + health_check 5s + } + cache 30 +} +~~~ + +Or configure other domain name for health check requests + +~~~ corefile +. { + forward . tls://9.9.9.9 { + tls_servername dns.quad9.net + health_check 5s domain example.org + } + cache 30 +} +~~~ + +Or with multiple upstreams from the same provider + +~~~ corefile +. { + forward . tls://1.1.1.1 tls://1.0.0.1 { + tls_servername cloudflare-dns.com + health_check 5s + } + cache 30 +} +~~~ + +Or when you have multiple DoT upstreams with different `tls_servername`s, you can do the following: + +~~~ corefile +. { + forward . 127.0.0.1:5301 127.0.0.1:5302 +} + +.:5301 { + forward . 8.8.8.8 8.8.4.4 { + tls_servername dns.google + } +} + +.:5302 { + forward . 1.1.1.1 1.0.0.1 { + tls_servername cloudflare-dns.com + } +} +~~~ + +## See Also + +[RFC 7858](https://tools.ietf.org/html/rfc7858) for DNS over TLS. diff --git a/go/plugin/coredns/forward/connect.go b/go/plugin/coredns/forward/connect.go new file mode 100644 index 000000000000..3d53044e5efe --- /dev/null +++ b/go/plugin/coredns/forward/connect.go @@ -0,0 +1,152 @@ +// Package forward implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same +// client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be +// 50% faster than just opening a new connection for every client. It works with UDP and TCP and uses +// inband healthchecking. +package forward + +import ( + "context" + "io" + "strconv" + "sync/atomic" + "time" + + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +// limitTimeout is a utility function to auto-tune timeout values +// average observed time is moved towards the last observed delay moderated by a weight +// next timeout to use will be the double of the computed average, limited by min and max frame. +func limitTimeout(currentAvg *int64, minValue time.Duration, maxValue time.Duration) time.Duration { + rt := time.Duration(atomic.LoadInt64(currentAvg)) + if rt < minValue { + return minValue + } + if rt < maxValue/2 { + return 2 * rt + } + return maxValue +} + +func averageTimeout(currentAvg *int64, observedDuration time.Duration, weight int64) { + dt := time.Duration(atomic.LoadInt64(currentAvg)) + atomic.AddInt64(currentAvg, int64(observedDuration-dt)/weight) +} + +func (t *Transport) dialTimeout() time.Duration { + return limitTimeout(&t.avgDialTime, minDialTimeout, maxDialTimeout) +} + +func (t *Transport) updateDialTimeout(newDialTime time.Duration) { + averageTimeout(&t.avgDialTime, newDialTime, cumulativeAvgWeight) +} + +// Dial dials the address configured in transport, potentially reusing a connection or creating a new one. +func (t *Transport) Dial(proto string) (*persistConn, bool, error) { + // If tls has been configured; use it. + if t.tlsConfig != nil { + proto = "tcp-tls" + } + + t.dial <- proto + pc := <-t.ret + + if pc != nil { + ConnCacheHitsCount.WithLabelValues(t.addr, proto).Add(1) + return pc, true, nil + } + ConnCacheMissesCount.WithLabelValues(t.addr, proto).Add(1) + + reqTime := time.Now() + timeout := t.dialTimeout() + if proto == "tcp-tls" { + conn, err := dns.DialTimeoutWithTLS("tcp", t.addr, t.tlsConfig, timeout) + t.updateDialTimeout(time.Since(reqTime)) + return &persistConn{c: conn}, false, err + } + conn, err := dns.DialTimeout(proto, t.addr, timeout) + t.updateDialTimeout(time.Since(reqTime)) + return &persistConn{c: conn}, false, err +} + +// Connect selects an upstream, sends the request and waits for a response. +func (p *Proxy) Connect(ctx context.Context, state request.Request, opts options) (*dns.Msg, error) { + start := time.Now() + + proto := "" + switch { + case opts.forceTCP: // TCP flag has precedence over UDP flag + proto = "tcp" + case opts.preferUDP: + proto = "udp" + default: + proto = state.Proto() + } + + pc, cached, err := p.transport.Dial(proto) + if err != nil { + return nil, err + } + + // Set buffer size correctly for this client. + pc.c.UDPSize = uint16(state.Size()) + if pc.c.UDPSize < 512 { + pc.c.UDPSize = 512 + } + + pc.c.SetWriteDeadline(time.Now().Add(maxTimeout)) + // records the origin Id before upstream. + originId := state.Req.Id + state.Req.Id = dns.Id() + defer func() { + state.Req.Id = originId + }() + + if err := pc.c.WriteMsg(state.Req); err != nil { + pc.c.Close() // not giving it back + if err == io.EOF && cached { + return nil, ErrCachedClosed + } + return nil, err + } + + var ret *dns.Msg + pc.c.SetReadDeadline(time.Now().Add(readTimeout)) + for { + ret, err = pc.c.ReadMsg() + if err != nil { + pc.c.Close() // not giving it back + if err == io.EOF && cached { + return nil, ErrCachedClosed + } + // recovery the origin Id after upstream. + if ret != nil { + ret.Id = originId + } + return ret, err + } + // drop out-of-order responses + if state.Req.Id == ret.Id { + break + } + } + // recovery the origin Id after upstream. + ret.Id = originId + + p.transport.Yield(pc) + + rc, ok := dns.RcodeToString[ret.Rcode] + if !ok { + rc = strconv.Itoa(ret.Rcode) + } + + RequestCount.WithLabelValues(p.addr).Add(1) + RcodeCount.WithLabelValues(rc, p.addr).Add(1) + RequestDuration.WithLabelValues(p.addr, rc).Observe(time.Since(start).Seconds()) + + return ret, nil +} + +const cumulativeAvgWeight = 4 diff --git a/go/plugin/coredns/forward/dnstap.go b/go/plugin/coredns/forward/dnstap.go new file mode 100644 index 000000000000..4e06ac1ff9d6 --- /dev/null +++ b/go/plugin/coredns/forward/dnstap.go @@ -0,0 +1,63 @@ +package forward + +import ( + "net" + "strconv" + "time" + + "github.com/coredns/coredns/plugin/dnstap/msg" + "github.com/coredns/coredns/request" + + tap "github.com/dnstap/golang-dnstap" + "github.com/miekg/dns" +) + +// toDnstap will send the forward and received message to the dnstap plugin. +func toDnstap(f *Forward, host string, state request.Request, opts options, reply *dns.Msg, start time.Time) { + // Query + q := new(tap.Message) + msg.SetQueryTime(q, start) + h, p, _ := net.SplitHostPort(host) // this is preparsed and can't err here + port, _ := strconv.ParseUint(p, 10, 32) // same here + ip := net.ParseIP(h) + + var ta net.Addr = &net.UDPAddr{IP: ip, Port: int(port)} + t := state.Proto() + switch { + case opts.forceTCP: + t = "tcp" + case opts.preferUDP: + t = "udp" + } + + if t == "tcp" { + ta = &net.TCPAddr{IP: ip, Port: int(port)} + } + + // Forwarder dnstap messages are from the perspective of the downstream server + // (upstream is the forward server) + msg.SetQueryAddress(q, state.W.RemoteAddr()) + msg.SetResponseAddress(q, ta) + + if f.tapPlugin.IncludeRawMessage { + buf, _ := state.Req.Pack() + q.QueryMessage = buf + } + msg.SetType(q, tap.Message_FORWARDER_QUERY) + f.tapPlugin.TapMessage(q) + + // Response + if reply != nil { + r := new(tap.Message) + if f.tapPlugin.IncludeRawMessage { + buf, _ := reply.Pack() + r.ResponseMessage = buf + } + msg.SetQueryTime(r, start) + msg.SetQueryAddress(r, state.W.RemoteAddr()) + msg.SetResponseAddress(r, ta) + msg.SetResponseTime(r, time.Now()) + msg.SetType(r, tap.Message_FORWARDER_RESPONSE) + f.tapPlugin.TapMessage(r) + } +} diff --git a/go/plugin/coredns/forward/forward.go b/go/plugin/coredns/forward/forward.go new file mode 100644 index 000000000000..90ae1aef885e --- /dev/null +++ b/go/plugin/coredns/forward/forward.go @@ -0,0 +1,239 @@ +// Package forward implements a forwarding proxy. It caches an upstream net.Conn for some time, so if the same +// client returns the upstream's Conn will be precached. Depending on how you benchmark this looks to be +// 50% faster than just opening a new connection for every client. It works with UDP and TCP and uses +// inband healthchecking. +package forward + +import ( + "context" + "crypto/tls" + "errors" + "sync/atomic" + "time" + + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/debug" + "github.com/coredns/coredns/plugin/dnstap" + "github.com/coredns/coredns/plugin/metadata" + clog "github.com/coredns/coredns/plugin/pkg/log" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" + ot "github.com/opentracing/opentracing-go" + otext "github.com/opentracing/opentracing-go/ext" +) + +var log = clog.NewWithPlugin("forward") + +// Forward represents a plugin instance that can proxy requests to another (DNS) server. It has a list +// of proxies each representing one upstream proxy. +type Forward struct { + concurrent int64 // atomic counters need to be first in struct for proper alignment + + proxies []*Proxy + p Policy + hcInterval time.Duration + + from string + ignored []string + + tlsConfig *tls.Config + tlsServerName string + maxfails uint32 + expire time.Duration + maxConcurrent int64 + + opts options // also here for testing + + // ErrLimitExceeded indicates that a query was rejected because the number of concurrent queries has exceeded + // the maximum allowed (maxConcurrent) + ErrLimitExceeded error + + tapPlugin *dnstap.Dnstap // when the dnstap plugin is loaded, we use to this to send messages out. + + Next plugin.Handler +} + +// New returns a new Forward. +func New() *Forward { + f := &Forward{maxfails: 2, tlsConfig: new(tls.Config), expire: defaultExpire, p: new(random), from: ".", hcInterval: hcInterval, opts: options{forceTCP: false, preferUDP: false, hcRecursionDesired: true, hcDomain: "."}} + return f +} + +// SetProxy appends p to the proxy list and starts healthchecking. +func (f *Forward) SetProxy(p *Proxy) { + f.proxies = append(f.proxies, p) + p.start(f.hcInterval) +} + +// Len returns the number of configured proxies. +func (f *Forward) Len() int { return len(f.proxies) } + +// Name implements plugin.Handler. +func (f *Forward) Name() string { return "forward" } + +// ServeDNS implements plugin.Handler. +func (f *Forward) ServeDNS(ctx context.Context, w dns.ResponseWriter, r *dns.Msg) (int, error) { + state := request.Request{W: w, Req: r} + if !f.match(state) { + return plugin.NextOrFailure(f.Name(), f.Next, ctx, w, r) + } + + if f.maxConcurrent > 0 { + count := atomic.AddInt64(&(f.concurrent), 1) + defer atomic.AddInt64(&(f.concurrent), -1) + if count > f.maxConcurrent { + MaxConcurrentRejectCount.Add(1) + return dns.RcodeRefused, f.ErrLimitExceeded + } + } + + fails := 0 + var span, child ot.Span + var upstreamErr error + span = ot.SpanFromContext(ctx) + i := 0 + list := f.List() + deadline := time.Now().Add(defaultTimeout) + start := time.Now() + for time.Now().Before(deadline) { + if i >= len(list) { + // reached the end of list, reset to begin + i = 0 + fails = 0 + } + + proxy := list[i] + i++ + if proxy.Down(f.maxfails) { + fails++ + if fails < len(f.proxies) { + continue + } + // All upstream proxies are dead, assume healthcheck is completely broken and randomly + // select an upstream to connect to. + r := new(random) + proxy = r.List(f.proxies)[0] + + HealthcheckBrokenCount.Add(1) + } + + if span != nil { + child = span.Tracer().StartSpan("connect", ot.ChildOf(span.Context())) + otext.PeerAddress.Set(child, proxy.addr) + ctx = ot.ContextWithSpan(ctx, child) + } + + metadata.SetValueFunc(ctx, "forward/upstream", func() string { + return proxy.addr + }) + + var ( + ret *dns.Msg + err error + ) + opts := f.opts + for { + ret, err = proxy.Connect(ctx, state, opts) + if err == ErrCachedClosed { // Remote side closed conn, can only happen with TCP. + continue + } + // Retry with TCP if truncated and prefer_udp configured. + if ret != nil && ret.Truncated && !opts.forceTCP && opts.preferUDP { + opts.forceTCP = true + continue + } + break + } + + if child != nil { + child.Finish() + } + + if f.tapPlugin != nil { + toDnstap(f, proxy.addr, state, opts, ret, start) + } + + upstreamErr = err + + if err != nil { + // Kick off health check to see if *our* upstream is broken. + if f.maxfails != 0 { + proxy.Healthcheck() + } + + if fails < len(f.proxies) { + continue + } + break + } + + // Check if the reply is correct; if not return FormErr. + if !state.Match(ret) { + debug.Hexdumpf(ret, "Wrong reply for id: %d, %s %d", ret.Id, state.QName(), state.QType()) + + formerr := new(dns.Msg) + formerr.SetRcode(state.Req, dns.RcodeFormatError) + w.WriteMsg(formerr) + return 0, nil + } + + w.WriteMsg(ret) + return 0, nil + } + + if upstreamErr != nil { + return dns.RcodeServerFailure, upstreamErr + } + + return dns.RcodeServerFailure, ErrNoHealthy +} + +func (f *Forward) match(state request.Request) bool { + if !plugin.Name(f.from).Matches(state.Name()) || !f.isAllowedDomain(state.Name()) { + return false + } + + return true +} + +func (f *Forward) isAllowedDomain(name string) bool { + if dns.Name(name) == dns.Name(f.from) { + return true + } + + for _, ignore := range f.ignored { + if plugin.Name(ignore).Matches(name) { + return false + } + } + return true +} + +// ForceTCP returns if TCP is forced to be used even when the request comes in over UDP. +func (f *Forward) ForceTCP() bool { return f.opts.forceTCP } + +// PreferUDP returns if UDP is preferred to be used even when the request comes in over TCP. +func (f *Forward) PreferUDP() bool { return f.opts.preferUDP } + +// List returns a set of proxies to be used for this client depending on the policy in f. +func (f *Forward) List() []*Proxy { return f.p.List(f.proxies) } + +var ( + // ErrNoHealthy means no healthy proxies left. + ErrNoHealthy = errors.New("no healthy proxies") + // ErrNoForward means no forwarder defined. + ErrNoForward = errors.New("no forwarder defined") + // ErrCachedClosed means cached connection was closed by peer. + ErrCachedClosed = errors.New("cached connection was closed by peer") +) + +// options holds various options that can be set. +type options struct { + forceTCP bool + preferUDP bool + hcRecursionDesired bool + hcDomain string +} + +var defaultTimeout = 5 * time.Second diff --git a/go/plugin/coredns/forward/forward_test.go b/go/plugin/coredns/forward/forward_test.go new file mode 100644 index 000000000000..b0ef47ba9244 --- /dev/null +++ b/go/plugin/coredns/forward/forward_test.go @@ -0,0 +1,24 @@ +package forward + +import ( + "testing" +) + +func TestList(t *testing.T) { + f := Forward{ + proxies: []*Proxy{{addr: "1.1.1.1:53"}, {addr: "2.2.2.2:53"}, {addr: "3.3.3.3:53"}}, + p: &roundRobin{}, + } + + expect := []*Proxy{{addr: "2.2.2.2:53"}, {addr: "1.1.1.1:53"}, {addr: "3.3.3.3:53"}} + got := f.List() + + if len(got) != len(expect) { + t.Fatalf("Expected: %v results, got: %v", len(expect), len(got)) + } + for i, p := range got { + if p.addr != expect[i].addr { + t.Fatalf("Expected proxy %v to be '%v', got: '%v'", i, expect[i].addr, p.addr) + } + } +} diff --git a/go/plugin/coredns/forward/fuzz.go b/go/plugin/coredns/forward/fuzz.go new file mode 100644 index 000000000000..bec573e47695 --- /dev/null +++ b/go/plugin/coredns/forward/fuzz.go @@ -0,0 +1,34 @@ +//go:build gofuzz + +package forward + +import ( + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/fuzz" + + "github.com/miekg/dns" +) + +var f *Forward + +// abuse init to setup an environment to test against. This start another server to that will +// reflect responses. +func init() { + f = New() + s := dnstest.NewServer(r{}.reflectHandler) + f.SetProxy(NewProxy(s.Addr, "tcp")) + f.SetProxy(NewProxy(s.Addr, "udp")) +} + +// Fuzz fuzzes forward. +func Fuzz(data []byte) int { + return fuzz.Do(f, data) +} + +type r struct{} + +func (r r) reflectHandler(w dns.ResponseWriter, req *dns.Msg) { + m := new(dns.Msg) + m.SetReply(req) + w.WriteMsg(m) +} diff --git a/go/plugin/coredns/forward/health.go b/go/plugin/coredns/forward/health.go new file mode 100644 index 000000000000..ec0b4814359b --- /dev/null +++ b/go/plugin/coredns/forward/health.go @@ -0,0 +1,106 @@ +package forward + +import ( + "crypto/tls" + "sync/atomic" + "time" + + "github.com/coredns/coredns/plugin/pkg/transport" + + "github.com/miekg/dns" +) + +// HealthChecker checks the upstream health. +type HealthChecker interface { + Check(*Proxy) error + SetTLSConfig(*tls.Config) + SetRecursionDesired(bool) + GetRecursionDesired() bool + SetDomain(domain string) + GetDomain() string + SetTCPTransport() +} + +// dnsHc is a health checker for a DNS endpoint (DNS, and DoT). +type dnsHc struct { + c *dns.Client + recursionDesired bool + domain string +} + +var ( + hcReadTimeout = 1 * time.Second + hcWriteTimeout = 1 * time.Second +) + +// NewHealthChecker returns a new HealthChecker based on transport. +func NewHealthChecker(trans string, recursionDesired bool, domain string) HealthChecker { + switch trans { + case transport.DNS, transport.TLS: + c := new(dns.Client) + c.Net = "udp" + c.ReadTimeout = hcReadTimeout + c.WriteTimeout = hcWriteTimeout + + return &dnsHc{c: c, recursionDesired: recursionDesired, domain: domain} + } + + log.Warningf("No healthchecker for transport %q", trans) + return nil +} + +func (h *dnsHc) SetTLSConfig(cfg *tls.Config) { + h.c.Net = "tcp-tls" + h.c.TLSConfig = cfg +} + +func (h *dnsHc) SetRecursionDesired(recursionDesired bool) { + h.recursionDesired = recursionDesired +} +func (h *dnsHc) GetRecursionDesired() bool { + return h.recursionDesired +} + +func (h *dnsHc) SetDomain(domain string) { + h.domain = domain +} +func (h *dnsHc) GetDomain() string { + return h.domain +} + +func (h *dnsHc) SetTCPTransport() { + h.c.Net = "tcp" +} + +// For HC we send to . IN NS +[no]rec message to the upstream. Dial timeouts and empty +// replies are considered fails, basically anything else constitutes a healthy upstream. + +// Check is used as the up.Func in the up.Probe. +func (h *dnsHc) Check(p *Proxy) error { + err := h.send(p.addr) + if err != nil { + HealthcheckFailureCount.WithLabelValues(p.addr).Add(1) + atomic.AddUint32(&p.fails, 1) + return err + } + + atomic.StoreUint32(&p.fails, 0) + return nil +} + +func (h *dnsHc) send(addr string) error { + ping := new(dns.Msg) + ping.SetQuestion(h.domain, dns.TypeNS) + ping.MsgHdr.RecursionDesired = h.recursionDesired + + m, _, err := h.c.Exchange(ping, addr) + // If we got a header, we're alright, basically only care about I/O errors 'n stuff. + if err != nil && m != nil { + // Silly check, something sane came back. + if m.Response || m.Opcode == dns.OpcodeQuery { + err = nil + } + } + + return err +} diff --git a/go/plugin/coredns/forward/health_test.go b/go/plugin/coredns/forward/health_test.go new file mode 100644 index 000000000000..9917b3a37c31 --- /dev/null +++ b/go/plugin/coredns/forward/health_test.go @@ -0,0 +1,283 @@ +package forward + +import ( + "context" + "sync/atomic" + "testing" + "time" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/transport" + "github.com/coredns/coredns/plugin/test" + + "github.com/miekg/dns" +) + +func TestHealth(t *testing.T) { + hcReadTimeout = 10 * time.Millisecond + hcWriteTimeout = 10 * time.Millisecond + readTimeout = 10 * time.Millisecond + defaultTimeout = 10 * time.Millisecond + + i := uint32(0) + q := uint32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + if atomic.LoadUint32(&q) == 0 { //drop the first query to trigger health-checking + atomic.AddUint32(&q, 1) + return + } + if r.Question[0].Name == "." && r.RecursionDesired == true { + atomic.AddUint32(&i, 1) + } + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + p := NewProxy(s.Addr, transport.DNS) + f := New() + f.SetProxy(p) + defer f.OnShutdown() + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + + f.ServeDNS(context.TODO(), &test.ResponseWriter{}, req) + + time.Sleep(20 * time.Millisecond) + i1 := atomic.LoadUint32(&i) + if i1 != 1 { + t.Errorf("Expected number of health checks with RecursionDesired==true to be %d, got %d", 1, i1) + } +} + +func TestHealthTCP(t *testing.T) { + hcReadTimeout = 10 * time.Millisecond + hcWriteTimeout = 10 * time.Millisecond + readTimeout = 10 * time.Millisecond + defaultTimeout = 10 * time.Millisecond + + i := uint32(0) + q := uint32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + if atomic.LoadUint32(&q) == 0 { //drop the first query to trigger health-checking + atomic.AddUint32(&q, 1) + return + } + if r.Question[0].Name == "." && r.RecursionDesired == true { + atomic.AddUint32(&i, 1) + } + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + p := NewProxy(s.Addr, transport.DNS) + p.health.SetTCPTransport() + f := New() + f.SetProxy(p) + defer f.OnShutdown() + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + + f.ServeDNS(context.TODO(), &test.ResponseWriter{TCP: true}, req) + + time.Sleep(20 * time.Millisecond) + i1 := atomic.LoadUint32(&i) + if i1 != 1 { + t.Errorf("Expected number of health checks with RecursionDesired==true to be %d, got %d", 1, i1) + } +} + +func TestHealthNoRecursion(t *testing.T) { + hcReadTimeout = 10 * time.Millisecond + readTimeout = 10 * time.Millisecond + defaultTimeout = 10 * time.Millisecond + hcWriteTimeout = 10 * time.Millisecond + + i := uint32(0) + q := uint32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + if atomic.LoadUint32(&q) == 0 { //drop the first query to trigger health-checking + atomic.AddUint32(&q, 1) + return + } + if r.Question[0].Name == "." && r.RecursionDesired == false { + atomic.AddUint32(&i, 1) + } + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + p := NewProxy(s.Addr, transport.DNS) + p.health.SetRecursionDesired(false) + f := New() + f.SetProxy(p) + defer f.OnShutdown() + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + + f.ServeDNS(context.TODO(), &test.ResponseWriter{}, req) + + time.Sleep(20 * time.Millisecond) + i1 := atomic.LoadUint32(&i) + if i1 != 1 { + t.Errorf("Expected number of health checks with RecursionDesired==false to be %d, got %d", 1, i1) + } +} + +func TestHealthTimeout(t *testing.T) { + hcReadTimeout = 10 * time.Millisecond + hcWriteTimeout = 10 * time.Millisecond + readTimeout = 10 * time.Millisecond + defaultTimeout = 10 * time.Millisecond + + i := uint32(0) + q := uint32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + if r.Question[0].Name == "." { + // health check, answer + atomic.AddUint32(&i, 1) + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + return + } + if atomic.LoadUint32(&q) == 0 { //drop only first query + atomic.AddUint32(&q, 1) + return + } + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + p := NewProxy(s.Addr, transport.DNS) + f := New() + f.SetProxy(p) + defer f.OnShutdown() + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + + f.ServeDNS(context.TODO(), &test.ResponseWriter{}, req) + + time.Sleep(20 * time.Millisecond) + i1 := atomic.LoadUint32(&i) + if i1 != 1 { + t.Errorf("Expected number of health checks to be %d, got %d", 1, i1) + } +} + +func TestHealthMaxFails(t *testing.T) { + hcReadTimeout = 10 * time.Millisecond + hcWriteTimeout = 10 * time.Millisecond + readTimeout = 10 * time.Millisecond + defaultTimeout = 10 * time.Millisecond + hcInterval = 10 * time.Millisecond + + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + // timeout + }) + defer s.Close() + + p := NewProxy(s.Addr, transport.DNS) + f := New() + f.maxfails = 2 + f.SetProxy(p) + defer f.OnShutdown() + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + + f.ServeDNS(context.TODO(), &test.ResponseWriter{}, req) + + time.Sleep(100 * time.Millisecond) + fails := atomic.LoadUint32(&p.fails) + if !p.Down(f.maxfails) { + t.Errorf("Expected Proxy fails to be greater than %d, got %d", f.maxfails, fails) + } +} + +func TestHealthNoMaxFails(t *testing.T) { + hcReadTimeout = 10 * time.Millisecond + hcWriteTimeout = 10 * time.Millisecond + readTimeout = 10 * time.Millisecond + defaultTimeout = 10 * time.Millisecond + hcInterval = 10 * time.Millisecond + + i := uint32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + if r.Question[0].Name == "." { + // health check, answer + atomic.AddUint32(&i, 1) + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + } + }) + defer s.Close() + + p := NewProxy(s.Addr, transport.DNS) + f := New() + f.maxfails = 0 + f.SetProxy(p) + defer f.OnShutdown() + + req := new(dns.Msg) + req.SetQuestion("example.org.", dns.TypeA) + + f.ServeDNS(context.TODO(), &test.ResponseWriter{}, req) + + time.Sleep(20 * time.Millisecond) + i1 := atomic.LoadUint32(&i) + if i1 != 0 { + t.Errorf("Expected number of health checks to be %d, got %d", 0, i1) + } +} + +func TestHealthDomain(t *testing.T) { + hcReadTimeout = 10 * time.Millisecond + readTimeout = 10 * time.Millisecond + defaultTimeout = 10 * time.Millisecond + hcWriteTimeout = 10 * time.Millisecond + hcDomain := "example.org." + i := uint32(0) + q := uint32(0) + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + if atomic.LoadUint32(&q) == 0 { //drop the first query to trigger health-checking + atomic.AddUint32(&q, 1) + return + } + if r.Question[0].Name == hcDomain && r.RecursionDesired == true { + atomic.AddUint32(&i, 1) + } + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + p := NewProxy(s.Addr, transport.DNS) + p.health.SetDomain(hcDomain) + f := New() + f.SetProxy(p) + defer f.OnShutdown() + + req := new(dns.Msg) + req.SetQuestion(".", dns.TypeNS) + + f.ServeDNS(context.TODO(), &test.ResponseWriter{}, req) + + time.Sleep(20 * time.Millisecond) + i1 := atomic.LoadUint32(&i) + if i1 != 1 { + t.Errorf("Expected number of health checks with Domain==%s to be %d, got %d", hcDomain, 1, i1) + } +} diff --git a/go/plugin/coredns/forward/log_test.go b/go/plugin/coredns/forward/log_test.go new file mode 100644 index 000000000000..a7f0a8589f6e --- /dev/null +++ b/go/plugin/coredns/forward/log_test.go @@ -0,0 +1,5 @@ +package forward + +import clog "github.com/coredns/coredns/plugin/pkg/log" + +func init() { clog.Discard() } diff --git a/go/plugin/coredns/forward/metrics.go b/go/plugin/coredns/forward/metrics.go new file mode 100644 index 000000000000..f1f0c48d67e5 --- /dev/null +++ b/go/plugin/coredns/forward/metrics.go @@ -0,0 +1,61 @@ +package forward + +import ( + "github.com/coredns/coredns/plugin" + + "github.com/prometheus/client_golang/prometheus" + "github.com/prometheus/client_golang/prometheus/promauto" +) + +// Variables declared for monitoring. +var ( + RequestCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "forward", + Name: "requests_total", + Help: "Counter of requests made per upstream.", + }, []string{"to"}) + RcodeCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "forward", + Name: "responses_total", + Help: "Counter of responses received per upstream.", + }, []string{"rcode", "to"}) + RequestDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{ + Namespace: plugin.Namespace, + Subsystem: "forward", + Name: "request_duration_seconds", + Buckets: plugin.TimeBuckets, + Help: "Histogram of the time each request took.", + }, []string{"to", "rcode"}) + HealthcheckFailureCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "forward", + Name: "healthcheck_failures_total", + Help: "Counter of the number of failed healthchecks.", + }, []string{"to"}) + HealthcheckBrokenCount = promauto.NewCounter(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "forward", + Name: "healthcheck_broken_total", + Help: "Counter of the number of complete failures of the healthchecks.", + }) + MaxConcurrentRejectCount = promauto.NewCounter(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "forward", + Name: "max_concurrent_rejects_total", + Help: "Counter of the number of queries rejected because the concurrent queries were at maximum.", + }) + ConnCacheHitsCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "forward", + Name: "conn_cache_hits_total", + Help: "Counter of connection cache hits per upstream and protocol.", + }, []string{"to", "proto"}) + ConnCacheMissesCount = promauto.NewCounterVec(prometheus.CounterOpts{ + Namespace: plugin.Namespace, + Subsystem: "forward", + Name: "conn_cache_misses_total", + Help: "Counter of connection cache misses per upstream and protocol.", + }, []string{"to", "proto"}) +) diff --git a/go/plugin/coredns/forward/persistent.go b/go/plugin/coredns/forward/persistent.go new file mode 100644 index 000000000000..95d08e1e148c --- /dev/null +++ b/go/plugin/coredns/forward/persistent.go @@ -0,0 +1,161 @@ +package forward + +import ( + "crypto/tls" + "sort" + "time" + + "github.com/miekg/dns" +) + +// a persistConn hold the dns.Conn and the last used time. +type persistConn struct { + c *dns.Conn + used time.Time +} + +// Transport hold the persistent cache. +type Transport struct { + avgDialTime int64 // kind of average time of dial time + conns [typeTotalCount][]*persistConn // Buckets for udp, tcp and tcp-tls. + expire time.Duration // After this duration a connection is expired. + addr string + tlsConfig *tls.Config + + dial chan string + yield chan *persistConn + ret chan *persistConn + stop chan bool +} + +func newTransport(addr string) *Transport { + t := &Transport{ + avgDialTime: int64(maxDialTimeout / 2), + conns: [typeTotalCount][]*persistConn{}, + expire: defaultExpire, + addr: addr, + dial: make(chan string), + yield: make(chan *persistConn), + ret: make(chan *persistConn), + stop: make(chan bool), + } + return t +} + +// connManagers manages the persistent connection cache for UDP and TCP. +func (t *Transport) connManager() { + ticker := time.NewTicker(defaultExpire) +Wait: + for { + select { + case proto := <-t.dial: + transtype := stringToTransportType(proto) + // take the last used conn - complexity O(1) + if stack := t.conns[transtype]; len(stack) > 0 { + pc := stack[len(stack)-1] + if time.Since(pc.used) < t.expire { + // Found one, remove from pool and return this conn. + t.conns[transtype] = stack[:len(stack)-1] + t.ret <- pc + continue Wait + } + // clear entire cache if the last conn is expired + t.conns[transtype] = nil + // now, the connections being passed to closeConns() are not reachable from + // transport methods anymore. So, it's safe to close them in a separate goroutine + go closeConns(stack) + } + t.ret <- nil + + case pc := <-t.yield: + transtype := t.transportTypeFromConn(pc) + t.conns[transtype] = append(t.conns[transtype], pc) + + case <-ticker.C: + t.cleanup(false) + + case <-t.stop: + t.cleanup(true) + close(t.ret) + return + } + } +} + +// closeConns closes connections. +func closeConns(conns []*persistConn) { + for _, pc := range conns { + pc.c.Close() + } +} + +// cleanup removes connections from cache. +func (t *Transport) cleanup(all bool) { + staleTime := time.Now().Add(-t.expire) + for transtype, stack := range t.conns { + if len(stack) == 0 { + continue + } + if all { + t.conns[transtype] = nil + // now, the connections being passed to closeConns() are not reachable from + // transport methods anymore. So, it's safe to close them in a separate goroutine + go closeConns(stack) + continue + } + if stack[0].used.After(staleTime) { + continue + } + + // connections in stack are sorted by "used" + good := sort.Search(len(stack), func(i int) bool { + return stack[i].used.After(staleTime) + }) + t.conns[transtype] = stack[good:] + // now, the connections being passed to closeConns() are not reachable from + // transport methods anymore. So, it's safe to close them in a separate goroutine + go closeConns(stack[:good]) + } +} + +// It is hard to pin a value to this, the import thing is to no block forever, losing at cached connection is not terrible. +const yieldTimeout = 25 * time.Millisecond + +// Yield returns the connection to transport for reuse. +func (t *Transport) Yield(pc *persistConn) { + pc.used = time.Now() // update used time + + // Make this non-blocking, because in the case of a very busy forwarder we will *block* on this yield. This + // blocks the outer go-routine and stuff will just pile up. We timeout when the send fails to as returning + // these connection is an optimization anyway. + select { + case t.yield <- pc: + return + case <-time.After(yieldTimeout): + return + } +} + +// Start starts the transport's connection manager. +func (t *Transport) Start() { go t.connManager() } + +// Stop stops the transport's connection manager. +func (t *Transport) Stop() { close(t.stop) } + +// SetExpire sets the connection expire time in transport. +func (t *Transport) SetExpire(expire time.Duration) { t.expire = expire } + +// SetTLSConfig sets the TLS config in transport. +func (t *Transport) SetTLSConfig(cfg *tls.Config) { t.tlsConfig = cfg } + +const ( + defaultExpire = 10 * time.Second + minDialTimeout = 1 * time.Second + maxDialTimeout = 30 * time.Second +) + +// Make a var for minimizing this value in tests. +var ( + // Some resolves might take quite a while, usually (cached) responses are fast. Set to 2s to give us some time to retry a different upstream. + readTimeout = 2 * time.Second +) diff --git a/go/plugin/coredns/forward/persistent_test.go b/go/plugin/coredns/forward/persistent_test.go new file mode 100644 index 000000000000..633696ac01b8 --- /dev/null +++ b/go/plugin/coredns/forward/persistent_test.go @@ -0,0 +1,109 @@ +package forward + +import ( + "testing" + "time" + + "github.com/coredns/coredns/plugin/pkg/dnstest" + + "github.com/miekg/dns" +) + +func TestCached(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + tr := newTransport(s.Addr) + tr.Start() + defer tr.Stop() + + c1, cache1, _ := tr.Dial("udp") + c2, cache2, _ := tr.Dial("udp") + + if cache1 || cache2 { + t.Errorf("Expected non-cached connection") + } + + tr.Yield(c1) + tr.Yield(c2) + c3, cached3, _ := tr.Dial("udp") + if !cached3 { + t.Error("Expected cached connection (c3)") + } + if c2 != c3 { + t.Error("Expected c2 == c3") + } + + tr.Yield(c3) + + // dial another protocol + c4, cached4, _ := tr.Dial("tcp") + if cached4 { + t.Errorf("Expected non-cached connection (c4)") + } + tr.Yield(c4) +} + +func TestCleanupByTimer(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + tr := newTransport(s.Addr) + tr.SetExpire(100 * time.Millisecond) + tr.Start() + defer tr.Stop() + + c1, _, _ := tr.Dial("udp") + c2, _, _ := tr.Dial("udp") + tr.Yield(c1) + time.Sleep(10 * time.Millisecond) + tr.Yield(c2) + + time.Sleep(120 * time.Millisecond) + c3, cached, _ := tr.Dial("udp") + if cached { + t.Error("Expected non-cached connection (c3)") + } + tr.Yield(c3) + + time.Sleep(120 * time.Millisecond) + c4, cached, _ := tr.Dial("udp") + if cached { + t.Error("Expected non-cached connection (c4)") + } + tr.Yield(c4) +} + +func TestCleanupAll(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + w.WriteMsg(ret) + }) + defer s.Close() + + tr := newTransport(s.Addr) + + c1, _ := dns.DialTimeout("udp", tr.addr, maxDialTimeout) + c2, _ := dns.DialTimeout("udp", tr.addr, maxDialTimeout) + c3, _ := dns.DialTimeout("udp", tr.addr, maxDialTimeout) + + tr.conns[typeUDP] = []*persistConn{{c1, time.Now()}, {c2, time.Now()}, {c3, time.Now()}} + + if len(tr.conns[typeUDP]) != 3 { + t.Error("Expected 3 connections") + } + tr.cleanup(true) + + if len(tr.conns[typeUDP]) > 0 { + t.Error("Expected no cached connections") + } +} diff --git a/go/plugin/coredns/forward/policy.go b/go/plugin/coredns/forward/policy.go new file mode 100644 index 000000000000..e81e4ab91043 --- /dev/null +++ b/go/plugin/coredns/forward/policy.go @@ -0,0 +1,68 @@ +package forward + +import ( + "sync/atomic" + "time" + + "github.com/coredns/coredns/plugin/pkg/rand" +) + +// Policy defines a policy we use for selecting upstreams. +type Policy interface { + List([]*Proxy) []*Proxy + String() string +} + +// random is a policy that implements random upstream selection. +type random struct{} + +func (r *random) String() string { return "random" } + +func (r *random) List(p []*Proxy) []*Proxy { + switch len(p) { + case 1: + return p + case 2: + if rn.Int()%2 == 0 { + return []*Proxy{p[1], p[0]} // swap + } + return p + } + + perms := rn.Perm(len(p)) + rnd := make([]*Proxy, len(p)) + + for i, p1 := range perms { + rnd[i] = p[p1] + } + return rnd +} + +// roundRobin is a policy that selects hosts based on round robin ordering. +type roundRobin struct { + robin uint32 +} + +func (r *roundRobin) String() string { return "round_robin" } + +func (r *roundRobin) List(p []*Proxy) []*Proxy { + poolLen := uint32(len(p)) + i := atomic.AddUint32(&r.robin, 1) % poolLen + + robin := []*Proxy{p[i]} + robin = append(robin, p[:i]...) + robin = append(robin, p[i+1:]...) + + return robin +} + +// sequential is a policy that selects hosts based on sequential ordering. +type sequential struct{} + +func (r *sequential) String() string { return "sequential" } + +func (r *sequential) List(p []*Proxy) []*Proxy { + return p +} + +var rn = rand.New(time.Now().UnixNano()) diff --git a/go/plugin/coredns/forward/proxy.go b/go/plugin/coredns/forward/proxy.go new file mode 100644 index 000000000000..6a4b5693e654 --- /dev/null +++ b/go/plugin/coredns/forward/proxy.go @@ -0,0 +1,82 @@ +package forward + +import ( + "crypto/tls" + "runtime" + "sync/atomic" + "time" + + "github.com/coredns/coredns/plugin/pkg/up" +) + +// Proxy defines an upstream host. +type Proxy struct { + fails uint32 + addr string + + transport *Transport + + // health checking + probe *up.Probe + health HealthChecker +} + +// NewProxy returns a new proxy. +func NewProxy(addr, trans string) *Proxy { + p := &Proxy{ + addr: addr, + fails: 0, + probe: up.New(), + transport: newTransport(addr), + } + p.health = NewHealthChecker(trans, true, ".") + runtime.SetFinalizer(p, (*Proxy).finalizer) + return p +} + +// SetTLSConfig sets the TLS config in the lower p.transport and in the healthchecking client. +func (p *Proxy) SetTLSConfig(cfg *tls.Config) { + p.transport.SetTLSConfig(cfg) + p.health.SetTLSConfig(cfg) +} + +// SetExpire sets the expire duration in the lower p.transport. +func (p *Proxy) SetExpire(expire time.Duration) { p.transport.SetExpire(expire) } + +// Healthcheck kicks of a round of health checks for this proxy. +func (p *Proxy) Healthcheck() { + if p.health == nil { + log.Warning("No healthchecker") + return + } + + p.probe.Do(func() error { + return p.health.Check(p) + }) +} + +// Down returns true if this proxy is down, i.e. has *more* fails than maxfails. +func (p *Proxy) Down(maxfails uint32) bool { + if maxfails == 0 { + return false + } + + fails := atomic.LoadUint32(&p.fails) + return fails > maxfails +} + +// close stops the health checking goroutine. +func (p *Proxy) stop() { p.probe.Stop() } +func (p *Proxy) finalizer() { p.transport.Stop() } + +// start starts the proxy's healthchecking. +func (p *Proxy) start(duration time.Duration) { + p.probe.Start(duration) + p.transport.Start() +} + +const ( + maxTimeout = 2 * time.Second +) + +var hcInterval = 500 * time.Millisecond diff --git a/go/plugin/coredns/forward/proxy_test.go b/go/plugin/coredns/forward/proxy_test.go new file mode 100644 index 000000000000..74a0b5c4b6cc --- /dev/null +++ b/go/plugin/coredns/forward/proxy_test.go @@ -0,0 +1,99 @@ +package forward + +import ( + "context" + "testing" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/plugin/pkg/dnstest" + "github.com/coredns/coredns/plugin/pkg/transport" + "github.com/coredns/coredns/plugin/test" + "github.com/coredns/coredns/request" + + "github.com/miekg/dns" +) + +func TestProxy(t *testing.T) { + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1")) + w.WriteMsg(ret) + }) + defer s.Close() + + c := caddy.NewTestController("dns", "forward . "+s.Addr) + fs, err := parseForward(c) + f := fs[0] + if err != nil { + t.Errorf("Failed to create forwarder: %s", err) + } + f.OnStartup() + defer f.OnShutdown() + + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeA) + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + if _, err := f.ServeDNS(context.TODO(), rec, m); err != nil { + t.Fatal("Expected to receive reply, but didn't") + } + if x := rec.Msg.Answer[0].Header().Name; x != "example.org." { + t.Errorf("Expected %s, got %s", "example.org.", x) + } +} + +func TestProxyTLSFail(t *testing.T) { + // This is an udp/tcp test server, so we shouldn't reach it with TLS. + s := dnstest.NewServer(func(w dns.ResponseWriter, r *dns.Msg) { + ret := new(dns.Msg) + ret.SetReply(r) + ret.Answer = append(ret.Answer, test.A("example.org. IN A 127.0.0.1")) + w.WriteMsg(ret) + }) + defer s.Close() + + c := caddy.NewTestController("dns", "forward . tls://"+s.Addr) + fs, err := parseForward(c) + f := fs[0] + if err != nil { + t.Errorf("Failed to create forwarder: %s", err) + } + f.OnStartup() + defer f.OnShutdown() + + m := new(dns.Msg) + m.SetQuestion("example.org.", dns.TypeA) + rec := dnstest.NewRecorder(&test.ResponseWriter{}) + + if _, err := f.ServeDNS(context.TODO(), rec, m); err == nil { + t.Fatal("Expected *not* to receive reply, but got one") + } +} + +func TestProtocolSelection(t *testing.T) { + p := NewProxy("bad_address", transport.DNS) + + stateUDP := request.Request{W: &test.ResponseWriter{}, Req: new(dns.Msg)} + stateTCP := request.Request{W: &test.ResponseWriter{TCP: true}, Req: new(dns.Msg)} + ctx := context.TODO() + + go func() { + p.Connect(ctx, stateUDP, options{}) + p.Connect(ctx, stateUDP, options{forceTCP: true}) + p.Connect(ctx, stateUDP, options{preferUDP: true}) + p.Connect(ctx, stateUDP, options{preferUDP: true, forceTCP: true}) + p.Connect(ctx, stateTCP, options{}) + p.Connect(ctx, stateTCP, options{forceTCP: true}) + p.Connect(ctx, stateTCP, options{preferUDP: true}) + p.Connect(ctx, stateTCP, options{preferUDP: true, forceTCP: true}) + }() + + for i, exp := range []string{"udp", "tcp", "udp", "tcp", "tcp", "tcp", "udp", "tcp"} { + proto := <-p.transport.dial + p.transport.ret <- nil + if proto != exp { + t.Errorf("Unexpected protocol in case %d, expected %q, actual %q", i, exp, proto) + } + } +} diff --git a/go/plugin/coredns/forward/setup.go b/go/plugin/coredns/forward/setup.go new file mode 100644 index 000000000000..dfae70d37806 --- /dev/null +++ b/go/plugin/coredns/forward/setup.go @@ -0,0 +1,292 @@ +package forward + +import ( + "crypto/tls" + "errors" + "fmt" + "strconv" + "time" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + "github.com/coredns/coredns/plugin" + "github.com/coredns/coredns/plugin/dnstap" + "github.com/coredns/coredns/plugin/pkg/parse" + pkgtls "github.com/coredns/coredns/plugin/pkg/tls" + "github.com/coredns/coredns/plugin/pkg/transport" + + "github.com/miekg/dns" +) + +func init() { plugin.Register("forward", setup) } + +func setup(c *caddy.Controller) error { + fs, err := parseForward(c) + if err != nil { + return plugin.Error("forward", err) + } + for i := range fs { + f := fs[i] + if f.Len() > max { + return plugin.Error("forward", fmt.Errorf("more than %d TOs configured: %d", max, f.Len())) + } + + if i == len(fs)-1 { + // last forward: point next to next plugin + dnsserver.GetConfig(c).AddPlugin(func(next plugin.Handler) plugin.Handler { + f.Next = next + return f + }) + } else { + // middle forward: point next to next forward + nextForward := fs[i+1] + dnsserver.GetConfig(c).AddPlugin(func(plugin.Handler) plugin.Handler { + f.Next = nextForward + return f + }) + } + + c.OnStartup(func() error { + return f.OnStartup() + }) + c.OnStartup(func() error { + if taph := dnsserver.GetConfig(c).Handler("dnstap"); taph != nil { + if tapPlugin, ok := taph.(dnstap.Dnstap); ok { + f.tapPlugin = &tapPlugin + } + } + return nil + }) + + c.OnShutdown(func() error { + return f.OnShutdown() + }) + } + + return nil +} + +// OnStartup starts a goroutines for all proxies. +func (f *Forward) OnStartup() (err error) { + for _, p := range f.proxies { + p.start(f.hcInterval) + } + return nil +} + +// OnShutdown stops all configured proxies. +func (f *Forward) OnShutdown() error { + for _, p := range f.proxies { + p.stop() + } + return nil +} + +func parseForward(c *caddy.Controller) ([]*Forward, error) { + var fs = []*Forward{} + for c.Next() { + f, err := parseStanza(c) + if err != nil { + return nil, err + } + fs = append(fs, f) + } + return fs, nil +} + +func parseStanza(c *caddy.Controller) (*Forward, error) { + f := New() + + if !c.Args(&f.from) { + return f, c.ArgErr() + } + origFrom := f.from + zones := plugin.Host(f.from).NormalizeExact() + if len(zones) == 0 { + return f, fmt.Errorf("unable to normalize '%s'", f.from) + } + f.from = zones[0] // there can only be one here, won't work with non-octet reverse + + if len(zones) > 1 { + log.Warningf("Unsupported CIDR notation: '%s' expands to multiple zones. Using only '%s'.", origFrom, f.from) + } + + to := c.RemainingArgs() + if len(to) == 0 { + return f, c.ArgErr() + } + + toHosts, err := parse.HostPortOrFile(to...) + if err != nil { + return f, err + } + + transports := make([]string, len(toHosts)) + allowedTrans := map[string]bool{"dns": true, "tls": true} + for i, host := range toHosts { + trans, h := parse.Transport(host) + + if !allowedTrans[trans] { + return f, fmt.Errorf("'%s' is not supported as a destination protocol in forward: %s", trans, host) + } + p := NewProxy(h, trans) + f.proxies = append(f.proxies, p) + transports[i] = trans + } + + for c.NextBlock() { + if err := parseBlock(c, f); err != nil { + return f, err + } + } + + if f.tlsServerName != "" { + f.tlsConfig.ServerName = f.tlsServerName + } + + // Initialize ClientSessionCache in tls.Config. This may speed up a TLS handshake + // in upcoming connections to the same TLS server. + f.tlsConfig.ClientSessionCache = tls.NewLRUClientSessionCache(len(f.proxies)) + + for i := range f.proxies { + // Only set this for proxies that need it. + if transports[i] == transport.TLS { + f.proxies[i].SetTLSConfig(f.tlsConfig) + } + f.proxies[i].SetExpire(f.expire) + f.proxies[i].health.SetRecursionDesired(f.opts.hcRecursionDesired) + // when TLS is used, checks are set to tcp-tls + if f.opts.forceTCP && transports[i] != transport.TLS { + f.proxies[i].health.SetTCPTransport() + } + f.proxies[i].health.SetDomain(f.opts.hcDomain) + } + + return f, nil +} + +func parseBlock(c *caddy.Controller, f *Forward) error { + switch c.Val() { + case "except": + ignore := c.RemainingArgs() + if len(ignore) == 0 { + return c.ArgErr() + } + for i := 0; i < len(ignore); i++ { + f.ignored = append(f.ignored, plugin.Host(ignore[i]).NormalizeExact()...) + } + case "max_fails": + if !c.NextArg() { + return c.ArgErr() + } + n, err := strconv.ParseUint(c.Val(), 10, 32) + if err != nil { + return err + } + f.maxfails = uint32(n) + case "health_check": + if !c.NextArg() { + return c.ArgErr() + } + dur, err := time.ParseDuration(c.Val()) + if err != nil { + return err + } + if dur < 0 { + return fmt.Errorf("health_check can't be negative: %d", dur) + } + f.hcInterval = dur + f.opts.hcDomain = "." + + for c.NextArg() { + switch hcOpts := c.Val(); hcOpts { + case "no_rec": + f.opts.hcRecursionDesired = false + case "domain": + if !c.NextArg() { + return c.ArgErr() + } + hcDomain := c.Val() + if _, ok := dns.IsDomainName(hcDomain); !ok { + return fmt.Errorf("health_check: invalid domain name %s", hcDomain) + } + f.opts.hcDomain = plugin.Name(hcDomain).Normalize() + default: + return fmt.Errorf("health_check: unknown option %s", hcOpts) + } + } + + case "force_tcp": + if c.NextArg() { + return c.ArgErr() + } + f.opts.forceTCP = true + case "prefer_udp": + if c.NextArg() { + return c.ArgErr() + } + f.opts.preferUDP = true + case "tls": + args := c.RemainingArgs() + if len(args) > 3 { + return c.ArgErr() + } + + tlsConfig, err := pkgtls.NewTLSConfigFromArgs(args...) + if err != nil { + return err + } + f.tlsConfig = tlsConfig + case "tls_servername": + if !c.NextArg() { + return c.ArgErr() + } + f.tlsServerName = c.Val() + case "expire": + if !c.NextArg() { + return c.ArgErr() + } + dur, err := time.ParseDuration(c.Val()) + if err != nil { + return err + } + if dur < 0 { + return fmt.Errorf("expire can't be negative: %s", dur) + } + f.expire = dur + case "policy": + if !c.NextArg() { + return c.ArgErr() + } + switch x := c.Val(); x { + case "random": + f.p = &random{} + case "round_robin": + f.p = &roundRobin{} + case "sequential": + f.p = &sequential{} + default: + return c.Errf("unknown policy '%s'", x) + } + case "max_concurrent": + if !c.NextArg() { + return c.ArgErr() + } + n, err := strconv.Atoi(c.Val()) + if err != nil { + return err + } + if n < 0 { + return fmt.Errorf("max_concurrent can't be negative: %d", n) + } + f.ErrLimitExceeded = errors.New("concurrent queries exceeded maximum " + c.Val()) + f.maxConcurrent = int64(n) + + default: + return c.Errf("unknown property '%s'", c.Val()) + } + + return nil +} + +const max = 15 // Maximum number of upstreams. diff --git a/go/plugin/coredns/forward/setup_policy_test.go b/go/plugin/coredns/forward/setup_policy_test.go new file mode 100644 index 000000000000..13466d7a34dd --- /dev/null +++ b/go/plugin/coredns/forward/setup_policy_test.go @@ -0,0 +1,47 @@ +package forward + +import ( + "strings" + "testing" + + "github.com/coredns/caddy" +) + +func TestSetupPolicy(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedPolicy string + expectedErr string + }{ + // positive + {"forward . 127.0.0.1 {\npolicy random\n}\n", false, "random", ""}, + {"forward . 127.0.0.1 {\npolicy round_robin\n}\n", false, "round_robin", ""}, + {"forward . 127.0.0.1 {\npolicy sequential\n}\n", false, "sequential", ""}, + // negative + {"forward . 127.0.0.1 {\npolicy random2\n}\n", true, "random", "unknown policy"}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + fs, err := parseForward(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErr) { + t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input) + } + } + + if !test.shouldErr && (len(fs) == 0 || fs[0].p.String() != test.expectedPolicy) { + t.Errorf("Test %d: expected: %s, got: %s", i, test.expectedPolicy, fs[0].p.String()) + } + } +} diff --git a/go/plugin/coredns/forward/setup_test.go b/go/plugin/coredns/forward/setup_test.go new file mode 100644 index 000000000000..4b17430985fa --- /dev/null +++ b/go/plugin/coredns/forward/setup_test.go @@ -0,0 +1,334 @@ +package forward + +import ( + "os" + "reflect" + "strings" + "testing" + + "github.com/coredns/caddy" + "github.com/coredns/coredns/core/dnsserver" + + "github.com/miekg/dns" +) + +func TestSetup(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedFrom string + expectedIgnored []string + expectedFails uint32 + expectedOpts options + expectedErr string + }{ + // positive + {"forward . 127.0.0.1", false, ".", nil, 2, options{hcRecursionDesired: true, hcDomain: "."}, ""}, + {"forward . 127.0.0.1 {\nhealth_check 0.5s domain example.org\n}\n", false, ".", nil, 2, options{hcRecursionDesired: true, hcDomain: "example.org."}, ""}, + {"forward . 127.0.0.1 {\nexcept miek.nl\n}\n", false, ".", nil, 2, options{hcRecursionDesired: true, hcDomain: "."}, ""}, + {"forward . 127.0.0.1 {\nmax_fails 3\n}\n", false, ".", nil, 3, options{hcRecursionDesired: true, hcDomain: "."}, ""}, + {"forward . 127.0.0.1 {\nforce_tcp\n}\n", false, ".", nil, 2, options{forceTCP: true, hcRecursionDesired: true, hcDomain: "."}, ""}, + {"forward . 127.0.0.1 {\nprefer_udp\n}\n", false, ".", nil, 2, options{preferUDP: true, hcRecursionDesired: true, hcDomain: "."}, ""}, + {"forward . 127.0.0.1 {\nforce_tcp\nprefer_udp\n}\n", false, ".", nil, 2, options{preferUDP: true, forceTCP: true, hcRecursionDesired: true, hcDomain: "."}, ""}, + {"forward . 127.0.0.1:53", false, ".", nil, 2, options{hcRecursionDesired: true, hcDomain: "."}, ""}, + {"forward . 127.0.0.1:8080", false, ".", nil, 2, options{hcRecursionDesired: true, hcDomain: "."}, ""}, + {"forward . [::1]:53", false, ".", nil, 2, options{hcRecursionDesired: true, hcDomain: "."}, ""}, + {"forward . [2003::1]:53", false, ".", nil, 2, options{hcRecursionDesired: true, hcDomain: "."}, ""}, + {"forward . 127.0.0.1 \n", false, ".", nil, 2, options{hcRecursionDesired: true, hcDomain: "."}, ""}, + {"forward 10.9.3.0/18 127.0.0.1", false, "0.9.10.in-addr.arpa.", nil, 2, options{hcRecursionDesired: true, hcDomain: "."}, ""}, + {`forward . ::1 + forward com ::2`, false, ".", nil, 2, options{hcRecursionDesired: true, hcDomain: "."}, "plugin"}, + // negative + {"forward . a27.0.0.1", true, "", nil, 0, options{hcRecursionDesired: true, hcDomain: "."}, "not an IP"}, + {"forward . 127.0.0.1 {\nblaatl\n}\n", true, "", nil, 0, options{hcRecursionDesired: true, hcDomain: "."}, "unknown property"}, + {"forward . 127.0.0.1 {\nhealth_check 0.5s domain\n}\n", true, "", nil, 0, options{hcRecursionDesired: true, hcDomain: "."}, "Wrong argument count or unexpected line ending after 'domain'"}, + {"forward . https://127.0.0.1 \n", true, ".", nil, 2, options{hcRecursionDesired: true, hcDomain: "."}, "'https' is not supported as a destination protocol in forward: https://127.0.0.1"}, + {"forward xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx 127.0.0.1 \n", true, ".", nil, 2, options{hcRecursionDesired: true, hcDomain: "."}, "unable to normalize 'xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx'"}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + fs, err := parseForward(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Fatalf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErr) { + t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input) + } + } + + if !test.shouldErr { + f := fs[0] + if f.from != test.expectedFrom { + t.Errorf("Test %d: expected: %s, got: %s", i, test.expectedFrom, f.from) + } + if test.expectedIgnored != nil { + if !reflect.DeepEqual(f.ignored, test.expectedIgnored) { + t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedIgnored, f.ignored) + } + } + if f.maxfails != test.expectedFails { + t.Errorf("Test %d: expected: %d, got: %d", i, test.expectedFails, f.maxfails) + } + if f.opts != test.expectedOpts { + t.Errorf("Test %d: expected: %v, got: %v", i, test.expectedOpts, f.opts) + } + } + } +} + +func TestSetupTLS(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedServerName string + expectedErr string + }{ + // positive + {`forward . tls://127.0.0.1 { + tls_servername dns + }`, false, "dns", ""}, + {`forward . 127.0.0.1 { + tls_servername dns + }`, false, "", ""}, + {`forward . 127.0.0.1 { + tls + }`, false, "", ""}, + {`forward . tls://127.0.0.1`, false, "", ""}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + fs, err := parseForward(c) + f := fs[0] + + if test.shouldErr && err == nil { + t.Errorf("Test %d: expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErr) { + t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input) + } + } + + if !test.shouldErr && test.expectedServerName != "" && test.expectedServerName != f.tlsConfig.ServerName { + t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, f.tlsConfig.ServerName) + } + + if !test.shouldErr && test.expectedServerName != "" && test.expectedServerName != f.proxies[0].health.(*dnsHc).c.TLSConfig.ServerName { + t.Errorf("Test %d: expected: %q, actual: %q", i, test.expectedServerName, f.proxies[0].health.(*dnsHc).c.TLSConfig.ServerName) + } + } +} + +func TestSetupResolvconf(t *testing.T) { + const resolv = "resolv.conf" + if err := os.WriteFile(resolv, + []byte(`nameserver 10.10.255.252 +nameserver 10.10.255.253`), 0666); err != nil { + t.Fatalf("Failed to write resolv.conf file: %s", err) + } + defer os.Remove(resolv) + + tests := []struct { + input string + shouldErr bool + expectedErr string + expectedNames []string + }{ + // pass + {`forward . ` + resolv, false, "", []string{"10.10.255.252:53", "10.10.255.253:53"}}, + // fail + {`forward . /dev/null`, true, "no nameservers", nil}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + fs, err := parseForward(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: expected error but found %s for input %s", i, err, test.input) + continue + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErr) { + t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input) + } + } + + if test.shouldErr { + continue + } + + f := fs[0] + for j, n := range test.expectedNames { + addr := f.proxies[j].addr + if n != addr { + t.Errorf("Test %d, expected %q, got %q", j, n, addr) + } + } + + for _, p := range f.proxies { + p.health.Check(p) // this should almost always err, we don't care it shouldn't crash + } + } +} + +func TestSetupMaxConcurrent(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedVal int64 + expectedErr string + }{ + // positive + {"forward . 127.0.0.1 {\nmax_concurrent 1000\n}\n", false, 1000, ""}, + // negative + {"forward . 127.0.0.1 {\nmax_concurrent many\n}\n", true, 0, "invalid"}, + {"forward . 127.0.0.1 {\nmax_concurrent -4\n}\n", true, 0, "negative"}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + fs, err := parseForward(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err) + } + + if !strings.Contains(err.Error(), test.expectedErr) { + t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input) + } + } + + if test.shouldErr { + continue + } + f := fs[0] + if f.maxConcurrent != test.expectedVal { + t.Errorf("Test %d: expected: %d, got: %d", i, test.expectedVal, f.maxConcurrent) + } + } +} + +func TestSetupHealthCheck(t *testing.T) { + tests := []struct { + input string + shouldErr bool + expectedRecVal bool + expectedDomain string + expectedErr string + }{ + // positive + {"forward . 127.0.0.1\n", false, true, ".", ""}, + {"forward . 127.0.0.1 {\nhealth_check 0.5s\n}\n", false, true, ".", ""}, + {"forward . 127.0.0.1 {\nhealth_check 0.5s no_rec\n}\n", false, false, ".", ""}, + {"forward . 127.0.0.1 {\nhealth_check 0.5s no_rec domain example.org\n}\n", false, false, "example.org.", ""}, + {"forward . 127.0.0.1 {\nhealth_check 0.5s domain example.org\n}\n", false, true, "example.org.", ""}, + {"forward . 127.0.0.1 {\nhealth_check 0.5s domain .\n}\n", false, true, ".", ""}, + {"forward . 127.0.0.1 {\nhealth_check 0.5s domain example.org.\n}\n", false, true, "example.org.", ""}, + // negative + {"forward . 127.0.0.1 {\nhealth_check no_rec\n}\n", true, true, ".", "time: invalid duration"}, + {"forward . 127.0.0.1 {\nhealth_check domain example.org\n}\n", true, true, "example.org", "time: invalid duration"}, + {"forward . 127.0.0.1 {\nhealth_check 0.5s rec\n}\n", true, true, ".", "health_check: unknown option rec"}, + {"forward . 127.0.0.1 {\nhealth_check 0.5s domain\n}\n", true, true, ".", "Wrong argument count or unexpected line ending after 'domain'"}, + {"forward . 127.0.0.1 {\nhealth_check 0.5s domain example..org\n}\n", true, true, ".", "health_check: invalid domain name"}, + } + + for i, test := range tests { + c := caddy.NewTestController("dns", test.input) + fs, err := parseForward(c) + + if test.shouldErr && err == nil { + t.Errorf("Test %d: expected error but found %s for input %s", i, err, test.input) + } + + if err != nil { + if !test.shouldErr { + t.Errorf("Test %d: expected no error but found one for input %s, got: %v", i, test.input, err) + } + if !strings.Contains(err.Error(), test.expectedErr) { + t.Errorf("Test %d: expected error to contain: %v, found error: %v, input: %s", i, test.expectedErr, err, test.input) + } + } + + if test.shouldErr { + continue + } + + f := fs[0] + if f.opts.hcRecursionDesired != test.expectedRecVal || f.proxies[0].health.GetRecursionDesired() != test.expectedRecVal || + f.opts.hcDomain != test.expectedDomain || f.proxies[0].health.GetDomain() != test.expectedDomain || !dns.IsFqdn(f.proxies[0].health.GetDomain()) { + t.Errorf("Test %d: expectedRec: %v, got: %v. expectedDomain: %s, got: %s. ", i, test.expectedRecVal, f.opts.hcRecursionDesired, test.expectedDomain, f.opts.hcDomain) + } + } +} + +func TestMultiForward(t *testing.T) { + input := ` + forward 1st.example.org 10.0.0.1 + forward 2nd.example.org 10.0.0.2 + forward 3rd.example.org 10.0.0.3 + ` + + c := caddy.NewTestController("dns", input) + setup(c) + dnsserver.NewServer("", []*dnsserver.Config{dnsserver.GetConfig(c)}) + + handlers := dnsserver.GetConfig(c).Handlers() + f1, ok := handlers[0].(*Forward) + if !ok { + t.Fatalf("expected first plugin to be Forward, got %v", reflect.TypeOf(f1.Next)) + } + + if f1.from != "1st.example.org." { + t.Errorf("expected first forward from \"1st.example.org.\", got %q", f1.from) + } + if f1.Next == nil { + t.Fatal("expected first forward to point to next forward instance, not nil") + } + + f2, ok := f1.Next.(*Forward) + if !ok { + t.Fatalf("expected second plugin to be Forward, got %v", reflect.TypeOf(f1.Next)) + } + if f2.from != "2nd.example.org." { + t.Errorf("expected second forward from \"2nd.example.org.\", got %q", f2.from) + } + if f2.Next == nil { + t.Fatal("expected second forward to point to third forward instance, got nil") + } + + f3, ok := f2.Next.(*Forward) + if !ok { + t.Fatalf("expected third plugin to be Forward, got %v", reflect.TypeOf(f2.Next)) + } + if f3.from != "3rd.example.org." { + t.Errorf("expected third forward from \"3rd.example.org.\", got %q", f3.from) + } + if f3.Next != nil { + t.Error("expected third plugin to be last, but Next is not nil") + } +} diff --git a/go/plugin/coredns/forward/type.go b/go/plugin/coredns/forward/type.go new file mode 100644 index 000000000000..9de842fbeaf4 --- /dev/null +++ b/go/plugin/coredns/forward/type.go @@ -0,0 +1,37 @@ +package forward + +import "net" + +type transportType int + +const ( + typeUDP transportType = iota + typeTCP + typeTLS + typeTotalCount // keep this last +) + +func stringToTransportType(s string) transportType { + switch s { + case "udp": + return typeUDP + case "tcp": + return typeTCP + case "tcp-tls": + return typeTLS + } + + return typeUDP +} + +func (t *Transport) transportTypeFromConn(pc *persistConn) transportType { + if _, ok := pc.c.Conn.(*net.UDPConn); ok { + return typeUDP + } + + if t.tlsConfig == nil { + return typeTCP + } + + return typeTLS +}