diff --git a/server/resourcemanager/operations.go b/server/resourcemanager/operations.go index 9a30a27a97..6d1c9d92c4 100644 --- a/server/resourcemanager/operations.go +++ b/server/resourcemanager/operations.go @@ -17,6 +17,9 @@ const ( OperationUpdate Operation = "update" OperationGet Operation = "get" OperationDelete Operation = "delete" + + OperationGetAugmented Operation = "getAugmented" + OperationListAugmented Operation = "listAugmented" ) var availableOperations = []Operation{ @@ -67,6 +70,14 @@ type Current[T ResourceSpec] interface { Current(context.Context) (T, error) } +type GetAugmented[T ResourceSpec] interface { + GetAugmented(context.Context, id.ID) (T, error) +} + +type ListAugmented[T ResourceSpec] interface { + ListAugmented(_ context.Context, take, skip int, query, sortBy, sortDirection string) ([]T, error) +} + type resourceHandler[T ResourceSpec] struct { SetID func(T, id.ID) T List func(_ context.Context, take, skip int, query, sortBy, sortDirection string) ([]T, error) @@ -77,6 +88,9 @@ type resourceHandler[T ResourceSpec] struct { Get func(context.Context, id.ID) (T, error) Delete func(context.Context, id.ID) error Provision func(context.Context, T) error + + GetAugmented func(context.Context, id.ID) (T, error) + ListAugmented func(_ context.Context, take, skip int, query, sortBy, sortDirection string) ([]T, error) } func (rh *resourceHandler[T]) bindOperations(enabledOperations []Operation, handler any) error { @@ -119,6 +133,20 @@ func (rh *resourceHandler[T]) bindOperations(enabledOperations []Operation, hand } } + if slices.Contains(enabledOperations, OperationGetAugmented) { + err := rh.bindGetAugmented(handler) + if err != nil { + return err + } + } + + if slices.Contains(enabledOperations, OperationListAugmented) { + err := rh.bindListAugmented(handler) + if err != nil { + return err + } + } + err := rh.bindProvisionOperation(handler) if err != nil { return err @@ -192,3 +220,23 @@ func (rh *resourceHandler[T]) bindProvisionOperation(handler any) error { return nil } + +func (rh *resourceHandler[T]) bindGetAugmented(handler any) error { + casted, ok := handler.(GetAugmented[T]) + if !ok { + return fmt.Errorf("handler does not implement interface `GetAugmented[T]`") + } + rh.GetAugmented = casted.GetAugmented + + return nil +} + +func (rh *resourceHandler[T]) bindListAugmented(handler any) error { + casted, ok := handler.(ListAugmented[T]) + if !ok { + return fmt.Errorf("handler does not implement interface `ListAugmented[T]`") + } + rh.ListAugmented = casted.ListAugmented + + return nil +} diff --git a/server/resourcemanager/resource_manager.go b/server/resourcemanager/resource_manager.go index 628839352a..5b1bf28863 100644 --- a/server/resourcemanager/resource_manager.go +++ b/server/resourcemanager/resource_manager.go @@ -270,7 +270,12 @@ func (m *manager[T]) list(w http.ResponseWriter, r *http.Request) { return } - items, err := m.rh.List( + listFn := m.rh.List + if isRequestForAugmented(r) { + listFn = m.rh.ListAugmented + } + + items, err := listFn( ctx, take, skip, @@ -316,6 +321,12 @@ func (m *manager[T]) list(w http.ResponseWriter, r *http.Request) { writeResponse(w, http.StatusOK, string(bytes)) } +const HeaderAugmented = "X-Tracetest-Augmented" + +func isRequestForAugmented(r *http.Request) bool { + return r.Header.Get(HeaderAugmented) == "true" +} + func (m *manager[T]) get(w http.ResponseWriter, r *http.Request) { encoder, err := encoderFromRequest(r) if err != nil { @@ -327,7 +338,12 @@ func (m *manager[T]) get(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) id := id.ID(vars["id"]) - item, err := m.rh.Get(r.Context(), id) + getterFn := m.rh.Get + if isRequestForAugmented(r) { + getterFn = m.rh.GetAugmented + } + + item, err := getterFn(r.Context(), id) if err != nil { m.handleResourceHandlerError(w, "getting", err, encoder) return diff --git a/server/resourcemanager/resource_manager_test.go b/server/resourcemanager/resource_manager_test.go index 36d0e9f47b..aa292a6c9b 100644 --- a/server/resourcemanager/resource_manager_test.go +++ b/server/resourcemanager/resource_manager_test.go @@ -8,7 +8,6 @@ import ( "github.com/gorilla/mux" "github.com/kubeshop/tracetest/server/pkg/id" - "github.com/kubeshop/tracetest/server/resourcemanager" rm "github.com/kubeshop/tracetest/server/resourcemanager" rmtests "github.com/kubeshop/tracetest/server/resourcemanager/testutil" "github.com/stretchr/testify/mock" @@ -27,55 +26,10 @@ func TestSampleResource(t *testing.T) { SomeValue: "the value updated", } - prepareSortByID := func(m *sampleResourceManager) { - m.On("List", mock.Anything, mock.Anything, mock.Anything, "id", "asc"). - Return([]sampleResource{ - {ID: "1", Name: "3", SomeValue: "3"}, - {ID: "2", Name: "1", SomeValue: "1"}, - {ID: "3", Name: "2", SomeValue: "2"}, - }, nil) - m.On("List", mock.Anything, mock.Anything, mock.Anything, "id", "desc"). - Return([]sampleResource{ - {ID: "3", Name: "2", SomeValue: "2"}, - {ID: "2", Name: "1", SomeValue: "1"}, - {ID: "1", Name: "3", SomeValue: "3"}, - }, nil) - } - - prepareSortByName := func(m *sampleResourceManager) { - m.On("List", mock.Anything, mock.Anything, mock.Anything, "name", "asc"). - Return([]sampleResource{ - {Name: "1", ID: "3", SomeValue: "3"}, - {Name: "2", ID: "1", SomeValue: "1"}, - {Name: "3", ID: "2", SomeValue: "2"}, - }, nil) - m.On("List", mock.Anything, mock.Anything, mock.Anything, "name", "desc"). - Return([]sampleResource{ - {Name: "3", ID: "2", SomeValue: "2"}, - {Name: "2", ID: "1", SomeValue: "1"}, - {Name: "1", ID: "3", SomeValue: "3"}, - }, nil) - } - - prepareSortBySomeValue := func(m *sampleResourceManager) { - m.On("List", mock.Anything, mock.Anything, mock.Anything, "some_value", "asc"). - Return([]sampleResource{ - {SomeValue: "1", ID: "3", Name: "3"}, - {SomeValue: "2", ID: "1", Name: "1"}, - {SomeValue: "3", ID: "2", Name: "2"}, - }, nil) - m.On("List", mock.Anything, mock.Anything, mock.Anything, "some_value", "desc"). - Return([]sampleResource{ - {SomeValue: "3", ID: "2", Name: "2"}, - {SomeValue: "2", ID: "1", Name: "1"}, - {SomeValue: "1", ID: "3", Name: "3"}, - }, nil) - } - rmtests.TestResourceTypeWithErrorOperations(t, rmtests.ResourceTypeTest{ ResourceTypeSingular: "SampleResource", ResourceTypePlural: "SampleResources", - RegisterManagerFn: func(router *mux.Router, db *sql.DB) resourcemanager.Manager { + RegisterManagerFn: func(router *mux.Router, db *sql.DB) rm.Manager { mockManager := new(sampleResourceManager) manager := rm.New[sampleResource]( "SampleResource", @@ -99,7 +53,8 @@ func TestSampleResource(t *testing.T) { mockManager. On("Provision", sample). Return(nil) - // Create + + // Create case rmtests.OperationCreateNoID: withGenID := sample withGenID.ID = id.ID("3") @@ -115,7 +70,7 @@ func TestSampleResource(t *testing.T) { On("Create", sample). Return(sampleResource{}, fmt.Errorf("some error")) - // Update + // Update case rmtests.OperationUpdateNotFound: mockManager. On("Update", sampleUpdated). @@ -129,7 +84,7 @@ func TestSampleResource(t *testing.T) { On("Update", sampleUpdated). Return(sampleResource{}, fmt.Errorf("some error")) - // Get + // Get case rmtests.OperationGetNotFound: mockManager. On("Get", sample.ID). @@ -160,7 +115,7 @@ func TestSampleResource(t *testing.T) { On("Delete", sample.ID). Return(fmt.Errorf("some error")) - // List + // List case rmtests.OperationListSuccess: mockManager. On("Count", mock.Anything). @@ -233,7 +188,7 @@ func TestRestrictedResource(t *testing.T) { rm.WithIDGen(func() id.ID { return id.ID("3") }), - rm.WithOperations(rm.OperationGet, rm.OperationUpdate), + rm.WithOperations(mockManager.Operations()...), ) manager.RegisterRoutes(router) @@ -296,13 +251,130 @@ func TestRestrictedResource(t *testing.T) { }) } +func TestAugmentedResource(t *testing.T) { + sample := sampleResource{ + ID: "1", + Name: "the name", + SomeValue: "the value", + } + + sampleAugmented := sampleResource{ + ID: "1", + Name: "the name", + SomeValue: "the value", + SomeAugmentedOnlyValue: "augmentation works", + } + + rmtests.TestResourceTypeWithErrorOperations(t, rmtests.ResourceTypeTest{ + ResourceTypeSingular: "AugmentedResource", + ResourceTypePlural: "AugmentedResources", + RegisterManagerFn: func(router *mux.Router, db *sql.DB) rm.Manager { + mockManager := new(augmentedResourceManager) + manager := rm.New[sampleResource]( + "AugmentedResource", + "AugmentedResources", + mockManager, + rm.WithOperations(mockManager.Operations()...), + ) + manager.RegisterRoutes(router) + + return manager + }, + Prepare: func(t *testing.T, op rmtests.Operation, manager rm.Manager) { + mockManager := manager.Handler().(*augmentedResourceManager) + mockManager.Test(t) + + switch op { + // Provisioning + case rmtests.OperationProvisioningSuccess: + mockManager. + On("Provision", sample). + Return(nil) + + // Get + case rmtests.OperationGetNotFound: + mockManager. + On("Get", sample.ID). + Return(sampleResource{}, sql.ErrNoRows) + case rmtests.OperationGetSuccess: + mockManager. + On("Get", sample.ID). + Return(sample, nil) + case rmtests.OperationGetInternalError: + mockManager. + On("Get", sample.ID). + Return(sampleResource{}, fmt.Errorf("some error")) + + // List + case rmtests.OperationListSuccess: + mockManager. + On("Count", mock.Anything). + Return(1, nil) + mockManager. + On("List", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return([]sampleResource{sample}, nil) + case rmtests.OperationListNoResults: + mockManager. + On("Count", mock.Anything). + Return(0, nil) + mockManager. + On("List", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return([]sampleResource{}, nil) + case rmtests.OperationListPaginatedSuccess: + mockManager. + On("Count", mock.Anything). + Return(3, nil) + + prepareSortByID(mockManager) + prepareSortByName(mockManager) + prepareSortBySomeValue(mockManager) + case rmtests.OperationListInternalError: + mockManager. + On("Count", mock.Anything). + Return(0, fmt.Errorf("some error")) + + // Augmented + case rmtests.OperationGetAugmentedSuccess: + mockManager. + On("GetAugmented", sampleAugmented.ID). + Return(sampleAugmented, nil) + case rmtests.OperationListAugmentedSuccess: + mockManager. + On("Count", mock.Anything). + Return(1, nil) + mockManager. + On("ListAugmented", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything). + Return([]sampleResource{sampleAugmented}, nil) + } + }, + SampleJSON: `{ + "type": "AugmentedResource", + "spec": { + "id": "1", + "name": "the name", + "some_value": "the value" + } + }`, + SampleJSONAugmented: `{ + "type": "AugmentedResource", + "spec": { + "id": "1", + "name": "the name", + "some_value": "the value", + "some_augmented_value": "augmentation works" + } + }`, + }) +} + // test structures and mocks type sampleResource struct { ID id.ID `mapstructure:"id"` Name string `mapstructure:"name"` - SomeValue string `mapstructure:"some_value"` + SomeValue string `mapstructure:"some_value"` + SomeAugmentedOnlyValue string `mapstructure:"some_augmented_value,omitempty"` } func (sr sampleResource) HasID() bool { @@ -348,6 +420,29 @@ func (m *restrictedResourceManager) Operations() []rm.Operation { } } +type augmentedResourceManager struct { + sampleResourceManager +} + +func (m *augmentedResourceManager) Operations() []rm.Operation { + return []rm.Operation{ + rm.OperationGet, + rm.OperationGetAugmented, + rm.OperationList, + rm.OperationListAugmented, + } +} + +func (m *augmentedResourceManager) GetAugmented(_ context.Context, id id.ID) (sampleResource, error) { + args := m.Called(id) + return args.Get(0).(sampleResource), args.Error(1) +} + +func (m *augmentedResourceManager) ListAugmented(_ context.Context, take, skip int, query, sortBy, sortDirection string) ([]sampleResource, error) { + args := m.Called(take, skip, query, sortBy, sortDirection) + return args.Get(0).([]sampleResource), args.Error(1) +} + type sampleResourceManager struct { baseResourceManager } @@ -375,3 +470,52 @@ func (m *sampleResourceManager) Count(_ context.Context, query string) (int, err args := m.Called(query) return args.Int(0), args.Error(1) } + +type mockable interface { + On(string, ...interface{}) *mock.Call +} + +func prepareSortByID(m mockable) { + m.On("List", mock.Anything, mock.Anything, mock.Anything, "id", "asc"). + Return([]sampleResource{ + {ID: "1", Name: "3", SomeValue: "3"}, + {ID: "2", Name: "1", SomeValue: "1"}, + {ID: "3", Name: "2", SomeValue: "2"}, + }, nil) + m.On("List", mock.Anything, mock.Anything, mock.Anything, "id", "desc"). + Return([]sampleResource{ + {ID: "3", Name: "2", SomeValue: "2"}, + {ID: "2", Name: "1", SomeValue: "1"}, + {ID: "1", Name: "3", SomeValue: "3"}, + }, nil) +} + +func prepareSortByName(m mockable) { + m.On("List", mock.Anything, mock.Anything, mock.Anything, "name", "asc"). + Return([]sampleResource{ + {Name: "1", ID: "3", SomeValue: "3"}, + {Name: "2", ID: "1", SomeValue: "1"}, + {Name: "3", ID: "2", SomeValue: "2"}, + }, nil) + m.On("List", mock.Anything, mock.Anything, mock.Anything, "name", "desc"). + Return([]sampleResource{ + {Name: "3", ID: "2", SomeValue: "2"}, + {Name: "2", ID: "1", SomeValue: "1"}, + {Name: "1", ID: "3", SomeValue: "3"}, + }, nil) +} + +func prepareSortBySomeValue(m mockable) { + m.On("List", mock.Anything, mock.Anything, mock.Anything, "some_value", "asc"). + Return([]sampleResource{ + {SomeValue: "1", ID: "3", Name: "3"}, + {SomeValue: "2", ID: "1", Name: "1"}, + {SomeValue: "3", ID: "2", Name: "2"}, + }, nil) + m.On("List", mock.Anything, mock.Anything, mock.Anything, "some_value", "desc"). + Return([]sampleResource{ + {SomeValue: "3", ID: "2", Name: "2"}, + {SomeValue: "2", ID: "1", Name: "1"}, + {SomeValue: "1", ID: "3", Name: "3"}, + }, nil) +} diff --git a/server/resourcemanager/testutil/operations.go b/server/resourcemanager/testutil/operations.go index 7aaf16632f..5d51de2ffc 100644 --- a/server/resourcemanager/testutil/operations.go +++ b/server/resourcemanager/testutil/operations.go @@ -66,6 +66,8 @@ var ( getNotFoundOperation, getSuccessOperation, + getAugmentedSuccessOperation, + deleteNotFoundOperation, deleteSuccessOperation, diff --git a/server/resourcemanager/testutil/operations_augmented.go b/server/resourcemanager/testutil/operations_augmented.go new file mode 100644 index 0000000000..f7b75b21a3 --- /dev/null +++ b/server/resourcemanager/testutil/operations_augmented.go @@ -0,0 +1,70 @@ +package testutil + +import ( + "net/http" + "net/http/httptest" + "testing" + + rm "github.com/kubeshop/tracetest/server/resourcemanager" + "github.com/stretchr/testify/require" +) + +func buildAugmentedGetRequest(rt ResourceTypeTest, ct contentTypeConverter, testServer *httptest.Server, t *testing.T) *http.Request { + id := extractID(rt.SampleJSONAugmented) + req, err := getRequestForID(id, rt, testServer) + require.NoError(t, err) + req.Header.Set(rm.HeaderAugmented, "true") + return req +} + +const OperationGetAugmentedSuccess Operation = "GetAugmentedSuccess" + +var getAugmentedSuccessOperation = buildSingleStepOperation(singleStepOperationTester{ + name: OperationGetAugmentedSuccess, + neededForOperation: rm.OperationGetAugmented, + buildRequest: func(t *testing.T, testServer *httptest.Server, ct contentTypeConverter, rt ResourceTypeTest) *http.Request { + return buildAugmentedGetRequest(rt, ct, testServer, t) + }, + assertResponse: func(t *testing.T, resp *http.Response, ct contentTypeConverter, rt ResourceTypeTest) { + t.Helper() + require.Equal(t, 200, resp.StatusCode) + + jsonBody := responseBodyJSON(t, resp, ct) + + expected := ct.toJSON(rt.SampleJSONAugmented) + + rt.customJSONComparer(t, OperationGetAugmentedSuccess, expected, jsonBody) + }, +}) + +func buildAugmentedListRequest(rt ResourceTypeTest, ct contentTypeConverter, testServer *httptest.Server, t *testing.T) *http.Request { + req := buildListRequest( + rt.ResourceTypePlural, + map[string]string{}, + ct, + testServer, + t, + ) + req.Header.Set(rm.HeaderAugmented, "true") + return req +} + +const OperationListAugmentedSuccess Operation = "ListAugmentedSuccess" + +var ListAugmentedSuccessOperation = buildSingleStepOperation(singleStepOperationTester{ + name: OperationListAugmentedSuccess, + neededForOperation: rm.OperationListAugmented, + buildRequest: func(t *testing.T, testServer *httptest.Server, ct contentTypeConverter, rt ResourceTypeTest) *http.Request { + return buildAugmentedListRequest(rt, ct, testServer, t) + }, + assertResponse: func(t *testing.T, resp *http.Response, ct contentTypeConverter, rt ResourceTypeTest) { + t.Helper() + require.Equal(t, 200, resp.StatusCode) + + jsonBody := responseBodyJSON(t, resp, ct) + + expected := ct.toJSON(rt.SampleJSONAugmented) + + rt.customJSONComparer(t, OperationGetAugmentedSuccess, expected, jsonBody) + }, +}) diff --git a/server/resourcemanager/testutil/operations_get.go b/server/resourcemanager/testutil/operations_get.go index 042c7a1159..9883977050 100644 --- a/server/resourcemanager/testutil/operations_get.go +++ b/server/resourcemanager/testutil/operations_get.go @@ -11,8 +11,7 @@ import ( "github.com/stretchr/testify/require" ) -func buildGetRequest(rt ResourceTypeTest, ct contentTypeConverter, testServer *httptest.Server, t *testing.T) *http.Request { - id := extractID(rt.SampleJSON) +func getRequestForID(id string, rt ResourceTypeTest, testServer *httptest.Server) (*http.Request, error) { url := fmt.Sprintf( "%s/%s/%s", testServer.URL, @@ -20,7 +19,12 @@ func buildGetRequest(rt ResourceTypeTest, ct contentTypeConverter, testServer *h id, ) - req, err := http.NewRequest(http.MethodGet, url, nil) + return http.NewRequest(http.MethodGet, url, nil) +} + +func buildGetRequest(rt ResourceTypeTest, ct contentTypeConverter, testServer *httptest.Server, t *testing.T) *http.Request { + id := extractID(rt.SampleJSON) + req, err := getRequestForID(id, rt, testServer) require.NoError(t, err) return req } diff --git a/server/resourcemanager/testutil/test_resource.go b/server/resourcemanager/testutil/test_resource.go index aaafaa369b..0d5b80e413 100644 --- a/server/resourcemanager/testutil/test_resource.go +++ b/server/resourcemanager/testutil/test_resource.go @@ -20,8 +20,9 @@ type ResourceTypeTest struct { RegisterManagerFn func(*mux.Router, *sql.DB) rm.Manager Prepare func(t *testing.T, operation Operation, manager rm.Manager) - SampleJSON string - SampleJSONUpdated string + SampleJSON string + SampleJSONUpdated string + SampleJSONAugmented string // private fields sortFields []string