Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow non root resources #454

Merged
merged 10 commits into from
May 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 13 additions & 0 deletions go/grpcweb/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@ import (
"fmt"
"net/http"
"net/url"
"regexp"
"strings"

"google.golang.org/grpc"
)

var pathMatcher = regexp.MustCompile(`/[^/]*/[^/]*$`)

// ListGRPCResources is a helper function that lists all URLs that are registered on gRPC server.
//
// This makes it easy to register all the relevant routes in your HTTP router of choice.
Expand All @@ -35,3 +39,12 @@ func WebsocketRequestOrigin(req *http.Request) (string, error) {
}
return parsed.Host, nil
}

func getGRPCEndpoint(req *http.Request) string {
endpoint := pathMatcher.FindString(strings.TrimRight(req.URL.Path, "/"))
if len(endpoint) == 0 {
return req.URL.Path
}

return endpoint
}
29 changes: 29 additions & 0 deletions go/grpcweb/helpers_internal_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package grpcweb

import (
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
)

func TestGetGRPCEndpoint(t *testing.T) {
cases := []struct {
input string
output string
}{
{input: "/", output: "/"},
{input: "/resource", output: "/resource"},
{input: "/improbable.grpcweb.test.TestService/PingEmpty", output: "/improbable.grpcweb.test.TestService/PingEmpty"},
{input: "/improbable.grpcweb.test.TestService/PingEmpty/", output: "/improbable.grpcweb.test.TestService/PingEmpty"},
{input: "/a/b/c/improbable.grpcweb.test.TestService/PingEmpty", output: "/improbable.grpcweb.test.TestService/PingEmpty"},
{input: "/a/b/c/improbable.grpcweb.test.TestService/PingEmpty/", output: "/improbable.grpcweb.test.TestService/PingEmpty"},
}

for _, c := range cases {
req := httptest.NewRequest("GET", c.input, nil)
result := getGRPCEndpoint(req)

assert.Equal(t, c.output, result)
}
}
15 changes: 15 additions & 0 deletions go/grpcweb/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ var (
allowedRequestHeaders: []string{"*"},
corsForRegisteredEndpointsOnly: true,
originFunc: func(origin string) bool { return false },
allowNonRootResources: false,
}
)

Expand All @@ -19,6 +20,7 @@ type options struct {
originFunc func(origin string) bool
enableWebsockets bool
websocketOriginFunc func(req *http.Request) bool
allowNonRootResources bool
}

func evaluateOptions(opts []Option) *options {
Expand Down Expand Up @@ -99,3 +101,16 @@ func WithWebsocketOriginFunc(websocketOriginFunc func(req *http.Request) bool) O
o.websocketOriginFunc = websocketOriginFunc
}
}

// WithAllowNonRootResource enables the gRPC wrapper to serve requests that have a path prefix
// added to the URL, before the service name and method placeholders.
//
// This should be set to false when exposing the endpoint as the root resource, to avoid
// the performance cost of path processing for every request.
//
// The default behaviour is `false`, i.e. always serves requests assuming there is no prefix to the gRPC endpoint.
johanbrandhorst marked this conversation as resolved.
Show resolved Hide resolved
func WithAllowNonRootResource(allowNonRootResources bool) Option {
johanbrandhorst marked this conversation as resolved.
Show resolved Hide resolved
return func(o *options) {
o.allowNonRootResources = allowNonRootResources
}
}
15 changes: 14 additions & 1 deletion go/grpcweb/wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ type WrappedGrpcServer struct {
originFunc func(origin string) bool
enableWebsockets bool
websocketOriginFunc func(req *http.Request) bool
endpointFunc func(req *http.Request) string
}

// WrapServer takes a gRPC Server in Go and returns a WrappedGrpcServer that provides gRPC-Web Compatibility.
Expand All @@ -56,13 +57,23 @@ func WrapServer(server *grpc.Server, options ...Option) *WrappedGrpcServer {
if websocketOriginFunc == nil {
websocketOriginFunc = defaultWebsocketOriginFunc
}

endpointFunc := func(req *http.Request) string {
return req.URL.Path
}

if opts.allowNonRootResources {
endpointFunc = getGRPCEndpoint
}

return &WrappedGrpcServer{
server: server,
opts: opts,
corsWrapper: corsWrapper,
originFunc: opts.originFunc,
enableWebsockets: opts.enableWebsockets,
websocketOriginFunc: websocketOriginFunc,
endpointFunc: endpointFunc,
}
}

Expand Down Expand Up @@ -105,6 +116,7 @@ func (w *WrappedGrpcServer) IsGrpcWebSocketRequest(req *http.Request) bool {
func (w *WrappedGrpcServer) HandleGrpcWebRequest(resp http.ResponseWriter, req *http.Request) {
intReq, isTextFormat := hackIntoNormalGrpcRequest(req)
intResp := newGrpcWebResponse(resp, isTextFormat)
req.URL.Path = w.endpointFunc(req)
w.server.ServeHTTP(intResp, intReq)
intResp.finishRequest(req)
}
Expand Down Expand Up @@ -161,6 +173,7 @@ func (w *WrappedGrpcServer) handleWebSocket(wsConn *websocket.Conn, req *http.Re
grpclog.Errorf("web socket text format requests not yet supported")
return
}
req.URL.Path = w.endpointFunc(req)
w.server.ServeHTTP(respWriter, interceptedRequest)
}

Expand All @@ -187,7 +200,7 @@ func (w *WrappedGrpcServer) IsAcceptableGrpcCorsRequest(req *http.Request) bool

func (w *WrappedGrpcServer) isRequestForRegisteredEndpoint(req *http.Request) bool {
registeredEndpoints := ListGRPCResources(w.server)
requestedEndpoint := req.URL.Path
requestedEndpoint := w.endpointFunc(req)
for _, v := range registeredEndpoints {
if v == requestedEndpoint {
return true
Expand Down
21 changes: 21 additions & 0 deletions go/grpcweb/wrapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"log"
"net"
"net/http"
"net/http/httptest"
"net/textproto"
"os"
"strconv"
Expand Down Expand Up @@ -67,6 +68,26 @@ func TestHttp1GrpcWebWrapperTestSuite(t *testing.T) {
suite.Run(t, &GrpcWebWrapperTestSuite{httpMajorVersion: 1})
}

func TestNonRootResource(t *testing.T) {
grpcServer := grpc.NewServer()
testproto.RegisterTestServiceServer(grpcServer, &testServiceImpl{})
wrappedServer := grpcweb.WrapServer(grpcServer,
grpcweb.WithAllowNonRootResource(true),
grpcweb.WithOriginFunc(func(origin string) bool {
return true
}))

headers := http.Header{}
headers.Add("Access-Control-Request-Method", "POST")
headers.Add("Access-Control-Request-Headers", "origin, x-something-custom, x-grpc-web, accept")
req := httptest.NewRequest("OPTIONS", "http://host/grpc/improbable.grpcweb.test.TestService/Echo", nil)
req.Header = headers
resp := httptest.NewRecorder()
wrappedServer.ServeHTTP(resp, req)

assert.Equal(t, http.StatusOK, resp.Code)
}

func (s *GrpcWebWrapperTestSuite) SetupTest() {
var err error
s.grpcServer = grpc.NewServer()
Expand Down