Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SRV support for proxy upstream #1915

Merged
merged 3 commits into from Nov 6, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 2 additions & 1 deletion caddyhttp/proxy/proxy.go
Expand Up @@ -82,7 +82,8 @@ type UpstreamHost struct {
// This is an int32 so that we can use atomic operations to do concurrent
// reads & writes to this value. The default value of 0 indicates that it
// is healthy and any non-zero value indicates unhealthy.
Unhealthy int32
Unhealthy int32
HealthCheckResult atomic.Value
}

// Down checks whether the upstream host is down or not.
Expand Down
46 changes: 40 additions & 6 deletions caddyhttp/proxy/reverseproxy.go
Expand Up @@ -26,7 +26,9 @@
package proxy

import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
Expand Down Expand Up @@ -91,6 +93,8 @@ type ReverseProxy struct {
// response body.
// If zero, no periodic flushing is done.
FlushInterval time.Duration

srvResolver srvResolver
}

// Though the relevant directive prefix is just "unix:", url.Parse
Expand All @@ -105,6 +109,23 @@ func socketDial(hostName string) func(network, addr string) (conn net.Conn, err
}
}

func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string) (conn net.Conn, err error) {
service := locator
if strings.HasPrefix(locator, "srv://") {
service = locator[6:]
} else if strings.HasPrefix(locator, "srv+https://") {
service = locator[12:]
}

return func(network, addr string) (conn net.Conn, err error) {
_, addrs, err := rp.srvResolver.LookupSRV(context.Background(), "", "", service)
if err != nil {
return nil, err
}
return net.Dial("tcp", fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port))
}
}

func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
Expand All @@ -131,6 +152,12 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
// scheme and host have to be faked
req.URL.Scheme = "http"
req.URL.Host = "socket"
} else if target.Scheme == "srv" {
req.URL.Scheme = "http"
req.URL.Host = target.Host
} else if target.Scheme == "srv+https" {
req.URL.Scheme = "https"
req.URL.Host = target.Host
} else {
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
Expand Down Expand Up @@ -199,7 +226,12 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
}
}

rp := &ReverseProxy{Director: director, FlushInterval: 250 * time.Millisecond} // flushing good for streaming & server-sent events
rp := &ReverseProxy{
Director: director,
FlushInterval: 250 * time.Millisecond, // flushing good for streaming & server-sent events
srvResolver: net.DefaultResolver,
}

if target.Scheme == "unix" {
rp.Transport = &http.Transport{
Dial: socketDial(target.String()),
Expand All @@ -210,13 +242,15 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
HandshakeTimeout: defaultCryptoHandshakeTimeout,
},
}
} else if keepalive != http.DefaultMaxIdleConnsPerHost {
// if keepalive is equal to the default,
// just use default transport, to avoid creating
// a brand new transport
} else if keepalive != http.DefaultMaxIdleConnsPerHost || strings.HasPrefix(target.Scheme, "srv") {
dialFunc := defaultDialer.Dial
if strings.HasPrefix(target.Scheme, "srv") {
dialFunc = rp.srvDialerFunc(target.String())
}

transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: defaultDialer.Dial,
Dial: dialFunc,
TLSHandshakeTimeout: defaultCryptoHandshakeTimeout,
ExpectContinueTimeout: 1 * time.Second,
}
Expand Down
94 changes: 94 additions & 0 deletions caddyhttp/proxy/reverseproxy_test.go
@@ -0,0 +1,94 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package proxy

import (
"net"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"testing"
)

const (
expectedResponse = "response from request proxied to upstream"
expectedStatus = http.StatusOK
)

var upstreamHost *httptest.Server

func setupTest() {
upstreamHost = httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/test-path" {
w.WriteHeader(expectedStatus)
w.Write([]byte(expectedResponse))
} else {
w.WriteHeader(404)
w.Write([]byte("Not found"))
}
}))
}

func tearDownTest() {
upstreamHost.Close()
}

func TestSingleSRVHostReverseProxy(t *testing.T) {
setupTest()
defer tearDownTest()

target, err := url.Parse("srv://test.upstream.service")
if err != nil {
t.Errorf("Failed to parse target URL. %s", err.Error())
}

upstream, err := url.Parse(upstreamHost.URL)
if err != nil {
t.Errorf("Failed to parse test server URL [%s]. %s", upstreamHost.URL, err.Error())
}
pp, err := strconv.Atoi(upstream.Port())
if err != nil {
t.Errorf("Failed to parse upstream server port [%s]. %s", upstream.Port(), err.Error())
}
port := uint16(pp)

rp := NewSingleHostReverseProxy(target, "", http.DefaultMaxIdleConnsPerHost)
rp.srvResolver = testResolver{
result: []*net.SRV{
{Target: upstream.Hostname(), Port: port, Priority: 1, Weight: 1},
},
}

resp := httptest.NewRecorder()
req, err := http.NewRequest("GET", "http://test.host/test-path", nil)
if err != nil {
t.Errorf("Failed to create new request. %s", err.Error())
}

err = rp.ServeHTTP(resp, req, nil)
if err != nil {
t.Errorf("Failed to perform reverse proxy to upstream host. %s", err.Error())
}

if resp.Body.String() != expectedResponse {
t.Errorf("Unexpected proxy response received. Expected: '%s', Got: '%s'", expectedResponse, resp.Body.String())
}

if resp.Code != expectedStatus {
t.Errorf("Unexpected proxy status. Expected: '%d', Got: '%d'", expectedStatus, resp.Code)
}
}