Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

reverse_proxy: Add support for SRV backends #3180

Merged
merged 2 commits into from Mar 24, 2020
Merged
Changes from all commits
Commits
File filter...
Filter file types
Jump to…
Jump to file
Failed to load files.

Always

Just for now

@@ -41,7 +41,7 @@ type (
// especially A/AAAA pointed at your server.
//
// Automatic HTTPS can be
// [customized or disabled](/docs/json/apps/http/servers/automatic_https/).
// [customized or disabled](/docs/modules/http#servers/automatic_https).
MatchHost []string

// MatchPath matches requests by the URI's path (case-insensitive). Path
@@ -177,13 +177,36 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
return net.JoinHostPort(host, port), nil
}

// appendUpstream creates an upstream for address and adds
// it to the list. If the address starts with "srv+" it is
// treated as a SRV-based upstream, and any port will be
// dropped.
appendUpstream := func(address string) error {
isSRV := strings.HasPrefix(address, "srv+")
if isSRV {
address = strings.TrimPrefix(address, "srv+")
}
dialAddr, err := upstreamDialAddress(address)
if err != nil {
return err
}
if isSRV {
if host, _, err := net.SplitHostPort(dialAddr); err == nil {
dialAddr = host
}
h.Upstreams = append(h.Upstreams, &Upstream{LookupSRV: dialAddr})
} else {
h.Upstreams = append(h.Upstreams, &Upstream{Dial: dialAddr})
}
return nil
}

for d.Next() {
for _, up := range d.RemainingArgs() {
dialAddr, err := upstreamDialAddress(up)
err := appendUpstream(up)
if err != nil {
return err
}
h.Upstreams = append(h.Upstreams, &Upstream{Dial: dialAddr})
}

for d.NextBlock(0) {
@@ -194,11 +217,10 @@ func (h *Handler) UnmarshalCaddyfile(d *caddyfile.Dispenser) error {
return d.ArgErr()
}
for _, up := range args {
dialAddr, err := upstreamDialAddress(up)
err := appendUpstream(up)
if err != nil {
return err
}
h.Upstreams = append(h.Upstreams, &Upstream{Dial: dialAddr})
}

case "lb_policy":
@@ -17,6 +17,8 @@ package reverseproxy
import (
"context"
"fmt"
"net"
"net/http"
"strconv"
"sync/atomic"

@@ -63,10 +65,10 @@ type UpstreamPool []*Upstream
type Upstream struct {
Host `json:"-"`

// The [network address](/docs/json/apps/http/#servers/listen)
// The [network address](/docs/conventions#network-addresses)
// to dial to connect to the upstream. Must represent precisely
// one socket (i.e. no port ranges). A valid network address
// either has a host and port, or is a unix socket address.
// either has a host and port or is a unix socket address.
//
// Placeholders may be used to make the upstream dynamic, but be
// aware of the health check implications of this: a single
@@ -75,6 +77,11 @@ type Upstream struct {
// backends is down. Also be aware of open proxy vulnerabilities.
Dial string `json:"dial,omitempty"`

// If DNS SRV records are used for service discovery with this
// upstream, specify the DNS name for which to look up SRV
// records here, instead of specifying a dial address.
LookupSRV string `json:"lookup_srv,omitempty"`

// The maximum number of simultaneous requests to allow to
// this upstream. If set, overrides the global passive health
// check UnhealthyRequestCount value.
@@ -118,6 +125,47 @@ func (u *Upstream) Full() bool {
return u.MaxRequests > 0 && u.Host.NumRequests() >= u.MaxRequests
}

// fillDialInfo returns a filled DialInfo for upstream u, using the request
// context. If the upstream has a SRV lookup configured, that is done and a
// returned address is chosen; otherwise, the upstream's regular dial address
// field is used. Note that the returned value is not a pointer.
func (u *Upstream) fillDialInfo(r *http.Request) (DialInfo, error) {
repl := r.Context().Value(caddy.ReplacerCtxKey).(*caddy.Replacer)
var addr caddy.ParsedAddress

if u.LookupSRV != "" {
// perform DNS lookup for SRV records and choose one
srvName := repl.ReplaceAll(u.LookupSRV, "")
_, records, err := net.DefaultResolver.LookupSRV(r.Context(), "", "", srvName)
if err != nil {
return DialInfo{}, err
}
addr.Network = "tcp"
addr.Host = records[0].Target
addr.StartPort, addr.EndPort = uint(records[0].Port), uint(records[0].Port)
} else {
// use provided dial address
var err error
dial := repl.ReplaceAll(u.Dial, "")
addr, err = caddy.ParseNetworkAddress(dial)
if err != nil {
return DialInfo{}, fmt.Errorf("upstream %s: invalid dial address %s: %v", u.Dial, dial, err)
}
if numPorts := addr.PortRangeSize(); numPorts != 1 {
return DialInfo{}, fmt.Errorf("upstream %s: dial address must represent precisely one socket: %s represents %d",
u.Dial, dial, numPorts)
}
}

return DialInfo{
Upstream: u,
Network: addr.Network,
Address: addr.JoinHostPort(0),
Host: addr.Host,
Port: strconv.Itoa(int(addr.StartPort)),
}, nil
}

// upstreamHost is the basic, in-memory representation
// of the state of a remote host. It implements the
// Host interface.
@@ -204,27 +252,6 @@ func (di DialInfo) String() string {
return caddy.JoinNetworkAddress(di.Network, di.Host, di.Port)
}

// fillDialInfo returns a filled DialInfo for the given upstream, using
// the given Replacer. Note that the returned value is not a pointer.
func fillDialInfo(upstream *Upstream, repl *caddy.Replacer) (DialInfo, error) {
dial := repl.ReplaceAll(upstream.Dial, "")
addr, err := caddy.ParseNetworkAddress(dial)
if err != nil {
return DialInfo{}, fmt.Errorf("upstream %s: invalid dial address %s: %v", upstream.Dial, dial, err)
}
if numPorts := addr.PortRangeSize(); numPorts != 1 {
return DialInfo{}, fmt.Errorf("upstream %s: dial address must represent precisely one socket: %s represents %d",
upstream.Dial, dial, numPorts)
}
return DialInfo{
Upstream: upstream,
Network: addr.Network,
Address: addr.JoinHostPort(0),
Host: addr.Host,
Port: strconv.Itoa(int(addr.StartPort)),
}, nil
}

// GetDialInfo gets the upstream dialing info out of the context,
// and returns true if there was a valid value; false otherwise.
func GetDialInfo(ctx context.Context) (DialInfo, bool) {
@@ -313,7 +313,7 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request, next caddyht
// the dial address may vary per-request if placeholders are
// used, so perform those replacements here; the resulting
// DialInfo struct should have valid network address syntax
dialInfo, err := fillDialInfo(upstream, repl)
dialInfo, err := upstream.fillDialInfo(r)
if err != nil {
return fmt.Errorf("making dial info: %v", err)
}
ProTip! Use n and p to navigate between commits in a pull request.
You can’t perform that action at this time.