Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix update destination #3604

Merged
merged 1 commit into from
Nov 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion internal/access/destination.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ func CreateDestination(c *gin.Context, destination *models.Destination) error {
return data.CreateDestination(db, destination)
}

func SaveDestination(rCtx RequestContext, destination *models.Destination) error {
func UpdateDestination(rCtx RequestContext, destination *models.Destination) error {
roles := []string{models.InfraAdminRole, models.InfraConnectorRole}
if err := IsAuthorized(rCtx, roles...); err != nil {
return HandleAuthErr(err, "destination", "update", roles...)
Expand Down
19 changes: 1 addition & 18 deletions internal/server/data/destination.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,23 +27,6 @@ func (d *destinationsTable) ScanFields() []any {
return []any{&d.ConnectionCA, &d.ConnectionURL, &d.CreatedAt, &d.DeletedAt, &d.ID, &d.Kind, &d.LastSeenAt, &d.Name, &d.OrganizationID, &d.Resources, &d.Roles, &d.UniqueID, &d.UpdatedAt, &d.Version}
}

// destinationsUpdateTable is used to update the destination. It excludes
// the CreatedAt field, because that field is not part of the input to
// UpdateDestination.
type destinationsUpdateTable models.Destination

func (d destinationsUpdateTable) Table() string {
return "destinations"
}

func (d destinationsUpdateTable) Columns() []string {
return []string{"connection_ca", "connection_url", "deleted_at", "id", "last_seen_at", "name", "organization_id", "resources", "roles", "unique_id", "updated_at", "version"}
}

func (d destinationsUpdateTable) Values() []any {
return []any{d.ConnectionCA, d.ConnectionURL, d.DeletedAt, d.ID, d.LastSeenAt, d.Name, d.OrganizationID, d.Resources, d.Roles, d.UniqueID, d.UpdatedAt, d.Version}
}

func validateDestination(dest *models.Destination) error {
if dest.Name == "" {
return fmt.Errorf("Destination.Name is required")
Expand All @@ -68,7 +51,7 @@ func UpdateDestination(tx WriteTxn, destination *models.Destination) error {
if err := validateDestination(destination); err != nil {
return err
}
return update(tx, (*destinationsUpdateTable)(destination))
return update(tx, (*destinationsTable)(destination))
}

type GetDestinationOptions struct {
Expand Down
5 changes: 1 addition & 4 deletions internal/server/data/destination_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,11 +93,8 @@ func TestUpdateDestination(t *testing.T) {
}
createDestinations(t, tx, orig)

// Unlike other update operations, the passed in destination
// may be constructed entirely by the caller and may not have the
// created, or updated time set.
destination := &models.Destination{
Model: models.Model{ID: orig.ID},
Model: orig.Model,
Name: "example-cluster-2",
UniqueID: "22222",
Kind: "kubernetes",
Expand Down
168 changes: 168 additions & 0 deletions internal/server/destination_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,12 @@ import (
"time"

gocmp "github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"gotest.tools/v3/assert"

"github.com/infrahq/infra/api"
"github.com/infrahq/infra/internal/server/data"
"github.com/infrahq/infra/internal/server/models"
)

func TestAPI_CreateDestination(t *testing.T) {
Expand Down Expand Up @@ -113,3 +116,168 @@ var cmpAPIDestinationJSON = gocmp.Options{
gocmp.FilterPath(pathMapKey(`created`, `updated`), cmpApproximateTime),
gocmp.FilterPath(pathMapKey(`id`), cmpAnyValidUID),
}

func TestAPI_UpdateDestination(t *testing.T) {
srv := setupServer(t, withAdminUser)
routes := srv.GenerateRoutes()

dest := &models.Destination{
Name: "the-dest",
Kind: models.DestinationKindSSH,
UniqueID: "unique-id",
}
assert.NilError(t, data.CreateDestination(srv.db, dest))

type testCase struct {
name string
setup func(t *testing.T, req *http.Request)
body func(t *testing.T) api.UpdateDestinationRequest
expected func(t *testing.T, resp *httptest.ResponseRecorder)
}

run := func(t *testing.T, tc testCase) {
createReq := tc.body(t)
body := jsonBody(t, &createReq)
req := httptest.NewRequest(http.MethodPut, "/api/destinations/"+dest.ID.String(), body)
req.Header.Set("Authorization", "Bearer "+adminAccessKey(srv))
req.Header.Set("Infra-Version", apiVersionLatest)

if tc.setup != nil {
tc.setup(t, req)
}

resp := httptest.NewRecorder()
routes.ServeHTTP(resp, req)

tc.expected(t, resp)
}

testCases := []testCase{
{
name: "not authenticated",
body: func(t *testing.T) api.UpdateDestinationRequest {
return api.UpdateDestinationRequest{}
},
setup: func(t *testing.T, req *http.Request) {
req.Header.Del("Authorization")
},
expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, resp.Code, http.StatusUnauthorized, (*responseDebug)(resp))
},
},
{
name: "not authorized",
body: func(t *testing.T) api.UpdateDestinationRequest {
return api.UpdateDestinationRequest{
Name: "the-dest",
UniqueID: "unique-id",
Connection: api.DestinationConnection{
URL: "10.10.10.10:12345",
CA: "the-ca-or-fingerprint",
},
}
},
setup: func(t *testing.T, req *http.Request) {
token, _ := createAccessKey(t, srv.db, "notauth@example.com")
req.Header.Set("Authorization", "Bearer "+token)
},
expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, resp.Code, http.StatusForbidden, (*responseDebug)(resp))
},
},
{
name: "missing required fields",
body: func(t *testing.T) api.UpdateDestinationRequest {
return api.UpdateDestinationRequest{}
},
expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, resp.Code, http.StatusBadRequest, (*responseDebug)(resp))

respBody := &api.Error{}
err := json.Unmarshal(resp.Body.Bytes(), respBody)
assert.NilError(t, err)

expected := []api.FieldError{
{FieldName: "connection.ca", Errors: []string{"is required"}},
{FieldName: "name", Errors: []string{"is required"}},
{FieldName: "uniqueID", Errors: []string{"is required"}},
}
assert.DeepEqual(t, respBody.FieldErrors, expected)
},
},
{
name: "success",
body: func(t *testing.T) api.UpdateDestinationRequest {
return api.UpdateDestinationRequest{
Name: "the-dest",
UniqueID: "unique-id",
Connection: api.DestinationConnection{
URL: "10.10.10.10:12345",
CA: "the-ca-or-fingerprint",
},
Roles: []string{"one", "two"},
}
},
setup: func(t *testing.T, req *http.Request) {
// Set the header that connectors use
req.Header.Set(headerInfraDestination, "unique-id")
},
expected: func(t *testing.T, resp *httptest.ResponseRecorder) {
assert.Equal(t, resp.Code, http.StatusOK, (*responseDebug)(resp))

expectedBody := jsonUnmarshal(t, fmt.Sprintf(`
{
"id": "%[2]v",
"name": "the-dest",
"kind": "ssh",
"uniqueID": "unique-id",
"version": "",
"connection": {
"url": "10.10.10.10:12345",
"ca": "the-ca-or-fingerprint"
},
"connected": true,
"lastSeen": "%[1]v",
"resources": null,
"roles": ["one", "two"],
"created": "%[1]v",
"updated": "%[1]v"
}
`, time.Now().UTC().Format(time.RFC3339), dest.ID))

actualBody := jsonUnmarshal(t, resp.Body.String())
assert.DeepEqual(t, actualBody, expectedBody, cmpAPIDestinationJSON)

expected := &models.Destination{
Model: dest.Model,
OrganizationMember: models.OrganizationMember{
OrganizationID: srv.db.DefaultOrg.ID,
},
Name: "the-dest",
UniqueID: "unique-id",
Kind: models.DestinationKindSSH,
ConnectionURL: "10.10.10.10:12345",
ConnectionCA: "the-ca-or-fingerprint",
LastSeenAt: time.Now(),
Roles: []string{"one", "two"},
}

actual, err := data.GetDestination(srv.db, data.GetDestinationOptions{ByID: dest.ID})
assert.NilError(t, err)

var cmpDestination = gocmp.Options{
cmpopts.EquateApproxTime(2 * time.Second),
cmpopts.EquateEmpty(),
}
assert.DeepEqual(t, actual, expected, cmpDestination)
assert.Assert(t, dest.UpdatedAt != actual.UpdatedAt)
},
},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
run(t, tc)
})
}
}
26 changes: 14 additions & 12 deletions internal/server/destinations.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,20 +74,22 @@ func (a *API) CreateDestination(c *gin.Context, r *api.CreateDestinationRequest)

