Skip to content

Commit

Permalink
proxy: Add SRV support for proxy upstream (#1915)
Browse files Browse the repository at this point in the history
* Simplify parseUpstream function

* Add SRV support for proxy upstream
  • Loading branch information
Mohammad Gufran authored and mholt committed Nov 6, 2017
1 parent 5cca9cc commit 63fd264
Show file tree
Hide file tree
Showing 5 changed files with 518 additions and 82 deletions.
3 changes: 2 additions & 1 deletion caddyhttp/proxy/proxy.go
Expand Up @@ -82,7 +82,8 @@ type UpstreamHost struct {
// This is an int32 so that we can use atomic operations to do concurrent
// reads & writes to this value. The default value of 0 indicates that it
// is healthy and any non-zero value indicates unhealthy.
Unhealthy int32
Unhealthy int32
HealthCheckResult atomic.Value
}

// Down checks whether the upstream host is down or not.
Expand Down
46 changes: 40 additions & 6 deletions caddyhttp/proxy/reverseproxy.go
Expand Up @@ -26,7 +26,9 @@
package proxy

import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/http"
Expand Down Expand Up @@ -91,6 +93,8 @@ type ReverseProxy struct {
// response body.
// If zero, no periodic flushing is done.
FlushInterval time.Duration

srvResolver srvResolver
}

// Though the relevant directive prefix is just "unix:", url.Parse
Expand All @@ -105,6 +109,23 @@ func socketDial(hostName string) func(network, addr string) (conn net.Conn, err
}
}

func (rp *ReverseProxy) srvDialerFunc(locator string) func(network, addr string) (conn net.Conn, err error) {
service := locator
if strings.HasPrefix(locator, "srv://") {
service = locator[6:]
} else if strings.HasPrefix(locator, "srv+https://") {
service = locator[12:]
}

return func(network, addr string) (conn net.Conn, err error) {
_, addrs, err := rp.srvResolver.LookupSRV(context.Background(), "", "", service)
if err != nil {
return nil, err
}
return net.Dial("tcp", fmt.Sprintf("%s:%d", addrs[0].Target, addrs[0].Port))
}
}

func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
Expand All @@ -131,6 +152,12 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
// scheme and host have to be faked
req.URL.Scheme = "http"
req.URL.Host = "socket"
} else if target.Scheme == "srv" {
req.URL.Scheme = "http"
req.URL.Host = target.Host
} else if target.Scheme == "srv+https" {
req.URL.Scheme = "https"
req.URL.Host = target.Host
} else {
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
Expand Down Expand Up @@ -199,7 +226,12 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
}
}

rp := &ReverseProxy{Director: director, FlushInterval: 250 * time.Millisecond} // flushing good for streaming & server-sent events
rp := &ReverseProxy{
Director: director,
FlushInterval: 250 * time.Millisecond, // flushing good for streaming & server-sent events
srvResolver: net.DefaultResolver,
}

if target.Scheme == "unix" {
rp.Transport = &http.Transport{
Dial: socketDial(target.String()),
Expand All @@ -210,13 +242,15 @@ func NewSingleHostReverseProxy(target *url.URL, without string, keepalive int) *
HandshakeTimeout: defaultCryptoHandshakeTimeout,
},
}
} else if keepalive != http.DefaultMaxIdleConnsPerHost {
// if keepalive is equal to the default,
// just use default transport, to avoid creating
// a brand new transport
} else if keepalive != http.DefaultMaxIdleConnsPerHost || strings.HasPrefix(target.Scheme, "srv") {
dialFunc := defaultDialer.Dial
if strings.HasPrefix(target.Scheme, "srv") {
dialFunc = rp.srvDialerFunc(target.String())
}

transport := &http.Transport{
Proxy: http.ProxyFromEnvironment,
Dial: defaultDialer.Dial,
Dial: dialFunc,
TLSHandshakeTimeout: defaultCryptoHandshakeTimeout,
ExpectContinueTimeout: 1 * time.Second,
}
Expand Down
94 changes: 94 additions & 0 deletions caddyhttp/proxy/reverseproxy_test.go
@@ -0,0 +1,94 @@
// Copyright 2015 Light Code Labs, LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package proxy

import (
"net"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"testing"
)

const (
expectedResponse = "response from request proxied to upstream"
expectedStatus = http.StatusOK
)

var upstreamHost *httptest.Server

func setupTest() {
upstreamHost = httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
if r.URL.Path == "/test-path" {
w.WriteHeader(expectedStatus)
w.Write([]byte(expectedResponse))
} else {
w.WriteHeader(404)
w.Write([]byte("Not found"))
}
}))
}

func tearDownTest() {
upstreamHost.Close()
}

func TestSingleSRVHostReverseProxy(t *testing.T) {
setupTest()
defer tearDownTest()

target, err := url.Parse("srv://test.upstream.service")
if err != nil {
t.Errorf("Failed to parse target URL. %s", err.Error())
}

upstream, err := url.Parse(upstreamHost.URL)
if err != nil {
t.Errorf("Failed to parse test server URL [%s]. %s", upstreamHost.URL, err.Error())
}
pp, err := strconv.Atoi(upstream.Port())
if err != nil {
t.Errorf("Failed to parse upstream server port [%s]. %s", upstream.Port(), err.Error())
}
port := uint16(pp)

rp := NewSingleHostReverseProxy(target, "", http.DefaultMaxIdleConnsPerHost)
rp.srvResolver = testResolver{
result: []*net.SRV{
{Target: upstream.Hostname(), Port: port, Priority: 1, Weight: 1},
},
}

resp := httptest.NewRecorder()
req, err := http.NewRequest("GET", "http://test.host/test-path", nil)
if err != nil {
t.Errorf("Failed to create new request. %s", err.Error())
}

err = rp.ServeHTTP(resp, req, nil)
if err != nil {
t.Errorf("Failed to perform reverse proxy to upstream host. %s", err.Error())
}

if resp.Body.String() != expectedResponse {
t.Errorf("Unexpected proxy response received. Expected: '%s', Got: '%s'", expectedResponse, resp.Body.String())
}

if resp.Code != expectedStatus {
t.Errorf("Unexpected proxy status. Expected: '%d', Got: '%d'", expectedStatus, resp.Code)
}
}

0 comments on commit 63fd264

Please sign in to comment.