Skip to content
Permalink
Browse files Browse the repository at this point in the history
Merge pull request from GHSA-95x7-mh78-7w2r
* fix: add interceptors for streaming endpoints

* add tests

* remove unnecessary code
  • Loading branch information
miparnisari committed Oct 21, 2022
1 parent c8db1ee commit 779d73d
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 70 deletions.
10 changes: 8 additions & 2 deletions pkg/cmd/service/service.go
Expand Up @@ -359,11 +359,16 @@ func BuildService(config *Config, logger logger.Logger) (*service, error) {
return nil, errors.Errorf("failed to initialize authenticator: %v", err)
}

interceptors := []grpc.UnaryServerInterceptor{
unaryServerInterceptors := []grpc.UnaryServerInterceptor{
grpc_auth.UnaryServerInterceptor(middleware.AuthFunc(authenticator)),
middleware.NewErrorLoggingInterceptor(logger),
}

streamingServerInterceptors := []grpc.StreamServerInterceptor{
grpc_auth.StreamServerInterceptor(middleware.AuthFunc(authenticator)),
middleware.NewStreamingErrorLoggingInterceptor(logger),
}

grpcHostAddr, grpcHostPort, err := net.SplitHostPort(config.GRPC.Addr)
if err != nil {
return nil, errors.Errorf("`grpc.addr` config must be in the form [host]:port")
Expand Down Expand Up @@ -415,7 +420,8 @@ func BuildService(config *Config, logger logger.Logger) (*service, error) {
ChangelogHorizonOffset: config.ChangelogHorizonOffset,
ListObjectsDeadline: config.ListObjectsDeadline,
ListObjectsMaxResults: config.ListObjectsMaxResults,
UnaryInterceptors: interceptors,
UnaryInterceptors: unaryServerInterceptors,
StreamingInterceptors: streamingServerInterceptors,
MuxOptions: nil,
})
if err != nil {
Expand Down
186 changes: 121 additions & 65 deletions pkg/cmd/service/service_test.go
Expand Up @@ -7,6 +7,7 @@ import (
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/json"
"encoding/pem"
"fmt"
"io"
Expand All @@ -24,6 +25,7 @@ import (
"github.com/hashicorp/go-retryablehttp"
"github.com/openfga/openfga/pkg/logger"
"github.com/openfga/openfga/server/authn/mocks"
serverErrors "github.com/openfga/openfga/server/errors"
"github.com/stretchr/testify/require"
"github.com/tidwall/gjson"
openfgapb "go.buf.build/openfga/go/openfga/api/openfga/v1"
Expand All @@ -33,6 +35,7 @@ import (
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure"
healthv1pb "google.golang.org/grpc/health/grpc_health_v1"
"google.golang.org/protobuf/encoding/protojson"
)

func TestMain(m *testing.M) {
Expand Down Expand Up @@ -205,9 +208,10 @@ func createCertsAndKeys(t *testing.T) certHandle {
}

type authTest struct {
_name string
authHeader string
expectedError string
_name string
authHeader string
expectedErrorResponse *serverErrors.ErrorResponse
expectedStatusCode int
}

func TestBuildServiceWithNoAuth(t *testing.T) {
Expand Down Expand Up @@ -251,48 +255,39 @@ func TestBuildServiceWithPresharedKeyAuthentication(t *testing.T) {
ensureServiceUp(t, service, nil, true)

tests := []authTest{{
_name: "Header with incorrect key fails",
authHeader: "Bearer incorrectkey",
expectedError: "unauthenticated",
_name: "Header with incorrect key fails",
authHeader: "Bearer incorrectkey",
expectedErrorResponse: &serverErrors.ErrorResponse{
Code: "unauthenticated",
Message: "unauthenticated",
},
expectedStatusCode: 401,
}, {
_name: "Missing header fails",
authHeader: "",
expectedError: "missing bearer token",
_name: "Missing header fails",
authHeader: "",
expectedErrorResponse: &serverErrors.ErrorResponse{
Code: "bearer_token_missing",
Message: "missing bearer token",
},
expectedStatusCode: 401,
}, {
_name: "Correct key one succeeds",
authHeader: "Bearer KEYONE",
expectedError: "",
_name: "Correct key one succeeds",
authHeader: fmt.Sprintf("Bearer %s", config.Authn.AuthnPresharedKeyConfig.Keys[0]),
expectedStatusCode: 200,
}, {
_name: "Correct key two succeeds",
authHeader: "Bearer KEYTWO",
expectedError: "",
_name: "Correct key two succeeds",
authHeader: fmt.Sprintf("Bearer %s", config.Authn.AuthnPresharedKeyConfig.Keys[1]),
expectedStatusCode: 200,
}}

retryClient := retryablehttp.NewClient()
for _, test := range tests {
t.Run(test._name, func(t *testing.T) {
payload := strings.NewReader(`{"name": "some-store-name"}`)
req, err := retryablehttp.NewRequest("POST", fmt.Sprintf("http://localhost:%d/stores", service.GetHTTPAddrPort().Port()), payload)
require.NoError(t, err, "Failed to construct request")
req.Header.Set("content-type", "application/json")
req.Header.Set("authorization", test.authHeader)

res, err := retryClient.Do(req)
require.NoError(t, err, "Failed to execute request")

defer res.Body.Close()
body, err := io.ReadAll(res.Body)
require.NoError(t, err, "Failed to read response")

stringBody := string(body)

if test.expectedError == "" && strings.Contains(stringBody, "code") {
t.Fatalf("Expected no error but got '%v'", stringBody)
}
tryGetStores(t, test, service, retryClient)
})

if !strings.Contains(stringBody, test.expectedError) && test.expectedError != "" {
t.Fatalf("Expected '%v' to contain '%v'", stringBody, test.expectedError)
}
t.Run(test._name+"/streaming", func(t *testing.T) {
tryStreamingListObjects(t, test, service, retryClient, config.Authn.AuthnPresharedKeyConfig.Keys[0])
})
}

Expand All @@ -301,6 +296,75 @@ func TestBuildServiceWithPresharedKeyAuthentication(t *testing.T) {
require.NoError(t, g.Wait())
}

func tryStreamingListObjects(t *testing.T, test authTest, service *service, retryClient *retryablehttp.Client, validToken string) {
// create a store
createStorePayload := strings.NewReader(`{"name": "some-store-name"}`)
req, err := retryablehttp.NewRequest("POST", fmt.Sprintf("http://localhost:%d/stores", service.GetHTTPAddrPort().Port()), createStorePayload)
require.NoError(t, err, "Failed to construct create store request")
req.Header.Set("content-type", "application/json")
req.Header.Set("authorization", fmt.Sprintf("Bearer %s", validToken))
res, err := retryClient.Do(req)
require.NoError(t, err, "Failed to execute create store request")
defer res.Body.Close()
body, err := io.ReadAll(res.Body)
require.NoError(t, err, "Failed to read create store response")
var createStoreResponse openfgapb.CreateStoreResponse
err = protojson.Unmarshal(body, &createStoreResponse)
require.NoError(t, err, "Failed to unmarshal create store response")

// create an authorization model
authModelPayload := strings.NewReader(`{"type_definitions":[{"type":"document","relations":{"owner":{"this":{}}}}]}`)
req, err = retryablehttp.NewRequest("POST", fmt.Sprintf("http://localhost:%d/stores/%s/authorization-models", service.GetHTTPAddrPort().Port(), createStoreResponse.Id), authModelPayload)
require.NoError(t, err, "Failed to construct create authorization model request")
req.Header.Set("content-type", "application/json")
req.Header.Set("authorization", fmt.Sprintf("Bearer %s", validToken))
_, err = retryClient.Do(req)
require.NoError(t, err, "Failed to execute create authorization model request")

// call one streaming endpoint
listObjectsPayload := strings.NewReader(`{"type": "document", "user": "anne", "relation": "owner"}`)
req, err = retryablehttp.NewRequest("POST", fmt.Sprintf("http://localhost:%d/stores/%s/streamed-list-objects", service.GetHTTPAddrPort().Port(), createStoreResponse.Id), listObjectsPayload)
require.NoError(t, err, "Failed to construct request")
req.Header.Set("content-type", "application/json")
req.Header.Set("authorization", test.authHeader)

res, err = retryClient.Do(req)
require.Equal(t, test.expectedStatusCode, res.StatusCode)
require.NoError(t, err, "Failed to execute streaming request")

defer res.Body.Close()
body, err = io.ReadAll(res.Body)
require.NoError(t, err, "Failed to read response")

if test.expectedErrorResponse != nil {
require.Contains(t, string(body), fmt.Sprintf(",\"message\":\"%s\"", test.expectedErrorResponse.Message))
}
}

func tryGetStores(t *testing.T, test authTest, service *service, retryClient *retryablehttp.Client) {
req, err := retryablehttp.NewRequest("GET", fmt.Sprintf("http://localhost:%d/stores", service.GetHTTPAddrPort().Port()), nil)
require.NoError(t, err, "Failed to construct request")
req.Header.Set("content-type", "application/json")
req.Header.Set("authorization", test.authHeader)

res, err := retryClient.Do(req)
require.NoError(t, err, "Failed to execute request")
require.Equal(t, test.expectedStatusCode, res.StatusCode)

defer res.Body.Close()
body, err := io.ReadAll(res.Body)
require.NoError(t, err, "Failed to read response")

if test.expectedErrorResponse != nil {
var actualErrorResponse serverErrors.ErrorResponse
err = json.Unmarshal(body, &actualErrorResponse)

require.NoError(t, err, "Failed to unmarshal response")

require.Equal(t, test.expectedErrorResponse, &actualErrorResponse)
}
}

func TestHTTPServerWithCORS(t *testing.T) {
config, err := DefaultConfigWithRandomPorts()
require.NoError(t, err)
Expand Down Expand Up @@ -436,43 +500,35 @@ func TestBuildServerWithOIDCAuthentication(t *testing.T) {
require.NoError(t, err)

tests := []authTest{{
_name: "Header with invalid token fails",
authHeader: "Bearer incorrecttoken",
expectedError: "invalid bearer token",
_name: "Header with invalid token fails",
authHeader: "Bearer incorrecttoken",
expectedErrorResponse: &serverErrors.ErrorResponse{
Code: "auth_failed_invalid_bearer_token",
Message: "invalid bearer token",
},
expectedStatusCode: 401,
}, {
_name: "Missing header fails",
authHeader: "",
expectedError: "missing bearer token",
_name: "Missing header fails",
authHeader: "",
expectedErrorResponse: &serverErrors.ErrorResponse{
Code: "bearer_token_missing",
Message: "missing bearer token",
},
expectedStatusCode: 401,
}, {
_name: "Correct token succeeds",
authHeader: "Bearer " + trustedToken,
expectedError: "",
_name: "Correct token succeeds",
authHeader: "Bearer " + trustedToken,
expectedStatusCode: 200,
}}

retryClient := retryablehttp.NewClient()
for _, test := range tests {
t.Run(test._name, func(t *testing.T) {
payload := strings.NewReader(`{"name": "some-store-name"}`)
req, err := retryablehttp.NewRequest("POST", fmt.Sprintf("http://localhost:%d/stores", service.GetHTTPAddrPort().Port()), payload)
require.NoError(t, err, "Failed to construct request")
req.Header.Set("content-type", "application/json")
req.Header.Set("authorization", test.authHeader)

res, err := retryClient.Do(req)
require.NoError(t, err, "Failed to execute request")

defer res.Body.Close()
body, err := io.ReadAll(res.Body)
require.NoError(t, err, "Failed to read response")

stringBody := string(body)
if test.expectedError == "" && strings.Contains(stringBody, "code") {
t.Fatalf("Expected no error but got %v", stringBody)
}
tryGetStores(t, test, service, retryClient)
})

if !strings.Contains(stringBody, test.expectedError) && test.expectedError != "" {
t.Fatalf("Expected %v to contain %v", stringBody, test.expectedError)
}
t.Run(test._name+"/streaming", func(t *testing.T) {
tryStreamingListObjects(t, test, service, retryClient, trustedToken)
})
}

Expand Down
15 changes: 15 additions & 0 deletions server/middleware/logging.go
Expand Up @@ -25,3 +25,18 @@ func NewErrorLoggingInterceptor(logger logger.Logger) grpc.UnaryServerIntercepto
return resp, nil
}
}

func NewStreamingErrorLoggingInterceptor(logger logger.Logger) grpc.StreamServerInterceptor {
return func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error {
err := handler(srv, stream)
if err != nil {
var e error
if internalError, ok := err.(serverErrors.InternalError); ok {
e = internalError.Internal()
}
logger.Error("grpc_error", zap.Error(e), zap.String("public_error", err.Error()))
}

return err
}
}
13 changes: 10 additions & 3 deletions server/server.go
Expand Up @@ -81,6 +81,7 @@ type Config struct {
ListObjectsDeadline time.Duration
ListObjectsMaxResults uint32
UnaryInterceptors []grpc.UnaryServerInterceptor
StreamingInterceptors []grpc.StreamServerInterceptor
MuxOptions []runtime.ServeMuxOption
}

Expand Down Expand Up @@ -486,13 +487,19 @@ func (s *Server) IsReady(ctx context.Context) (bool, error) {
// server cancel the provided ctx.
func (s *Server) Run(ctx context.Context) error {

interceptors := []grpc.UnaryServerInterceptor{
unaryServerInterceptors := []grpc.UnaryServerInterceptor{
grpc_validator.UnaryServerInterceptor(),
}
interceptors = append(interceptors, s.config.UnaryInterceptors...)
unaryServerInterceptors = append(unaryServerInterceptors, s.config.UnaryInterceptors...)

streamingInterceptors := []grpc.StreamServerInterceptor{
grpc_validator.StreamServerInterceptor(),
}
streamingInterceptors = append(streamingInterceptors, s.config.StreamingInterceptors...)

opts := []grpc.ServerOption{
grpc.ChainUnaryInterceptor(interceptors...),
grpc.ChainUnaryInterceptor(unaryServerInterceptors...),
grpc.ChainStreamInterceptor(streamingInterceptors...),
}

if s.config.GRPCServer.TLSConfig != nil {
Expand Down

0 comments on commit 779d73d

Please sign in to comment.