From 67feb5a937f0c57f993f774ccf9deb0b48f7e602 Mon Sep 17 00:00:00 2001 From: Teemu Koponen Date: Wed, 3 Apr 2024 15:54:18 -0700 Subject: [PATCH] server: Remove unnecessary AST-to-JSON conversions. This time for v0QueryPath, v1DataGet, and v1DataPost. Signed-off-by: Teemu Koponen --- server/server.go | 80 +++++++++++++++---------------------------- server/server_test.go | 12 +++++-- 2 files changed, 38 insertions(+), 54 deletions(-) diff --git a/server/server.go b/server/server.go index b968815f7c..436d89662d 100644 --- a/server/server.go +++ b/server/server.go @@ -1028,22 +1028,12 @@ func (s *Server) v0QueryPath(w http.ResponseWriter, r *http.Request, urlPath str ctx := logging.WithDecisionID(r.Context(), decisionID) annotateSpan(ctx, decisionID) - input, err := readInputV0(r) + input, goInput, err := readInputV0(r) if err != nil { writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, fmt.Errorf("unexpected parse error for input: %w", err)) return } - var goInput *interface{} - if input != nil { - x, err := ast.JSON(input) - if err != nil { - writer.ErrorString(w, http.StatusInternalServerError, types.CodeInvalidParameter, fmt.Errorf("could not marshal input: %w", err)) - return - } - goInput = &x - } - // Prepare for query. txn, err := s.store.NewTransaction(ctx) if err != nil { @@ -1446,26 +1436,17 @@ func (s *Server) v1DataGet(w http.ResponseWriter, r *http.Request) { inputs := r.URL.Query()[types.ParamInputV1] var input ast.Value + var goInput *interface{} if len(inputs) > 0 { var err error - input, err = readInputGetV1(inputs[len(inputs)-1]) + input, goInput, err = readInputGetV1(inputs[len(inputs)-1]) if err != nil { writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, err) return } } - var goInput *interface{} - if input != nil { - x, err := ast.JSON(input) - if err != nil { - writer.ErrorString(w, http.StatusInternalServerError, types.CodeInvalidParameter, fmt.Errorf("could not marshal input: %w", err)) - return - } - goInput = &x - } - m.Timer(metrics.RegoInputParse).Stop() // Prepare for query. @@ -1678,22 +1659,12 @@ func (s *Server) v1DataPost(w http.ResponseWriter, r *http.Request) { m.Timer(metrics.RegoInputParse).Start() - input, err := readInputPostV1(r) + input, goInput, err := readInputPostV1(r) if err != nil { writer.ErrorString(w, http.StatusBadRequest, types.CodeInvalidParameter, err) return } - var goInput *interface{} - if input != nil { - x, err := ast.JSON(input) - if err != nil { - writer.ErrorString(w, http.StatusInternalServerError, types.CodeInvalidParameter, fmt.Errorf("could not marshal input: %w", err)) - return - } - goInput = &x - } - m.Timer(metrics.RegoInputParse).Stop() txn, err := s.store.NewTransaction(ctx, storage.TransactionParams{Context: storage.NewContext().WithMetrics(m)}) @@ -2752,17 +2723,18 @@ func getExplain(p []string, zero types.ExplainModeV1) types.ExplainModeV1 { return zero } -func readInputV0(r *http.Request) (ast.Value, error) { +func readInputV0(r *http.Request) (ast.Value, *interface{}, error) { parsed, ok := authorizer.GetBodyOnContext(r.Context()) if ok { - return ast.InterfaceToValue(parsed) + v, err := ast.InterfaceToValue(parsed) + return v, &parsed, err } // decompress the input if sent as zip body, err := readPlainBody(r) if err != nil { - return nil, fmt.Errorf("could not decompress the body: %w", err) + return nil, nil, fmt.Errorf("could not decompress the body: %w", err) } var x interface{} @@ -2770,41 +2742,44 @@ func readInputV0(r *http.Request) (ast.Value, error) { if strings.Contains(r.Header.Get("Content-Type"), "yaml") { bs, err := io.ReadAll(body) if err != nil { - return nil, err + return nil, nil, err } if len(bs) > 0 { if err = util.Unmarshal(bs, &x); err != nil { - return nil, fmt.Errorf("body contains malformed input document: %w", err) + return nil, nil, fmt.Errorf("body contains malformed input document: %w", err) } } } else { dec := util.NewJSONDecoder(body) if err := dec.Decode(&x); err != nil && err != io.EOF { - return nil, fmt.Errorf("body contains malformed input document: %w", err) + return nil, nil, fmt.Errorf("body contains malformed input document: %w", err) } } - return ast.InterfaceToValue(x) + v, err := ast.InterfaceToValue(x) + return v, &x, err } -func readInputGetV1(str string) (ast.Value, error) { +func readInputGetV1(str string) (ast.Value, *interface{}, error) { var input interface{} if err := util.UnmarshalJSON([]byte(str), &input); err != nil { - return nil, fmt.Errorf("parameter contains malformed input document: %w", err) + return nil, nil, fmt.Errorf("parameter contains malformed input document: %w", err) } - return ast.InterfaceToValue(input) + v, err := ast.InterfaceToValue(input) + return v, &input, err } -func readInputPostV1(r *http.Request) (ast.Value, error) { +func readInputPostV1(r *http.Request) (ast.Value, *interface{}, error) { parsed, ok := authorizer.GetBodyOnContext(r.Context()) if ok { if obj, ok := parsed.(map[string]interface{}); ok { if input, ok := obj["input"]; ok { - return ast.InterfaceToValue(input) + v, err := ast.InterfaceToValue(input) + return v, &input, err } } - return nil, nil + return nil, nil, nil } var request types.DataRequestV1 @@ -2812,7 +2787,7 @@ func readInputPostV1(r *http.Request) (ast.Value, error) { // decompress the input if sent as zip body, err := readPlainBody(r) if err != nil { - return nil, fmt.Errorf("could not decompress the body: %w", err) + return nil, nil, fmt.Errorf("could not decompress the body: %w", err) } ct := r.Header.Get("Content-Type") @@ -2821,25 +2796,26 @@ func readInputPostV1(r *http.Request) (ast.Value, error) { if strings.Contains(ct, "yaml") { bs, err := io.ReadAll(body) if err != nil { - return nil, err + return nil, nil, err } if len(bs) > 0 { if err = util.Unmarshal(bs, &request); err != nil { - return nil, fmt.Errorf("body contains malformed input document: %w", err) + return nil, nil, fmt.Errorf("body contains malformed input document: %w", err) } } } else { dec := util.NewJSONDecoder(body) if err := dec.Decode(&request); err != nil && err != io.EOF { - return nil, fmt.Errorf("body contains malformed input document: %w", err) + return nil, nil, fmt.Errorf("body contains malformed input document: %w", err) } } if request.Input == nil { - return nil, nil + return nil, nil, nil } - return ast.InterfaceToValue(*request.Input) + v, err := ast.InterfaceToValue(*request.Input) + return v, request.Input, err } type compileRequest struct { diff --git a/server/server_test.go b/server/server_test.go index 9a827af79b..9d7b48efb4 100644 --- a/server/server_test.go +++ b/server/server_test.go @@ -4119,7 +4119,7 @@ func TestServerUsesAuthorizerParsedBody(t *testing.T) { }) // Check that v1 reader function behaves correctly. - inp, err := readInputPostV1(req.WithContext(ctx)) + inp, goInp, err := readInputPostV1(req.WithContext(ctx)) if err != nil { t.Fatal(err) } @@ -4130,12 +4130,16 @@ func TestServerUsesAuthorizerParsedBody(t *testing.T) { t.Fatalf("expected %v but got %v", exp, inp) } + if exp.Value.Compare(ast.MustInterfaceToValue(*goInp)) != 0 { + t.Fatalf("expected %v but got %v", exp, *goInp) + } + // Check that v0 reader function behaves correctly. ctx = authorizer.SetBodyOnContext(req.Context(), map[string]interface{}{ "foo": "good", }) - inp, err = readInputV0(req.WithContext(ctx)) + inp, goInp, err = readInputV0(req.WithContext(ctx)) if err != nil { t.Fatal(err) } @@ -4143,6 +4147,10 @@ func TestServerUsesAuthorizerParsedBody(t *testing.T) { if exp.Value.Compare(inp) != 0 { t.Fatalf("expected %v but got %v", exp, inp) } + + if exp.Value.Compare(ast.MustInterfaceToValue(*goInp)) != 0 { + t.Fatalf("expected %v but got %v", exp, *goInp) + } } func TestServerReloadTrigger(t *testing.T) {