Skip to content

Commit 0898d9e

Browse files
committed
Exposed proxy balancer interface
Signed-off-by: Vishal Rana <vr@labstack.com>
1 parent a8cd0ad commit 0898d9e

File tree

3 files changed

+87
-93
lines changed

3 files changed

+87
-93
lines changed

echo.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ const (
165165

166166
// Headers
167167
const (
168+
HeaderAccept = "Accept"
168169
HeaderAcceptEncoding = "Accept-Encoding"
169170
HeaderAllow = "Allow"
170171
HeaderAuthorization = "Authorization"

middleware/proxy.go

Lines changed: 66 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
11
package middleware
22

33
import (
4+
"errors"
5+
"fmt"
46
"io"
57
"math/rand"
8+
"net"
69
"net/http"
710
"net/http/httputil"
811
"net/url"
912
"sync/atomic"
1013
"time"
1114

1215
"github.com/labstack/echo"
13-
"golang.org/x/net/websocket"
1416
)
1517

1618
type (
@@ -19,91 +21,95 @@ type (
1921
// Skipper defines a function to skip middleware.
2022
Skipper Skipper
2123

22-
// Load balancing technique.
23-
// Optional. Default value "random".
24-
// Possible values:
25-
// - "random"
26-
// - "round-robin"
27-
Balance string `json:"balance"`
28-
29-
// Upstream target URLs
24+
// Balance defines a load balancing technique.
3025
// Required.
31-
Targets []*ProxyTarget `json:"targets"`
32-
33-
balancer proxyBalancer
26+
// Possible values:
27+
// - ProxyRandom
28+
// - ProxyRoundRobin
29+
Balancer ProxyBalancer
3430
}
3531

3632
// ProxyTarget defines the upstream target.
3733
ProxyTarget struct {
38-
Name string `json:"name,omitempty"`
39-
URL string `json:"url"`
40-
url *url.URL
34+
URL *url.URL
4135
}
4236

43-
proxyRandom struct {
44-
targets []*ProxyTarget
37+
RandomBalancer struct {
38+
Targets []*ProxyTarget
4539
random *rand.Rand
4640
}
4741

48-
proxyRoundRobin struct {
49-
targets []*ProxyTarget
50-
i int32
42+
RoundRobinBalancer struct {
43+
Targets []*ProxyTarget
44+
i uint32
5145
}
5246

53-
proxyBalancer interface {
47+
ProxyBalancer interface {
5448
Next() *ProxyTarget
55-
Length() int
5649
}
5750
)
5851

59-
func proxyHTTP(u *url.URL, c echo.Context) http.Handler {
60-
return httputil.NewSingleHostReverseProxy(u)
52+
func proxyHTTP(t *ProxyTarget) http.Handler {
53+
return httputil.NewSingleHostReverseProxy(t.URL)
6154
}
6255

63-
func proxyWS(u *url.URL, c echo.Context) http.Handler {
64-
return websocket.Handler(func(in *websocket.Conn) {
56+
func proxyRaw(t *ProxyTarget, c echo.Context) http.Handler {
57+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
58+
h, ok := w.(http.Hijacker)
59+
if !ok {
60+
c.Error(errors.New("proxy raw, not a hijacker"))
61+
return
62+
}
63+
64+
in, _, err := h.Hijack()
65+
if err != nil {
66+
c.Error(fmt.Errorf("proxy raw hijack error=%v, url=%s", r.URL, err))
67+
return
68+
}
6569
defer in.Close()
6670

67-
r := in.Request()
68-
t := "ws://" + u.Host + r.RequestURI
69-
out, err := websocket.Dial(t, "", r.Header.Get("Origin"))
71+
out, err := net.Dial("tcp", t.URL.Host)
7072
if err != nil {
71-
c.Logger().Errorf("ws proxy error, target=%s, err=%v", t, err)
73+
he := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw dial error=%v, url=%s", r.URL, err))
74+
c.Error(he)
7275
return
7376
}
7477
defer out.Close()
7578

79+
err = r.Write(out)
80+
if err != nil {
81+
he := echo.NewHTTPError(http.StatusBadGateway, fmt.Sprintf("proxy raw request copy error=%v, url=%s", r.URL, err))
82+
c.Error(he)
83+
return
84+
}
85+
7686
errc := make(chan error, 2)
77-
cp := func(w io.Writer, r io.Reader) {
78-
_, err := io.Copy(w, r)
87+
cp := func(dst io.Writer, src io.Reader) {
88+
_, err := io.Copy(dst, src)
7989
errc <- err
8090
}
8191

82-
go cp(in, out)
8392
go cp(out, in)
93+
go cp(in, out)
8494
err = <-errc
8595
if err != nil && err != io.EOF {
86-
c.Logger().Errorf("ws proxy error, url=%s, err=%v", r.URL, err)
96+
c.Logger().Errorf("proxy raw error=%v, url=%s", r.URL, err)
8797
}
8898
})
8999
}
90100

91-
func (r *proxyRandom) Next() *ProxyTarget {
92-
return r.targets[r.random.Intn(len(r.targets))]
93-
}
94-
95-
func (r *proxyRandom) Length() int {
96-
return len(r.targets)
97-
}
98-
99-
func (r *proxyRoundRobin) Next() *ProxyTarget {
100-
r.i = r.i % int32(len(r.targets))
101-
atomic.AddInt32(&r.i, 1)
102-
return r.targets[r.i]
101+
func (r *RandomBalancer) Next() *ProxyTarget {
102+
if r.random == nil {
103+
r.random = rand.New(rand.NewSource(int64(time.Now().Nanosecond())))
104+
}
105+
return r.Targets[r.random.Intn(len(r.Targets))]
103106
}
104107

105-
func (r *proxyRoundRobin) Length() int {
106-
return len(r.targets)
108+
func (r *RoundRobinBalancer) Next() *ProxyTarget {
109+
r.i = r.i % uint32(len(r.Targets))
110+
t := r.Targets[r.i]
111+
atomic.AddUint32(&r.i, 1)
112+
return t
107113
}
108114

109115
// Proxy returns an HTTP/WebSocket reverse proxy middleware.
@@ -112,49 +118,26 @@ func Proxy(config ProxyConfig) echo.MiddlewareFunc {
112118
if config.Skipper == nil {
113119
config.Skipper = DefaultLoggerConfig.Skipper
114120
}
115-
if config.Targets == nil || len(config.Targets) == 0 {
116-
panic("echo: proxy middleware requires targets")
117-
}
118-
119-
// Initialize
120-
for _, t := range config.Targets {
121-
u, err := url.Parse(t.URL)
122-
if err != nil {
123-
panic("echo: proxy target url parsing failed" + err.Error())
124-
}
125-
t.url = u
126-
}
127-
128-
// Balancer
129-
switch config.Balance {
130-
case "round-robin":
131-
config.balancer = &proxyRoundRobin{
132-
targets: config.Targets,
133-
i: -1,
134-
}
135-
default: // random
136-
config.balancer = &proxyRandom{
137-
targets: config.Targets,
138-
random: rand.New(rand.NewSource(int64(time.Now().Nanosecond()))),
139-
}
121+
if config.Balancer == nil {
122+
panic("echo: proxy middleware requires balancer")
140123
}
141124

142125
return func(next echo.HandlerFunc) echo.HandlerFunc {
143126
return func(c echo.Context) (err error) {
144127
req := c.Request()
145128
res := c.Response()
146-
t := config.balancer.Next().url
147-
148-
// Tell upstream that the incoming request is HTTPS
149-
if c.IsTLS() {
150-
req.Header.Set(echo.HeaderXForwardedProto, "https")
151-
}
129+
t := config.Balancer.Next()
152130

153131
// Proxy
154-
if req.Header.Get(echo.HeaderUpgrade) == "websocket" {
155-
proxyWS(t, c).ServeHTTP(res, req)
156-
} else {
157-
proxyHTTP(t, c).ServeHTTP(res, req)
132+
upgrade := req.Header.Get(echo.HeaderUpgrade)
133+
accept := req.Header.Get(echo.HeaderAccept)
134+
135+
switch {
136+
case upgrade == "websocket" || upgrade == "Websocket":
137+
proxyRaw(t, c).ServeHTTP(res, req)
138+
case accept == "text/event-stream":
139+
default:
140+
proxyHTTP(t).ServeHTTP(res, req)
158141
}
159142

160143
return

middleware/proxy_test.go

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"net/http/httptest"
77
"testing"
88

9+
"net/url"
10+
911
"github.com/labstack/echo"
1012
"github.com/stretchr/testify/assert"
1113
)
@@ -38,18 +40,24 @@ func TestProxy(t *testing.T) {
3840
fmt.Fprint(w, "target 1")
3941
}))
4042
defer t1.Close()
43+
url1, _ := url.Parse(t1.URL)
4144
t2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
4245
fmt.Fprint(w, "target 2")
4346
}))
4447
defer t2.Close()
48+
url2, _ := url.Parse(t2.URL)
49+
50+
targets := []*ProxyTarget{
51+
&ProxyTarget{
52+
URL: url1,
53+
},
54+
&ProxyTarget{
55+
URL: url2,
56+
},
57+
}
4558
config := ProxyConfig{
46-
Targets: []*ProxyTarget{
47-
&ProxyTarget{
48-
URL: t1.URL,
49-
},
50-
&ProxyTarget{
51-
URL: t2.URL,
52-
},
59+
Balancer: &RandomBalancer{
60+
Targets: targets,
5361
},
5462
}
5563

@@ -60,16 +68,18 @@ func TestProxy(t *testing.T) {
6068
rec := newCloseNotifyRecorder()
6169
e.ServeHTTP(rec, req)
6270
body := rec.Body.String()
63-
targets := map[string]bool{
71+
expected := map[string]bool{
6472
"target 1": true,
6573
"target 2": true,
6674
}
6775
assert.Condition(t, func() bool {
68-
return targets[body]
76+
return expected[body]
6977
})
7078

7179
// Round-robin
72-
config.Balance = "round-robin"
80+
config.Balancer = &RoundRobinBalancer{
81+
Targets: targets,
82+
}
7383
e = echo.New()
7484
e.Use(Proxy(config))
7585
rec = newCloseNotifyRecorder()

0 commit comments

Comments
 (0)