func (a *API) UpdateDestination(c *gin.Context, r *api.UpdateDestinationRequest) (*api.Destination, error) {
rCtx := getRequestContext(c)
destination := &models.Destination{
Model: models.Model{
ID: r.ID,
},
Name: r.Name,
UniqueID: r.UniqueID,
ConnectionURL: r.Connection.URL,
ConnectionCA: string(r.Connection.CA),
Resources: r.Resources,
Roles: r.Roles,
Version: r.Version,

// Start with the existing value, so that non-update fields are not set to zero.
destination, err := access.GetDestination(c, r.ID)
if err != nil {
return nil, err
}

if err := access.SaveDestination(rCtx, destination); err != nil {
destination.Name = r.Name
destination.UniqueID = r.UniqueID
destination.ConnectionURL = r.Connection.URL
destination.ConnectionCA = string(r.Connection.CA)
destination.Resources = r.Resources
destination.Roles = r.Roles
destination.Version = r.Version

if err := access.UpdateDestination(rCtx, destination); err != nil {
return nil, fmt.Errorf("update destination: %w", err)
}

Expand Down
2 changes: 1 addition & 1 deletion internal/server/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func handleInfraDestinationHeader(rCtx access.RequestContext, uniqueID string) e
// only save if there's significant difference between LastSeenAt and Now
if time.Since(destination.LastSeenAt) > lastSeenUpdateThreshold {
destination.LastSeenAt = time.Now()
if err := access.SaveDestination(rCtx, destination); err != nil {
if err := access.UpdateDestination(rCtx, destination); err != nil {
return fmt.Errorf("failed to update destination lastSeenAt: %w", err)
}
}
Expand Down