From cf1467ac0c084139ad91379cb58323772854ef29 Mon Sep 17 00:00:00 2001 From: Daniel Nephin Date: Wed, 9 Nov 2022 11:04:14 -0500 Subject: [PATCH] Fix update destination Start with the existing data so that we don't lose values for fields that are not updated. We have 3 non-updated fields now: CreatedAt, LastSeenAt, and Kind. Also remove the extra logic in the data layer, we no longer need this now that we're handling it properly in the API. This allows the update operation for destination to look like all the other update operations. Also add tests for the API endpoint. --- internal/access/destination.go | 2 +- internal/server/data/destination.go | 19 +-- internal/server/data/destination_test.go | 5 +- internal/server/destination_test.go | 168 +++++++++++++++++++++++ internal/server/destinations.go | 26 ++-- internal/server/middleware.go | 2 +- 6 files changed, 186 insertions(+), 36 deletions(-) diff --git a/internal/access/destination.go b/internal/access/destination.go index 29283534e0..4d09d8ac76 100644 --- a/internal/access/destination.go +++ b/internal/access/destination.go @@ -18,7 +18,7 @@ func CreateDestination(c *gin.Context, destination *models.Destination) error { return data.CreateDestination(db, destination) } -func SaveDestination(rCtx RequestContext, destination *models.Destination) error { +func UpdateDestination(rCtx RequestContext, destination *models.Destination) error { roles := []string{models.InfraAdminRole, models.InfraConnectorRole} if err := IsAuthorized(rCtx, roles...); err != nil { return HandleAuthErr(err, "destination", "update", roles...) diff --git a/internal/server/data/destination.go b/internal/server/data/destination.go index 0941cfe912..5430d26824 100644 --- a/internal/server/data/destination.go +++ b/internal/server/data/destination.go @@ -27,23 +27,6 @@ func (d *destinationsTable) ScanFields() []any { return []any{&d.ConnectionCA, &d.ConnectionURL, &d.CreatedAt, &d.DeletedAt, &d.ID, &d.Kind, &d.LastSeenAt, &d.Name, &d.OrganizationID, &d.Resources, &d.Roles, &d.UniqueID, &d.UpdatedAt, &d.Version} } -// destinationsUpdateTable is used to update the destination. It excludes -// the CreatedAt field, because that field is not part of the input to -// UpdateDestination. -type destinationsUpdateTable models.Destination - -func (d destinationsUpdateTable) Table() string { - return "destinations" -} - -func (d destinationsUpdateTable) Columns() []string { - return []string{"connection_ca", "connection_url", "deleted_at", "id", "last_seen_at", "name", "organization_id", "resources", "roles", "unique_id", "updated_at", "version"} -} - -func (d destinationsUpdateTable) Values() []any { - return []any{d.ConnectionCA, d.ConnectionURL, d.DeletedAt, d.ID, d.LastSeenAt, d.Name, d.OrganizationID, d.Resources, d.Roles, d.UniqueID, d.UpdatedAt, d.Version} -} - func validateDestination(dest *models.Destination) error { if dest.Name == "" { return fmt.Errorf("Destination.Name is required") @@ -68,7 +51,7 @@ func UpdateDestination(tx WriteTxn, destination *models.Destination) error { if err := validateDestination(destination); err != nil { return err } - return update(tx, (*destinationsUpdateTable)(destination)) + return update(tx, (*destinationsTable)(destination)) } type GetDestinationOptions struct { diff --git a/internal/server/data/destination_test.go b/internal/server/data/destination_test.go index e3d79b0578..1d0c013851 100644 --- a/internal/server/data/destination_test.go +++ b/internal/server/data/destination_test.go @@ -93,11 +93,8 @@ func TestUpdateDestination(t *testing.T) { } createDestinations(t, tx, orig) - // Unlike other update operations, the passed in destination - // may be constructed entirely by the caller and may not have the - // created, or updated time set. destination := &models.Destination{ - Model: models.Model{ID: orig.ID}, + Model: orig.Model, Name: "example-cluster-2", UniqueID: "22222", Kind: "kubernetes", diff --git a/internal/server/destination_test.go b/internal/server/destination_test.go index 196c56cca7..6fd0d63675 100644 --- a/internal/server/destination_test.go +++ b/internal/server/destination_test.go @@ -9,9 +9,12 @@ import ( "time" gocmp "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "gotest.tools/v3/assert" "github.com/infrahq/infra/api" + "github.com/infrahq/infra/internal/server/data" + "github.com/infrahq/infra/internal/server/models" ) func TestAPI_CreateDestination(t *testing.T) { @@ -113,3 +116,168 @@ var cmpAPIDestinationJSON = gocmp.Options{ gocmp.FilterPath(pathMapKey(`created`, `updated`), cmpApproximateTime), gocmp.FilterPath(pathMapKey(`id`), cmpAnyValidUID), } + +func TestAPI_UpdateDestination(t *testing.T) { + srv := setupServer(t, withAdminUser) + routes := srv.GenerateRoutes() + + dest := &models.Destination{ + Name: "the-dest", + Kind: models.DestinationKindSSH, + UniqueID: "unique-id", + } + assert.NilError(t, data.CreateDestination(srv.db, dest)) + + type testCase struct { + name string + setup func(t *testing.T, req *http.Request) + body func(t *testing.T) api.UpdateDestinationRequest + expected func(t *testing.T, resp *httptest.ResponseRecorder) + } + + run := func(t *testing.T, tc testCase) { + createReq := tc.body(t) + body := jsonBody(t, &createReq) + req := httptest.NewRequest(http.MethodPut, "/api/destinations/"+dest.ID.String(), body) + req.Header.Set("Authorization", "Bearer "+adminAccessKey(srv)) + req.Header.Set("Infra-Version", apiVersionLatest) + + if tc.setup != nil { + tc.setup(t, req) + } + + resp := httptest.NewRecorder() + routes.ServeHTTP(resp, req) + + tc.expected(t, resp) + } + + testCases := []testCase{ + { + name: "not authenticated", + body: func(t *testing.T) api.UpdateDestinationRequest { + return api.UpdateDestinationRequest{} + }, + setup: func(t *testing.T, req *http.Request) { + req.Header.Del("Authorization") + }, + expected: func(t *testing.T, resp *httptest.ResponseRecorder) { + assert.Equal(t, resp.Code, http.StatusUnauthorized, (*responseDebug)(resp)) + }, + }, + { + name: "not authorized", + body: func(t *testing.T) api.UpdateDestinationRequest { + return api.UpdateDestinationRequest{ + Name: "the-dest", + UniqueID: "unique-id", + Connection: api.DestinationConnection{ + URL: "10.10.10.10:12345", + CA: "the-ca-or-fingerprint", + }, + } + }, + setup: func(t *testing.T, req *http.Request) { + token, _ := createAccessKey(t, srv.db, "notauth@example.com") + req.Header.Set("Authorization", "Bearer "+token) + }, + expected: func(t *testing.T, resp *httptest.ResponseRecorder) { + assert.Equal(t, resp.Code, http.StatusForbidden, (*responseDebug)(resp)) + }, + }, + { + name: "missing required fields", + body: func(t *testing.T) api.UpdateDestinationRequest { + return api.UpdateDestinationRequest{} + }, + expected: func(t *testing.T, resp *httptest.ResponseRecorder) { + assert.Equal(t, resp.Code, http.StatusBadRequest, (*responseDebug)(resp)) + + respBody := &api.Error{} + err := json.Unmarshal(resp.Body.Bytes(), respBody) + assert.NilError(t, err) + + expected := []api.FieldError{ + {FieldName: "connection.ca", Errors: []string{"is required"}}, + {FieldName: "name", Errors: []string{"is required"}}, + {FieldName: "uniqueID", Errors: []string{"is required"}}, + } + assert.DeepEqual(t, respBody.FieldErrors, expected) + }, + }, + { + name: "success", + body: func(t *testing.T) api.UpdateDestinationRequest { + return api.UpdateDestinationRequest{ + Name: "the-dest", + UniqueID: "unique-id", + Connection: api.DestinationConnection{ + URL: "10.10.10.10:12345", + CA: "the-ca-or-fingerprint", + }, + Roles: []string{"one", "two"}, + } + }, + setup: func(t *testing.T, req *http.Request) { + // Set the header that connectors use + req.Header.Set(headerInfraDestination, "unique-id") + }, + expected: func(t *testing.T, resp *httptest.ResponseRecorder) { + assert.Equal(t, resp.Code, http.StatusOK, (*responseDebug)(resp)) + + expectedBody := jsonUnmarshal(t, fmt.Sprintf(` + { + "id": "%[2]v", + "name": "the-dest", + "kind": "ssh", + "uniqueID": "unique-id", + "version": "", + "connection": { + "url": "10.10.10.10:12345", + "ca": "the-ca-or-fingerprint" + }, + "connected": true, + "lastSeen": "%[1]v", + "resources": null, + "roles": ["one", "two"], + "created": "%[1]v", + "updated": "%[1]v" + } + `, time.Now().UTC().Format(time.RFC3339), dest.ID)) + + actualBody := jsonUnmarshal(t, resp.Body.String()) + assert.DeepEqual(t, actualBody, expectedBody, cmpAPIDestinationJSON) + + expected := &models.Destination{ + Model: dest.Model, + OrganizationMember: models.OrganizationMember{ + OrganizationID: srv.db.DefaultOrg.ID, + }, + Name: "the-dest", + UniqueID: "unique-id", + Kind: models.DestinationKindSSH, + ConnectionURL: "10.10.10.10:12345", + ConnectionCA: "the-ca-or-fingerprint", + LastSeenAt: time.Now(), + Roles: []string{"one", "two"}, + } + + actual, err := data.GetDestination(srv.db, data.GetDestinationOptions{ByID: dest.ID}) + assert.NilError(t, err) + + var cmpDestination = gocmp.Options{ + cmpopts.EquateApproxTime(2 * time.Second), + cmpopts.EquateEmpty(), + } + assert.DeepEqual(t, actual, expected, cmpDestination) + assert.Assert(t, dest.UpdatedAt != actual.UpdatedAt) + }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + run(t, tc) + }) + } +} diff --git a/internal/server/destinations.go b/internal/server/destinations.go index c51deafa50..38a39993ff 100644 --- a/internal/server/destinations.go +++ b/internal/server/destinations.go @@ -74,20 +74,22 @@ func (a *API) CreateDestination(c *gin.Context, r *api.CreateDestinationRequest) func (a *API) UpdateDestination(c *gin.Context, r *api.UpdateDestinationRequest) (*api.Destination, error) { rCtx := getRequestContext(c) - destination := &models.Destination{ - Model: models.Model{ - ID: r.ID, - }, - Name: r.Name, - UniqueID: r.UniqueID, - ConnectionURL: r.Connection.URL, - ConnectionCA: string(r.Connection.CA), - Resources: r.Resources, - Roles: r.Roles, - Version: r.Version, + + // Start with the existing value, so that non-update fields are not set to zero. + destination, err := access.GetDestination(c, r.ID) + if err != nil { + return nil, err } - if err := access.SaveDestination(rCtx, destination); err != nil { + destination.Name = r.Name + destination.UniqueID = r.UniqueID + destination.ConnectionURL = r.Connection.URL + destination.ConnectionCA = string(r.Connection.CA) + destination.Resources = r.Resources + destination.Roles = r.Roles + destination.Version = r.Version + + if err := access.UpdateDestination(rCtx, destination); err != nil { return nil, fmt.Errorf("update destination: %w", err) } diff --git a/internal/server/middleware.go b/internal/server/middleware.go index 0bfbe073fa..a7000910f4 100644 --- a/internal/server/middleware.go +++ b/internal/server/middleware.go @@ -34,7 +34,7 @@ func handleInfraDestinationHeader(rCtx access.RequestContext, uniqueID string) e // only save if there's significant difference between LastSeenAt and Now if time.Since(destination.LastSeenAt) > lastSeenUpdateThreshold { destination.LastSeenAt = time.Now() - if err := access.SaveDestination(rCtx, destination); err != nil { + if err := access.UpdateDestination(rCtx, destination); err != nil { return fmt.Errorf("failed to update destination lastSeenAt: %w", err) } }