Skip to content

Commit

Permalink
Verify WS handshake success
Browse files Browse the repository at this point in the history
  • Loading branch information
magiconair committed Mar 22, 2018
1 parent 9e26812 commit 96f93e0
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
1 change: 0 additions & 1 deletion proxy/http_proxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ func (p *HTTPProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} else {
h = newRawProxy(targetURL.Host, net.Dial)
}
r.Header.Set("Connection", "close")

case accept == "text/event-stream":
// use the flush interval for SSE (server-sent events)
Expand Down
33 changes: 33 additions & 0 deletions proxy/http_raw_handler.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package proxy

import (
"bytes"
"fmt"
"io"
"log"
"net"
"net/http"
"time"

"github.com/fabiolb/fabio/metrics"
)
Expand Down Expand Up @@ -51,6 +54,36 @@ func newRawProxy(host string, dial dialFunc) http.Handler {
return
}

// read the initial response to check whether we get an HTTP/1.1 101 ... response
// to determine whether the handshake worked.
b := make([]byte, 1024)
if err := out.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
log.Printf("[ERROR] Error setting read timeout for %s: %s", r.URL, err)
http.Error(w, "error setting read timeout", http.StatusInternalServerError)
return
}

n, err := out.Read(b)
if err != nil {
log.Printf("[ERROR] Error reading response for %s: %s", r.URL, err)
http.Error(w, "error reading response", http.StatusInternalServerError)
return
}

b = b[:n]
if m, err := in.Write(b); err != nil || n != m {
log.Printf("[ERROR] Error sending header for %s: %s", r.URL, err)
http.Error(w, "error sending response", http.StatusInternalServerError)
return
}

if !bytes.HasPrefix(b, []byte("HTTP/1.1 101")) {
fmt.Println("boom")
log.Printf("[INFO] WS Upgrade failed for %s", r.URL)
http.Error(w, "error handling ws upgrade", http.StatusInternalServerError)
return
}

errc := make(chan error, 2)
cp := func(dst io.Writer, src io.Reader) {
_, err := io.Copy(dst, src)
Expand Down

0 comments on commit 96f93e0

Please sign in to comment.