From e6368cc57fa20743c7210fd745799712fd33d253 Mon Sep 17 00:00:00 2001 From: Tomas Touceda Date: Wed, 15 Sep 2021 16:27:53 -0300 Subject: [PATCH] Refactor integration tests (#1821) * Refactor integration tests * Remove nopCloser and use io.NopCloser * Address review comments --- cmd/fleetctl/apply_test.go | 11 +- cmd/fleetctl/debug_test.go | 3 +- cmd/fleetctl/get_test.go | 15 +- cmd/fleetctl/hosts_test.go | 15 +- cmd/fleetctl/query_test.go | 3 +- cmd/fleetctl/users_test.go | 3 +- server/datastore/mysql/testing_utils.go | 9 + server/service/client.go | 2 +- server/service/http_auth_test.go | 13 +- server/service/integration_core_test.go | 392 ++++++++ server/service/integration_ds_only_test.go | 61 ++ server/service/integration_enterprise_test.go | 142 +++ server/service/integration_logger_test.go | 183 ++++ server/service/integration_test.go | 931 ------------------ server/service/testing_client.go | 143 +++ server/service/testing_utils.go | 25 + 16 files changed, 975 insertions(+), 976 deletions(-) create mode 100644 server/service/integration_core_test.go create mode 100644 server/service/integration_ds_only_test.go create mode 100644 server/service/integration_enterprise_test.go create mode 100644 server/service/integration_logger_test.go delete mode 100644 server/service/integration_test.go create mode 100644 server/service/testing_client.go diff --git a/cmd/fleetctl/apply_test.go b/cmd/fleetctl/apply_test.go index 64cd4bfc2186..84ec2fa99640 100644 --- a/cmd/fleetctl/apply_test.go +++ b/cmd/fleetctl/apply_test.go @@ -41,8 +41,7 @@ var userRoleSpecList = []*fleet.User{ } func TestApplyUserRoles(t *testing.T) { - server, ds := runServerWithMockedDS(t) - defer server.Close() + _, ds := runServerWithMockedDS(t) ds.ListUsersFunc = func(ctx context.Context, opt fleet.UserListOptions) ([]*fleet.User, error) { return userRoleSpecList, nil @@ -102,11 +101,10 @@ spec: func TestApplyTeamSpecs(t *testing.T) { license := &fleet.LicenseInfo{Tier: fleet.TierPremium, Expiration: time.Now().Add(24 * time.Hour)} - server, ds := runServerWithMockedDS(t, service.TestServerOpts{License: license}) - defer server.Close() + _, ds := runServerWithMockedDS(t, service.TestServerOpts{License: license}) teamsByName := map[string]*fleet.Team{ - "team1": &fleet.Team{ + "team1": { ID: 42, Name: "team1", Description: "team1 description", @@ -185,8 +183,7 @@ func writeTmpYml(t *testing.T, contents string) string { } func TestApplyAppConfig(t *testing.T) { - server, ds := runServerWithMockedDS(t) - defer server.Close() + _, ds := runServerWithMockedDS(t) ds.ListUsersFunc = func(ctx context.Context, opt fleet.UserListOptions) ([]*fleet.User, error) { return userRoleSpecList, nil diff --git a/cmd/fleetctl/debug_test.go b/cmd/fleetctl/debug_test.go index f8a30a99b848..d71e72196ef0 100644 --- a/cmd/fleetctl/debug_test.go +++ b/cmd/fleetctl/debug_test.go @@ -44,8 +44,7 @@ oug6edBNpdhp8r2/4t6n3AouK0/zG2naAlmXV0JoFuEvy2bX0BbbbPg+v4WNZIsC func TestDebugConnectionCommand(t *testing.T) { t.Run("without certificate", func(t *testing.T) { - server, ds := runServerWithMockedDS(t) - defer server.Close() + _, ds := runServerWithMockedDS(t) ds.VerifyEnrollSecretFunc = func(ctx context.Context, secret string) (*fleet.EnrollSecret, error) { return nil, errors.New("invalid") diff --git a/cmd/fleetctl/get_test.go b/cmd/fleetctl/get_test.go index 8212b268fed2..2e96b18ca6db 100644 --- a/cmd/fleetctl/get_test.go +++ b/cmd/fleetctl/get_test.go @@ -55,8 +55,7 @@ var userRoleList = []*fleet.User{ } func TestGetUserRoles(t *testing.T) { - server, ds := runServerWithMockedDS(t) - defer server.Close() + _, ds := runServerWithMockedDS(t) ds.ListUsersFunc = func(ctx context.Context, opt fleet.UserListOptions) ([]*fleet.User, error) { return userRoleList, nil @@ -114,8 +113,7 @@ func TestGetTeams(t *testing.T) { for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { license := tt.license - server, ds := runServerWithMockedDS(t, service.TestServerOpts{License: license}) - defer server.Close() + _, ds := runServerWithMockedDS(t, service.TestServerOpts{License: license}) agentOpts := json.RawMessage(`{"config":{"foo":"bar"},"overrides":{"platforms":{"darwin":{"foo":"override"}}}}`) ds.ListTeamsFunc = func(ctx context.Context, filter fleet.TeamFilter, opt fleet.ListOptions) ([]*fleet.Team, error) { @@ -196,8 +194,7 @@ spec: } func TestGetHosts(t *testing.T) { - server, ds := runServerWithMockedDS(t) - defer server.Close() + _, ds := runServerWithMockedDS(t) // this func is called when no host is specified i.e. `fleetctl get hosts --json` ds.ListHostsFunc = func(ctx context.Context, filter fleet.TeamFilter, opt fleet.HostListOptions) ([]*fleet.Host, error) { @@ -343,8 +340,7 @@ func TestGetHosts(t *testing.T) { } func TestGetConfig(t *testing.T) { - server, ds := runServerWithMockedDS(t) - defer server.Close() + _, ds := runServerWithMockedDS(t) ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { return &fleet.AppConfig{ @@ -412,8 +408,7 @@ spec: } func TestGetSoftawre(t *testing.T) { - server, ds := runServerWithMockedDS(t) - defer server.Close() + _, ds := runServerWithMockedDS(t) foo001 := fleet.Software{ Name: "foo", Version: "0.0.1", Source: "chrome_extensions", GenerateCPE: "somecpe", diff --git a/cmd/fleetctl/hosts_test.go b/cmd/fleetctl/hosts_test.go index 7e8f08060589..936f0cae63b5 100644 --- a/cmd/fleetctl/hosts_test.go +++ b/cmd/fleetctl/hosts_test.go @@ -10,8 +10,7 @@ import ( ) func TestHostTransferFlagChecks(t *testing.T) { - server, _ := runServerWithMockedDS(t) - defer server.Close() + runServerWithMockedDS(t) runAppCheckErr(t, []string{"hosts", "transfer", "--team", "team1", "--hosts", "host1", "--label", "AAA"}, @@ -24,8 +23,7 @@ func TestHostTransferFlagChecks(t *testing.T) { } func TestHostsTransferByHosts(t *testing.T) { - server, ds := runServerWithMockedDS(t) - defer server.Close() + _, ds := runServerWithMockedDS(t) ds.HostByIdentifierFunc = func(ctx context.Context, identifier string) (*fleet.Host, error) { require.Equal(t, "host1", identifier) @@ -48,8 +46,7 @@ func TestHostsTransferByHosts(t *testing.T) { } func TestHostsTransferByLabel(t *testing.T) { - server, ds := runServerWithMockedDS(t) - defer server.Close() + _, ds := runServerWithMockedDS(t) ds.HostByIdentifierFunc = func(ctx context.Context, identifier string) (*fleet.Host, error) { require.Equal(t, "host1", identifier) @@ -83,8 +80,7 @@ func TestHostsTransferByLabel(t *testing.T) { } func TestHostsTransferByStatus(t *testing.T) { - server, ds := runServerWithMockedDS(t) - defer server.Close() + _, ds := runServerWithMockedDS(t) ds.HostByIdentifierFunc = func(ctx context.Context, identifier string) (*fleet.Host, error) { require.Equal(t, "host1", identifier) @@ -118,8 +114,7 @@ func TestHostsTransferByStatus(t *testing.T) { } func TestHostsTransferByStatusAndSearchQuery(t *testing.T) { - server, ds := runServerWithMockedDS(t) - defer server.Close() + _, ds := runServerWithMockedDS(t) ds.HostByIdentifierFunc = func(ctx context.Context, identifier string) (*fleet.Host, error) { require.Equal(t, "host1", identifier) diff --git a/cmd/fleetctl/query_test.go b/cmd/fleetctl/query_test.go index 306ecf173401..145d5d0f6e90 100644 --- a/cmd/fleetctl/query_test.go +++ b/cmd/fleetctl/query_test.go @@ -16,8 +16,7 @@ import ( func TestLiveQuery(t *testing.T) { rs := pubsub.NewInmemQueryResults() lq := new(live_query.MockLiveQuery) - server, ds := runServerWithMockedDS(t, service.TestServerOpts{Rs: rs, Lq: lq}) - defer server.Close() + _, ds := runServerWithMockedDS(t, service.TestServerOpts{Rs: rs, Lq: lq}) ds.HostIDsByNameFunc = func(ctx context.Context, filter fleet.TeamFilter, hostnames []string) ([]uint, error) { return []uint{1234}, nil diff --git a/cmd/fleetctl/users_test.go b/cmd/fleetctl/users_test.go index 663c6514bc10..482a1d4e0fd7 100644 --- a/cmd/fleetctl/users_test.go +++ b/cmd/fleetctl/users_test.go @@ -9,8 +9,7 @@ import ( ) func TestUserDelete(t *testing.T) { - server, ds := runServerWithMockedDS(t) - defer server.Close() + _, ds := runServerWithMockedDS(t) ds.UserByEmailFunc = func(ctx context.Context, email string) (*fleet.User, error) { return &fleet.User{ diff --git a/server/datastore/mysql/testing_utils.go b/server/datastore/mysql/testing_utils.go index 6d84ca5e8afd..be4d04a80e5f 100644 --- a/server/datastore/mysql/testing_utils.go +++ b/server/datastore/mysql/testing_utils.go @@ -243,3 +243,12 @@ func CreateMySQLDSWithOptions(t *testing.T, opts *DatastoreTestOptions) *Datasto func CreateMySQLDS(t *testing.T) *Datastore { return createMySQLDSWithOptions(t, nil) } + +func CreateNamedMySQLDS(t *testing.T, name string) *Datastore { + if _, ok := os.LookupEnv("MYSQL_TEST"); !ok { + t.Skip("MySQL tests are disabled") + } + + t.Parallel() + return initializeDatabase(t, name, new(DatastoreTestOptions)) +} diff --git a/server/service/client.go b/server/service/client.go index abb98c15afee..4e7a86b500a6 100644 --- a/server/service/client.go +++ b/server/service/client.go @@ -230,7 +230,7 @@ func (l *logRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { fmt.Fprintf(os.Stderr, "Read body error: %v", err) return nil, err } - res.Body = ioutil.NopCloser(resBody) + res.Body = io.NopCloser(resBody) return res, nil } diff --git a/server/service/http_auth_test.go b/server/service/http_auth_test.go index 3c9533de8589..eea04c61642c 100644 --- a/server/service/http_auth_test.go +++ b/server/service/http_auth_test.go @@ -21,7 +21,6 @@ import ( func TestLogin(t *testing.T) { ds, users, server := setupAuthTest(t) - defer server.Close() var loginTests = []struct { email string status int @@ -60,7 +59,7 @@ func TestLogin(t *testing.T) { j, err := json.Marshal(¶ms) assert.Nil(t, err) - requestBody := &nopCloser{bytes.NewBuffer(j)} + requestBody := io.NopCloser(bytes.NewBuffer(j)) resp, err := http.Post(server.URL+"/api/v1/fleet/login", "application/json", requestBody) require.Nil(t, err) assert.Equal(t, tt.status, resp.StatusCode) @@ -173,7 +172,7 @@ func getTestAdminToken(t *testing.T, server *httptest.Server) string { j, err := json.Marshal(¶ms) assert.Nil(t, err) - requestBody := &nopCloser{bytes.NewBuffer(j)} + requestBody := io.NopCloser(bytes.NewBuffer(j)) resp, err := http.Post(server.URL+"/api/v1/fleet/login", "application/json", requestBody) require.Nil(t, err) assert.Equal(t, http.StatusOK, resp.StatusCode) @@ -191,7 +190,6 @@ func getTestAdminToken(t *testing.T, server *httptest.Server) string { func TestNoHeaderErrorsDifferently(t *testing.T) { _, _, server := setupAuthTest(t) - defer server.Close() req, _ := http.NewRequest("GET", server.URL+"/api/v1/fleet/users", nil) client := &http.Client{} @@ -230,10 +228,3 @@ func TestNoHeaderErrorsDifferently(t *testing.T) { } `, string(bodyBytes)) } - -// an io.ReadCloser for new request body -type nopCloser struct { - io.Reader -} - -func (nopCloser) Close() error { return nil } diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go new file mode 100644 index 000000000000..3922225b700b --- /dev/null +++ b/server/service/integration_core_test.go @@ -0,0 +1,392 @@ +package service + +import ( + "context" + "fmt" + "io/ioutil" + "net/http" + "strconv" + "testing" + "time" + + "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/fleetdm/fleet/v4/server/ptr" + "github.com/ghodss/yaml" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type integrationTestSuite struct { + suite.Suite + + withServer +} + +func (s *integrationTestSuite) SetupSuite() { + s.withServer.SetupSuite("integrationTestSuite") +} + +func TestIntegrations(t *testing.T) { + testingSuite := new(integrationTestSuite) + testingSuite.s = &testingSuite.Suite + suite.Run(t, testingSuite) +} + +func (s *integrationTestSuite) TestDoubleUserCreationErrors() { + t := s.T() + + params := fleet.UserPayload{ + Name: ptr.String("user1"), + Email: ptr.String("email@asd.com"), + Password: ptr.String("pass"), + GlobalRole: ptr.String(fleet.RoleObserver), + } + + s.Do("POST", "/api/v1/fleet/users/admin", ¶ms, http.StatusOK) + respSecond := s.Do("POST", "/api/v1/fleet/users/admin", ¶ms, http.StatusConflict) + + assertBodyContains(t, respSecond, `Error 1062: Duplicate entry 'email@asd.com'`) +} + +func (s *integrationTestSuite) TestUserWithoutRoleErrors() { + t := s.T() + + params := fleet.UserPayload{ + Name: ptr.String("user1"), + Email: ptr.String("email@asd.com"), + Password: ptr.String("pass"), + } + + resp := s.Do("POST", "/api/v1/fleet/users/admin", ¶ms, http.StatusUnprocessableEntity) + assertErrorCodeAndMessage(t, resp, fleet.ErrNoRoleNeeded, "either global role or team role needs to be defined") +} + +func (s *integrationTestSuite) TestUserWithWrongRoleErrors() { + t := s.T() + + params := fleet.UserPayload{ + Name: ptr.String("user1"), + Email: ptr.String("email@asd.com"), + Password: ptr.String("pass"), + GlobalRole: ptr.String("wrongrole"), + } + resp := s.Do("POST", "/api/v1/fleet/users/admin", ¶ms, http.StatusUnprocessableEntity) + assertErrorCodeAndMessage(t, resp, fleet.ErrNoRoleNeeded, "GlobalRole role can only be admin, observer, or maintainer.") +} + +func (s *integrationTestSuite) TestUserCreationWrongTeamErrors() { + t := s.T() + + teams := []fleet.UserTeam{ + { + Team: fleet.Team{ + ID: 9999, + }, + Role: fleet.RoleObserver, + }, + } + + params := fleet.UserPayload{ + Name: ptr.String("user2"), + Email: ptr.String("email2@asd.com"), + Password: ptr.String("pass"), + Teams: &teams, + } + resp := s.Do("POST", "/api/v1/fleet/users/admin", ¶ms, http.StatusUnprocessableEntity) + assertBodyContains(t, resp, `Error 1452: Cannot add or update a child row: a foreign key constraint fails`) +} + +func (s *integrationTestSuite) TestQueryCreationLogsActivity() { + t := s.T() + + admin1 := s.users["admin1@example.com"] + admin1.GravatarURL = "http://iii.com" + err := s.ds.SaveUser(context.Background(), &admin1) + require.NoError(t, err) + + params := fleet.QueryPayload{ + Name: ptr.String("user1"), + Query: ptr.String("select * from time;"), + } + s.Do("POST", "/api/v1/fleet/queries", ¶ms, http.StatusOK) + + activities := listActivitiesResponse{} + s.DoJSON("GET", "/api/v1/fleet/activities", nil, http.StatusOK, &activities) + + assert.Len(t, activities.Activities, 1) + assert.Equal(t, "Test Name admin1@example.com", activities.Activities[0].ActorFullName) + require.NotNil(t, activities.Activities[0].ActorGravatar) + assert.Equal(t, "http://iii.com", *activities.Activities[0].ActorGravatar) + assert.Equal(t, "created_saved_query", activities.Activities[0].Type) +} +func (s *integrationTestSuite) TestAppConfigAdditionalQueriesCanBeRemoved() { + t := s.T() + + spec := []byte(` + host_expiry_settings: + host_expiry_enabled: true + host_expiry_window: 0 + host_settings: + additional_queries: + time: SELECT * FROM time + enable_host_users: true +`) + s.applyConfig(spec) + + spec = []byte(` + host_settings: + enable_host_users: true + additional_queries: null +`) + s.applyConfig(spec) + + config := s.getConfig() + assert.Nil(t, config.HostSettings.AdditionalQueries) + assert.True(t, config.HostExpirySettings.HostExpiryEnabled) +} + +func (s *integrationTestSuite) TestAppConfigDefaultValues() { + config := s.getConfig() + s.Run("Update interval", func() { + require.Equal(s.T(), 1*time.Hour, config.UpdateInterval.OSQueryDetail) + }) + + s.Run("has logging", func() { + require.NotNil(s.T(), config.Logging) + }) +} + +func (s *integrationTestSuite) TestUserRolesSpec() { + t := s.T() + + _, err := s.ds.NewTeam(context.Background(), &fleet.Team{ + ID: 42, + Name: "team1", + Description: "desc team1", + }) + require.NoError(t, err) + + email := t.Name() + "@asd.com" + u := &fleet.User{ + Password: []byte("asd"), + Name: t.Name(), + Email: email, + GravatarURL: "http://asd.com", + GlobalRole: ptr.String(fleet.RoleObserver), + } + user, err := s.ds.NewUser(context.Background(), u) + require.NoError(t, err) + assert.Len(t, user.Teams, 0) + + spec := []byte(fmt.Sprintf(` + roles: + %s: + global_role: null + teams: + - role: maintainer + team: team1 +`, + email)) + + var userRoleSpec applyUserRoleSpecsRequest + err = yaml.Unmarshal(spec, &userRoleSpec.Spec) + require.NoError(t, err) + + s.Do("POST", "/api/v1/fleet/users/roles/spec", &userRoleSpec, http.StatusOK) + + user, err = s.ds.UserByEmail(context.Background(), email) + require.NoError(t, err) + require.Len(t, user.Teams, 1) + assert.Equal(t, fleet.RoleMaintainer, user.Teams[0].Role) +} + +func (s *integrationTestSuite) TestGlobalSchedule() { + t := s.T() + + gs := fleet.GlobalSchedulePayload{} + s.DoJSON("GET", "/api/v1/fleet/global/schedule", nil, http.StatusOK, &gs) + require.Len(t, gs.GlobalSchedule, 0) + + qr, err := s.ds.NewQuery(context.Background(), &fleet.Query{ + Name: "TestQuery1", + Description: "Some description", + Query: "select * from osquery;", + ObserverCanRun: true, + }) + require.NoError(t, err) + + gsParams := fleet.ScheduledQueryPayload{QueryID: ptr.Uint(qr.ID), Interval: ptr.Uint(42)} + r := globalScheduleQueryResponse{} + s.DoJSON("POST", "/api/v1/fleet/global/schedule", gsParams, http.StatusOK, &r) + + gs = fleet.GlobalSchedulePayload{} + s.DoJSON("GET", "/api/v1/fleet/global/schedule", nil, http.StatusOK, &gs) + require.Len(t, gs.GlobalSchedule, 1) + assert.Equal(t, uint(42), gs.GlobalSchedule[0].Interval) + assert.Equal(t, "TestQuery1", gs.GlobalSchedule[0].Name) + id := gs.GlobalSchedule[0].ID + + gs = fleet.GlobalSchedulePayload{} + gsParams = fleet.ScheduledQueryPayload{Interval: ptr.Uint(55)} + s.DoJSON("PATCH", fmt.Sprintf("/api/v1/fleet/global/schedule/%d", id), gsParams, http.StatusOK, &gs) + + gs = fleet.GlobalSchedulePayload{} + s.DoJSON("GET", "/api/v1/fleet/global/schedule", nil, http.StatusOK, &gs) + require.Len(t, gs.GlobalSchedule, 1) + assert.Equal(t, uint(55), gs.GlobalSchedule[0].Interval) + + r = globalScheduleQueryResponse{} + s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/global/schedule/%d", id), nil, http.StatusOK, &r) + + gs = fleet.GlobalSchedulePayload{} + s.DoJSON("GET", "/api/v1/fleet/global/schedule", nil, http.StatusOK, &gs) + require.Len(t, gs.GlobalSchedule, 0) +} + +func (s *integrationTestSuite) TestTranslator() { + t := s.T() + + payload := translatorResponse{} + params := translatorRequest{List: []fleet.TranslatePayload{ + { + Type: fleet.TranslatorTypeUserEmail, + Payload: fleet.StringIdentifierToIDPayload{Identifier: "admin1@example.com"}, + }, + }} + s.DoJSON("POST", "/api/v1/fleet/translate", ¶ms, http.StatusOK, &payload) + require.Len(t, payload.List, 1) + + assert.Equal(t, s.users[payload.List[0].Payload.Identifier].ID, payload.List[0].Payload.ID) +} + +func (s *integrationTestSuite) TestVulnerableSoftware() { + t := s.T() + + host, err := s.ds.NewHost(context.Background(), &fleet.Host{ + DetailUpdatedAt: time.Now(), + LabelUpdatedAt: time.Now(), + SeenTime: time.Now(), + NodeKey: t.Name() + "1", + UUID: t.Name() + "1", + Hostname: "foo.local", + PrimaryIP: "192.168.1.1", + PrimaryMac: "30-65-EC-6F-C4-58", + }) + require.NoError(t, err) + require.NotNil(t, host) + + soft := fleet.HostSoftware{ + Modified: true, + Software: []fleet.Software{ + {Name: "foo", Version: "0.0.1", Source: "chrome_extensions"}, + {Name: "bar", Version: "0.0.3", Source: "apps"}, + }, + } + host.HostSoftware = soft + require.NoError(t, s.ds.SaveHostSoftware(context.Background(), host)) + require.NoError(t, s.ds.LoadHostSoftware(context.Background(), host)) + + soft1 := host.Software[0] + if soft1.Name != "bar" { + soft1 = host.Software[1] + } + + require.NoError(t, s.ds.AddCPEForSoftware(context.Background(), soft1, "somecpe")) + require.NoError(t, s.ds.InsertCVEForCPE(context.Background(), "cve-123-123-132", []string{"somecpe"})) + + resp := s.Do("GET", fmt.Sprintf("/api/v1/fleet/hosts/%d", host.ID), nil, http.StatusOK) + bodyBytes, err := ioutil.ReadAll(resp.Body) + require.NoError(t, err) + + expectedJSONSoft2 := `"name": "bar", + "version": "0.0.3", + "source": "apps", + "generated_cpe": "somecpe", + "vulnerabilities": [ + { + "cve": "cve-123-123-132", + "details_link": "https://nvd.nist.gov/vuln/detail/cve-123-123-132" + } + ]` + expectedJSONSoft1 := `"name": "foo", + "version": "0.0.1", + "source": "chrome_extensions", + "generated_cpe": "", + "vulnerabilities": null` + // We are doing Contains instead of equals to test the output for software in particular + // ignoring other things like timestamps and things that are outside the cope of this ticket + assert.Contains(t, string(bodyBytes), expectedJSONSoft2) + assert.Contains(t, string(bodyBytes), expectedJSONSoft1) +} + +func (s *integrationTestSuite) TestGlobalPolicies() { + t := s.T() + + for i := 0; i < 3; i++ { + _, err := s.ds.NewHost(context.Background(), &fleet.Host{ + DetailUpdatedAt: time.Now(), + LabelUpdatedAt: time.Now(), + SeenTime: time.Now().Add(-time.Duration(i) * time.Minute), + OsqueryHostID: strconv.Itoa(i), + NodeKey: fmt.Sprintf("%d", i), + UUID: fmt.Sprintf("%d", i), + Hostname: fmt.Sprintf("foo.local%d", i), + }) + require.NoError(t, err) + } + + qr, err := s.ds.NewQuery(context.Background(), &fleet.Query{ + Name: "TestQuery3", + Description: "Some description", + Query: "select * from osquery;", + ObserverCanRun: true, + }) + require.NoError(t, err) + + gpParams := globalPolicyRequest{QueryID: qr.ID} + gpResp := globalPolicyResponse{} + s.DoJSON("POST", "/api/v1/fleet/global/policies", gpParams, http.StatusOK, &gpResp) + require.NotNil(t, gpResp.Policy) + assert.Equal(t, qr.ID, gpResp.Policy.QueryID) + + policiesResponse := listGlobalPoliciesResponse{} + s.DoJSON("GET", "/api/v1/fleet/global/policies", nil, http.StatusOK, &policiesResponse) + require.Len(t, policiesResponse.Policies, 1) + assert.Equal(t, qr.ID, policiesResponse.Policies[0].QueryID) + + singlePolicyResponse := getPolicyByIDResponse{} + singlePolicyURL := fmt.Sprintf("/api/v1/fleet/global/policies/%d", policiesResponse.Policies[0].ID) + s.DoJSON("GET", singlePolicyURL, nil, http.StatusOK, &singlePolicyResponse) + assert.Equal(t, qr.ID, singlePolicyResponse.Policy.QueryID) + assert.Equal(t, qr.Name, singlePolicyResponse.Policy.QueryName) + + listHostsURL := fmt.Sprintf("/api/v1/fleet/hosts?policy_id=%d", policiesResponse.Policies[0].ID) + listHostsResp := listHostsResponse{} + s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp) + require.Len(t, listHostsResp.Hosts, 3) + + h1 := listHostsResp.Hosts[0] + h2 := listHostsResp.Hosts[1] + + listHostsURL = fmt.Sprintf("/api/v1/fleet/hosts?policy_id=%d&policy_response=passing", policiesResponse.Policies[0].ID) + listHostsResp = listHostsResponse{} + s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp) + require.Len(t, listHostsResp.Hosts, 0) + + require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: ptr.Bool(true)}, time.Now())) + require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now())) + + listHostsURL = fmt.Sprintf("/api/v1/fleet/hosts?policy_id=%d&policy_response=passing", policiesResponse.Policies[0].ID) + listHostsResp = listHostsResponse{} + s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp) + require.Len(t, listHostsResp.Hosts, 1) + + deletePolicyParams := deleteGlobalPoliciesRequest{IDs: []uint{policiesResponse.Policies[0].ID}} + deletePolicyResp := deleteGlobalPoliciesResponse{} + s.DoJSON("POST", "/api/v1/fleet/global/policies/delete", deletePolicyParams, http.StatusOK, &deletePolicyResp) + + policiesResponse = listGlobalPoliciesResponse{} + s.DoJSON("GET", "/api/v1/fleet/global/policies", nil, http.StatusOK, &policiesResponse) + require.Len(t, policiesResponse.Policies, 0) +} diff --git a/server/service/integration_ds_only_test.go b/server/service/integration_ds_only_test.go new file mode 100644 index 000000000000..780a281fb2b4 --- /dev/null +++ b/server/service/integration_ds_only_test.go @@ -0,0 +1,61 @@ +package service + +import ( + "net/http" + "testing" + "time" + + "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type integrationDSTestSuite struct { + withDS + suite.Suite +} + +func TestIntegrationDSTestSuite(t *testing.T) { + testingSuite := new(integrationDSTestSuite) + testingSuite.s = &testingSuite.Suite + suite.Run(t, testingSuite) +} + +func (s *integrationDSTestSuite) SetupSuite() { + s.withDS.SetupSuite("integrationDSTestSuite") +} + +func (s *integrationDSTestSuite) TestLicenseExpiration() { + testCases := []struct { + name string + tier string + expiration time.Time + shouldHaveHeader bool + }{ + {"basic expired", fleet.TierPremium, time.Now().Add(-24 * time.Hour), true}, + {"basic not expired", fleet.TierPremium, time.Now().Add(24 * time.Hour), false}, + {"core expired", fleet.TierFree, time.Now().Add(-24 * time.Hour), false}, + {"core not expired", fleet.TierFree, time.Now().Add(24 * time.Hour), false}, + } + + createTestUsers(s.T(), s.ds) + for _, tt := range testCases { + s.Run(tt.name, func() { + t := s.T() + + license := &fleet.LicenseInfo{Tier: tt.tier, Expiration: tt.expiration} + _, server := RunServerForTestsWithDS(t, s.ds, TestServerOpts{License: license, SkipCreateTestUsers: true}) + + ts := withServer{server: server} + ts.s = &s.Suite + ts.token = ts.getTestAdminToken() + + resp := ts.Do("GET", "/api/v1/fleet/config", nil, http.StatusOK) + if tt.shouldHaveHeader { + require.Equal(t, fleet.HeaderLicenseValueExpired, resp.Header.Get(fleet.HeaderLicenseKey)) + } else { + require.Equal(t, "", resp.Header.Get(fleet.HeaderLicenseKey)) + } + }) + } +} diff --git a/server/service/integration_enterprise_test.go b/server/service/integration_enterprise_test.go new file mode 100644 index 000000000000..53af9f01b67b --- /dev/null +++ b/server/service/integration_enterprise_test.go @@ -0,0 +1,142 @@ +package service + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "testing" + + "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/fleetdm/fleet/v4/server/ptr" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +func TestIntegrationsEnterprise(t *testing.T) { + testingSuite := new(integrationEnterpriseTestSuite) + testingSuite.s = &testingSuite.Suite + suite.Run(t, testingSuite) +} + +type integrationEnterpriseTestSuite struct { + withServer + suite.Suite +} + +func (s *integrationEnterpriseTestSuite) SetupSuite() { + s.withDS.SetupSuite("integrationEnterpriseTestSuite") + + users, server := RunServerForTestsWithDS( + s.T(), s.ds, TestServerOpts{License: &fleet.LicenseInfo{Tier: fleet.TierPremium}}) + s.server = server + s.users = users + s.token = s.getTestAdminToken() +} + +func (s *integrationEnterpriseTestSuite) TestTeamSpecs() { + t := s.T() + + // create a team through the service so it initializes the agent ops + teamName := t.Name() + "team1" + team := &fleet.Team{ + Name: teamName, + Description: "desc team1", + } + + s.Do("POST", "/api/v1/fleet/teams", team, http.StatusOK) + + // updates a team + agentOpts := json.RawMessage(`{"config": {"foo": "bar"}, "overrides": {"platforms": {"darwin": {"foo": "override"}}}}`) + teamSpecs := applyTeamSpecsRequest{Specs: []*fleet.TeamSpec{{Name: teamName, AgentOptions: &agentOpts}}} + s.Do("POST", "/api/v1/fleet/spec/teams", teamSpecs, http.StatusOK) + + team, err := s.ds.TeamByName(context.Background(), teamName) + require.NoError(t, err) + + assert.Len(t, team.Secrets, 0) + require.JSONEq(t, string(agentOpts), string(*team.AgentOptions)) + + // creates a team with default agent options + user, err := s.ds.UserByEmail(context.Background(), "admin1@example.com") + require.NoError(t, err) + + teams, err := s.ds.ListTeams(context.Background(), fleet.TeamFilter{User: user}, fleet.ListOptions{}) + require.NoError(t, err) + require.True(t, len(teams) >= 1) + + teamSpecs = applyTeamSpecsRequest{Specs: []*fleet.TeamSpec{{Name: "team2"}}} + s.Do("POST", "/api/v1/fleet/spec/teams", teamSpecs, http.StatusOK) + + teams, err = s.ds.ListTeams(context.Background(), fleet.TeamFilter{User: user}, fleet.ListOptions{}) + require.NoError(t, err) + assert.True(t, len(teams) >= 2) + + team, err = s.ds.TeamByName(context.Background(), "team2") + require.NoError(t, err) + + defaultOpts := `{"config": {"options": {"logger_plugin": "tls", "pack_delimiter": "/", "logger_tls_period": 10, "distributed_plugin": "tls", "disable_distributed": false, "logger_tls_endpoint": "/api/v1/osquery/log", "distributed_interval": 10, "distributed_tls_max_attempts": 3}, "decorators": {"load": ["SELECT uuid AS host_uuid FROM system_info;", "SELECT hostname AS hostname FROM system_info;"]}}, "overrides": {}}` + assert.Len(t, team.Secrets, 0) + require.NotNil(t, team.AgentOptions) + require.JSONEq(t, defaultOpts, string(*team.AgentOptions)) + + // updates secrets + teamSpecs = applyTeamSpecsRequest{Specs: []*fleet.TeamSpec{{Name: "team2", Secrets: []fleet.EnrollSecret{{Secret: "ABC"}}}}} + s.Do("POST", "/api/v1/fleet/spec/teams", teamSpecs, http.StatusOK) + + team, err = s.ds.TeamByName(context.Background(), "team2") + require.NoError(t, err) + + require.Len(t, team.Secrets, 1) + assert.Equal(t, "ABC", team.Secrets[0].Secret) +} + +func (s *integrationEnterpriseTestSuite) TestTeamSchedule() { + t := s.T() + + team1, err := s.ds.NewTeam(context.Background(), &fleet.Team{ + ID: 42, + Name: "team1", + Description: "desc team1", + }) + require.NoError(t, err) + + ts := getTeamScheduleResponse{} + s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/team/%d/schedule", team1.ID), nil, http.StatusOK, &ts) + require.Len(t, ts.Scheduled, 0) + + qr, err := s.ds.NewQuery(context.Background(), &fleet.Query{Name: "TestQuery2", Description: "Some description", Query: "select * from osquery;", ObserverCanRun: true}) + require.NoError(t, err) + + gsParams := teamScheduleQueryRequest{ScheduledQueryPayload: fleet.ScheduledQueryPayload{QueryID: &qr.ID, Interval: ptr.Uint(42)}} + r := teamScheduleQueryResponse{} + s.DoJSON("POST", fmt.Sprintf("/api/v1/fleet/team/%d/schedule", team1.ID), gsParams, http.StatusOK, &r) + + ts = getTeamScheduleResponse{} + s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/team/%d/schedule", team1.ID), nil, http.StatusOK, &ts) + require.Len(t, ts.Scheduled, 1) + assert.Equal(t, uint(42), ts.Scheduled[0].Interval) + assert.Equal(t, "TestQuery2", ts.Scheduled[0].Name) + assert.Equal(t, qr.ID, ts.Scheduled[0].QueryID) + id := ts.Scheduled[0].ID + + modifyResp := modifyTeamScheduleResponse{} + modifyParams := modifyTeamScheduleRequest{ScheduledQueryPayload: fleet.ScheduledQueryPayload{Interval: ptr.Uint(55)}} + s.DoJSON("PATCH", fmt.Sprintf("/api/v1/fleet/team/%d/schedule/%d", team1.ID, id), modifyParams, http.StatusOK, &modifyResp) + + // just to satisfy my paranoia, wanted to make sure the contents of the json would work + s.DoRaw("PATCH", fmt.Sprintf("/api/v1/fleet/team/%d/schedule/%d", team1.ID, id), []byte(`{"interval": 77}`), http.StatusOK) + + ts = getTeamScheduleResponse{} + s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/team/%d/schedule", team1.ID), nil, http.StatusOK, &ts) + require.Len(t, ts.Scheduled, 1) + assert.Equal(t, uint(77), ts.Scheduled[0].Interval) + + deleteResp := deleteTeamScheduleResponse{} + s.DoJSON("DELETE", fmt.Sprintf("/api/v1/fleet/team/%d/schedule/%d", team1.ID, id), nil, http.StatusOK, &deleteResp) + + ts = getTeamScheduleResponse{} + s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/team/%d/schedule", team1.ID), nil, http.StatusOK, &ts) + require.Len(t, ts.Scheduled, 0) +} diff --git a/server/service/integration_logger_test.go b/server/service/integration_logger_test.go new file mode 100644 index 000000000000..0816e6e8e8f9 --- /dev/null +++ b/server/service/integration_logger_test.go @@ -0,0 +1,183 @@ +package service + +import ( + "bytes" + "context" + "encoding/json" + "io" + "net/http" + "strings" + "testing" + "time" + + "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/fleetdm/fleet/v4/server/ptr" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/log/level" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +func TestIntegrationLoggerTestSuite(t *testing.T) { + testingSuite := new(integrationLoggerTestSuite) + testingSuite.s = &testingSuite.Suite + suite.Run(t, testingSuite) +} + +type integrationLoggerTestSuite struct { + withServer + suite.Suite + + buf *bytes.Buffer +} + +func (s *integrationLoggerTestSuite) SetupSuite() { + s.withDS.SetupSuite("integrationLoggerTestSuite") + + s.buf = new(bytes.Buffer) + logger := log.NewJSONLogger(s.buf) + logger = level.NewFilter(logger, level.AllowDebug()) + + users, server := RunServerForTestsWithDS(s.T(), s.ds, TestServerOpts{Logger: logger}) + s.server = server + s.users = users +} + +func (s *integrationLoggerTestSuite) TearDownTest() { + s.buf.Reset() +} + +func (s *integrationLoggerTestSuite) TestLogger() { + t := s.T() + + s.token = getTestAdminToken(t, s.server) + + s.getConfig() + + params := fleet.QueryPayload{ + Name: ptr.String("somequery"), + Description: ptr.String("desc"), + Query: ptr.String("select 1 from osquery;"), + } + payload := createQueryRequest{} + s.DoJSON("POST", "/api/v1/fleet/queries", params, http.StatusOK, &payload) + + logs := s.buf.String() + parts := strings.Split(strings.TrimSpace(logs), "\n") + assert.Len(t, parts, 3) + for i, part := range parts { + kv := make(map[string]string) + err := json.Unmarshal([]byte(part), &kv) + require.NoError(t, err) + + assert.NotEqual(t, "", kv["took"]) + + switch i { + case 0: + assert.Equal(t, "info", kv["level"]) + assert.Equal(t, "POST", kv["method"]) + assert.Equal(t, "/api/v1/fleet/login", kv["uri"]) + case 1: + assert.Equal(t, "debug", kv["level"]) + assert.Equal(t, "GET", kv["method"]) + assert.Equal(t, "/api/v1/fleet/config", kv["uri"]) + assert.Equal(t, "admin1@example.com", kv["user"]) + case 2: + assert.Equal(t, "info", kv["level"]) + assert.Equal(t, "POST", kv["method"]) + assert.Equal(t, "/api/v1/fleet/queries", kv["uri"]) + assert.Equal(t, "admin1@example.com", kv["user"]) + assert.Equal(t, "somequery", kv["name"]) + assert.Equal(t, "select 1 from osquery;", kv["sql"]) + default: + t.Fail() + } + } +} + +func (s *integrationLoggerTestSuite) TestOsqueryEndpointsLogErrors() { + t := s.T() + + _, err := s.ds.NewHost(context.Background(), &fleet.Host{ + DetailUpdatedAt: time.Now(), + LabelUpdatedAt: time.Now(), + SeenTime: time.Now(), + NodeKey: t.Name() + "1234", + UUID: "1", + Hostname: "foo.local", + OsqueryHostID: t.Name(), + PrimaryIP: "192.168.1.1", + PrimaryMac: "30-65-EC-6F-C4-58", + }) + require.NoError(t, err) + + requestBody := io.NopCloser(bytes.NewBuffer([]byte(`{"node_key":"1234","log_type":"status","data":[}`))) + req, _ := http.NewRequest("POST", s.server.URL+"/api/v1/osquery/log", requestBody) + client := &http.Client{} + _, err = client.Do(req) + require.Nil(t, err) + + logString := s.buf.String() + assert.Equal(t, `{"err":"decoding JSON: invalid character '}' looking for beginning of value","level":"info","path":"/api/v1/osquery/log"} +`, logString) +} + +func (s *integrationLoggerTestSuite) TestSubmitStatusLog() { + t := s.T() + + _, err := s.ds.NewHost(context.Background(), &fleet.Host{ + DetailUpdatedAt: time.Now(), + LabelUpdatedAt: time.Now(), + SeenTime: time.Now(), + NodeKey: t.Name() + "1234", + UUID: "1", + Hostname: "foo.local", + PrimaryIP: "192.168.1.1", + PrimaryMac: "30-65-EC-6F-C4-58", + OsqueryHostID: t.Name(), + }) + require.NoError(t, err) + + req := submitLogsRequest{ + NodeKey: "1234", + LogType: "status", + Data: nil, + } + res := submitLogsResponse{} + s.DoJSON("POST", "/api/v1/osquery/log", req, http.StatusOK, &res) + + logString := s.buf.String() + assert.Equal(t, 1, strings.Count(logString, "\"ip_addr\"")) + assert.Equal(t, 1, strings.Count(logString, "x_for_ip_addr")) +} + +func (s *integrationLoggerTestSuite) TestEnrollAgentLogsErrors() { + t := s.T() + _, err := s.ds.NewHost(context.Background(), &fleet.Host{ + DetailUpdatedAt: time.Now(), + LabelUpdatedAt: time.Now(), + SeenTime: time.Now(), + NodeKey: "1234", + UUID: "1", + Hostname: "foo.local", + PrimaryIP: "192.168.1.1", + PrimaryMac: "30-65-EC-6F-C4-58", + }) + require.NoError(t, err) + + j, err := json.Marshal(&enrollAgentRequest{ + EnrollSecret: "1234", + HostIdentifier: "4321", + HostDetails: nil, + }) + require.NoError(t, err) + + s.DoRawNoAuth("POST", "/api/v1/osquery/enroll", j, http.StatusUnauthorized) + + parts := strings.Split(strings.TrimSpace(s.buf.String()), "\n") + require.Len(t, parts, 1) + logData := make(map[string]json.RawMessage) + require.NoError(t, json.Unmarshal([]byte(parts[0]), &logData)) + assert.Equal(t, json.RawMessage(`["enroll failed: no matching secret found"]`), logData["err"]) +} diff --git a/server/service/integration_test.go b/server/service/integration_test.go deleted file mode 100644 index ca24a626c21a..000000000000 --- a/server/service/integration_test.go +++ /dev/null @@ -1,931 +0,0 @@ -package service - -import ( - "bytes" - "context" - "encoding/json" - "fmt" - "io/ioutil" - "net/http" - "net/http/httptest" - "strconv" - "strings" - "testing" - "time" - - "github.com/fleetdm/fleet/v4/server/datastore/mysql" - "github.com/fleetdm/fleet/v4/server/fleet" - "github.com/fleetdm/fleet/v4/server/ptr" - "github.com/fleetdm/fleet/v4/server/test" - "github.com/ghodss/yaml" - "github.com/go-kit/kit/log" - "github.com/go-kit/kit/log/level" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestDoubleUserCreationErrors(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - defer ds.Close() - - _, server := RunServerForTestsWithDS(t, ds) - defer server.Close() - token := getTestAdminToken(t, server) - - params := fleet.UserPayload{ - Name: ptr.String("user1"), - Email: ptr.String("email@asd.com"), - Password: ptr.String("pass"), - GlobalRole: ptr.String(fleet.RoleObserver), - } - j, err := json.Marshal(¶ms) - assert.Nil(t, err) - - requestBody := &nopCloser{bytes.NewBuffer(j)} - req, _ := http.NewRequest("POST", server.URL+"/api/v1/fleet/users/admin", requestBody) - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token)) - client := &http.Client{} - resp, err := client.Do(req) - require.Nil(t, err) - assert.Equal(t, http.StatusOK, resp.StatusCode) - - requestBody = &nopCloser{bytes.NewBuffer(j)} - req, _ = http.NewRequest("POST", server.URL+"/api/v1/fleet/users/admin", requestBody) - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token)) - resp, err = client.Do(req) - require.Nil(t, err) - assert.Equal(t, http.StatusConflict, resp.StatusCode) - assertBodyContains(t, resp, `Error 1062: Duplicate entry 'email@asd.com'`) -} - -func TestUserWithoutRoleErrors(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - defer ds.Close() - - _, server := RunServerForTestsWithDS(t, ds) - defer server.Close() - token := getTestAdminToken(t, server) - - params := fleet.UserPayload{ - Name: ptr.String("user1"), - Email: ptr.String("email@asd.com"), - Password: ptr.String("pass"), - } - j, err := json.Marshal(¶ms) - assert.Nil(t, err) - - requestBody := &nopCloser{bytes.NewBuffer(j)} - req, _ := http.NewRequest("POST", server.URL+"/api/v1/fleet/users/admin", requestBody) - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token)) - client := &http.Client{} - resp, err := client.Do(req) - require.Nil(t, err) - assert.Equal(t, http.StatusUnprocessableEntity, resp.StatusCode) - assertErrorCodeAndMessage(t, resp, fleet.ErrNoRoleNeeded, "either global role or team role needs to be defined") -} - -func TestUserWithWrongRoleErrors(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - defer ds.Close() - - _, server := RunServerForTestsWithDS(t, ds) - defer server.Close() - token := getTestAdminToken(t, server) - - params := fleet.UserPayload{ - Name: ptr.String("user1"), - Email: ptr.String("email@asd.com"), - Password: ptr.String("pass"), - GlobalRole: ptr.String("wrongrole"), - } - j, err := json.Marshal(¶ms) - assert.Nil(t, err) - - requestBody := &nopCloser{bytes.NewBuffer(j)} - req, _ := http.NewRequest("POST", server.URL+"/api/v1/fleet/users/admin", requestBody) - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token)) - client := &http.Client{} - resp, err := client.Do(req) - require.Nil(t, err) - assert.Equal(t, http.StatusUnprocessableEntity, resp.StatusCode) - assertErrorCodeAndMessage(t, resp, fleet.ErrNoRoleNeeded, "GlobalRole role can only be admin, observer, or maintainer.") -} - -func TestUserCreationWrongTeamErrors(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - defer ds.Close() - - _, server := RunServerForTestsWithDS(t, ds) - defer server.Close() - token := getTestAdminToken(t, server) - - teams := []fleet.UserTeam{ - { - Team: fleet.Team{ - ID: 9999, - }, - Role: fleet.RoleObserver, - }, - } - - params := fleet.UserPayload{ - Name: ptr.String("user1"), - Email: ptr.String("email@asd.com"), - Password: ptr.String("pass"), - Teams: &teams, - } - method := "POST" - path := "/api/v1/fleet/users/admin" - expectedStatusCode := http.StatusUnprocessableEntity - - resp, closeFunc := doReq(t, params, method, server, path, token, expectedStatusCode) - defer closeFunc() - assertBodyContains(t, resp, `Error 1452: Cannot add or update a child row: a foreign key constraint fails`) -} - -func doReq( - t *testing.T, - params interface{}, - method string, - server *httptest.Server, - path string, - token string, - expectedStatusCode int, -) (*http.Response, func()) { - j, err := json.Marshal(¶ms) - assert.Nil(t, err) - - requestBody := &nopCloser{bytes.NewBuffer(j)} - req, _ := http.NewRequest(method, server.URL+path, requestBody) - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token)) - client := &http.Client{} - resp, err := client.Do(req) - require.Nil(t, err) - assert.Equal(t, expectedStatusCode, resp.StatusCode) - return resp, func() { - thisResp := resp - thisResp.Body.Close() - } -} - -func doRawReq( - t *testing.T, - body []byte, - method string, - server *httptest.Server, - path string, - token string, - expectedStatusCode int, -) *http.Response { - requestBody := &nopCloser{bytes.NewBuffer(body)} - req, _ := http.NewRequest(method, server.URL+path, requestBody) - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", token)) - client := &http.Client{} - resp, err := client.Do(req) - require.Nil(t, err) - assert.Equal(t, expectedStatusCode, resp.StatusCode) - return resp -} - -func doJSONReq( - t *testing.T, - params interface{}, - method string, - server *httptest.Server, - path string, - token string, - expectedStatusCode int, - v interface{}, -) { - resp, closeFunc := doReq(t, params, method, server, path, token, expectedStatusCode) - defer closeFunc() - err := json.NewDecoder(resp.Body).Decode(v) - require.Nil(t, err) -} - -func assertBodyContains(t *testing.T, resp *http.Response, expectedError string) { - bodyBytes, err := ioutil.ReadAll(resp.Body) - require.Nil(t, err) - bodyString := string(bodyBytes) - assert.Contains(t, bodyString, expectedError) -} - -func TestQueryCreationLogsActivity(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - defer ds.Close() - - users, server := RunServerForTestsWithDS(t, ds) - defer server.Close() - token := getTestAdminToken(t, server) - - admin1 := users["admin1@example.com"] - admin1.GravatarURL = "http://iii.com" - err := ds.SaveUser(context.Background(), &admin1) - require.NoError(t, err) - - params := fleet.QueryPayload{ - Name: ptr.String("user1"), - Query: ptr.String("select * from time;"), - } - _, closeFunc := doReq(t, params, "POST", server, "/api/v1/fleet/queries", token, http.StatusOK) - defer closeFunc() - type activitiesRespose struct { - Activities []map[string]interface{} `json:"activities"` - } - activities := activitiesRespose{} - doJSONReq(t, nil, "GET", server, "/api/v1/fleet/activities", token, http.StatusOK, &activities) - - assert.Len(t, activities.Activities, 1) - assert.Equal(t, "Test Name admin1@example.com", activities.Activities[0]["actor_full_name"]) - assert.Equal(t, "http://iii.com", activities.Activities[0]["actor_gravatar"]) - assert.Equal(t, "created_saved_query", activities.Activities[0]["type"]) -} - -func getJSON(r *http.Response, target interface{}) error { - return json.NewDecoder(r.Body).Decode(target) -} - -func assertErrorCodeAndMessage(t *testing.T, resp *http.Response, code int, message string) { - err := &fleet.Error{} - require.Nil(t, getJSON(resp, err)) - assert.Equal(t, code, err.Code) - assert.Equal(t, message, err.Message) -} - -func TestAppConfigAdditionalQueriesCanBeRemoved(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - defer ds.Close() - - _, server := RunServerForTestsWithDS(t, ds) - defer server.Close() - token := getTestAdminToken(t, server) - - spec := []byte(` - host_expiry_settings: - host_expiry_enabled: true - host_expiry_window: 0 - host_settings: - additional_queries: - time: SELECT * FROM time - enable_host_users: true -`) - applyConfig(t, spec, server, token) - - spec = []byte(` - host_settings: - enable_host_users: true - additional_queries: null -`) - applyConfig(t, spec, server, token) - - config := getConfig(t, server, token) - assert.Nil(t, config.HostSettings.AdditionalQueries) - assert.True(t, config.HostExpirySettings.HostExpiryEnabled) -} - -func TestAppConfigUpdateInterval(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - defer ds.Close() - - _, server := RunServerForTestsWithDS(t, ds) - defer server.Close() - token := getTestAdminToken(t, server) - - config := getConfig(t, server, token) - require.Equal(t, 1*time.Hour, config.UpdateInterval.OSQueryDetail) -} - -func TestAppConfigHasLogging(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - defer ds.Close() - - _, server := RunServerForTestsWithDS(t, ds) - defer server.Close() - token := getTestAdminToken(t, server) - - config := getConfig(t, server, token) - require.NotNil(t, config.Logging) -} - -func applyConfig(t *testing.T, spec []byte, server *httptest.Server, token string) { - var appConfigSpec interface{} - err := yaml.Unmarshal(spec, &appConfigSpec) - require.NoError(t, err) - - _, closeFunc := doReq(t, appConfigSpec, "PATCH", server, "/api/v1/fleet/config", token, http.StatusOK) - closeFunc() -} - -func getConfig(t *testing.T, server *httptest.Server, token string) *appConfigResponse { - var responseBody *appConfigResponse - doJSONReq(t, nil, "GET", server, "/api/v1/fleet/config", token, http.StatusOK, &responseBody) - return responseBody -} - -func TestUserRolesSpec(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - defer ds.Close() - - _, server := RunServerForTestsWithDS(t, ds) - defer server.Close() - _, err := ds.NewTeam(context.Background(), &fleet.Team{ - ID: 42, - Name: "team1", - Description: "desc team1", - }) - require.NoError(t, err) - token := getTestAdminToken(t, server) - - user, err := ds.UserByEmail(context.Background(), "user1@example.com") - require.NoError(t, err) - assert.Len(t, user.Teams, 0) - - spec := []byte(` - roles: - user1@example.com: - global_role: null - teams: - - role: maintainer - team: team1 -`) - - var userRoleSpec applyUserRoleSpecsRequest - err = yaml.Unmarshal(spec, &userRoleSpec.Spec) - require.NoError(t, err) - - _, closeFunc := doReq(t, userRoleSpec, "POST", server, "/api/v1/fleet/users/roles/spec", token, http.StatusOK) - closeFunc() - - user, err = ds.UserByEmail(context.Background(), "user1@example.com") - require.NoError(t, err) - require.Len(t, user.Teams, 1) - assert.Equal(t, fleet.RoleMaintainer, user.Teams[0].Role) - - // But users are not deleted - users, err := ds.ListUsers(context.Background(), fleet.UserListOptions{}) - require.NoError(t, err) - assert.Len(t, users, 3) -} - -func TestGlobalSchedule(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - defer ds.Close() - - test.AddAllHostsLabel(t, ds) - - _, server := RunServerForTestsWithDS(t, ds) - defer server.Close() - token := getTestAdminToken(t, server) - - gs := fleet.GlobalSchedulePayload{} - doJSONReq(t, nil, "GET", server, "/api/v1/fleet/global/schedule", token, http.StatusOK, &gs) - assert.Len(t, gs.GlobalSchedule, 0) - - qr, err := ds.NewQuery(context.Background(), &fleet.Query{ - Name: "TestQuery", - Description: "Some description", - Query: "select * from osquery;", - ObserverCanRun: true, - }) - require.NoError(t, err) - - gsParams := fleet.ScheduledQueryPayload{QueryID: ptr.Uint(qr.ID), Interval: ptr.Uint(42)} - r := globalScheduleQueryResponse{} - doJSONReq(t, gsParams, "POST", server, "/api/v1/fleet/global/schedule", token, http.StatusOK, &r) - require.Nil(t, r.Err) - - gs = fleet.GlobalSchedulePayload{} - doJSONReq(t, nil, "GET", server, "/api/v1/fleet/global/schedule", token, http.StatusOK, &gs) - require.Len(t, gs.GlobalSchedule, 1) - assert.Equal(t, uint(42), gs.GlobalSchedule[0].Interval) - assert.Equal(t, "TestQuery", gs.GlobalSchedule[0].Name) - id := gs.GlobalSchedule[0].ID - - gs = fleet.GlobalSchedulePayload{} - gsParams = fleet.ScheduledQueryPayload{Interval: ptr.Uint(55)} - doJSONReq( - t, gsParams, "PATCH", server, - fmt.Sprintf("/api/v1/fleet/global/schedule/%d", id), - token, http.StatusOK, &gs, - ) - - gs = fleet.GlobalSchedulePayload{} - doJSONReq(t, nil, "GET", server, "/api/v1/fleet/global/schedule", token, http.StatusOK, &gs) - require.Len(t, gs.GlobalSchedule, 1) - assert.Equal(t, uint(55), gs.GlobalSchedule[0].Interval) - - r = globalScheduleQueryResponse{} - doJSONReq( - t, nil, "DELETE", server, - fmt.Sprintf("/api/v1/fleet/global/schedule/%d", id), - token, http.StatusOK, &r, - ) - require.Nil(t, r.Err) - - gs = fleet.GlobalSchedulePayload{} - doJSONReq(t, nil, "GET", server, "/api/v1/fleet/global/schedule", token, http.StatusOK, &gs) - require.Len(t, gs.GlobalSchedule, 0) -} - -func TestTeamSpecs(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - defer ds.Close() - - _, server := RunServerForTestsWithDS(t, ds, TestServerOpts{License: &fleet.LicenseInfo{Tier: fleet.TierPremium}}) - defer server.Close() - token := getTestAdminToken(t, server) - - // create a team through the service so it initializes the agent ops - team := &fleet.Team{ - Name: "team1", - Description: "desc team1", - } - _, closeFunc := doReq(t, team, "POST", server, "/api/v1/fleet/teams", token, http.StatusOK) - defer closeFunc() - - // updates a team - agentOpts := json.RawMessage(`{"config": {"foo": "bar"}, "overrides": {"platforms": {"darwin": {"foo": "override"}}}}`) - teamSpecs := applyTeamSpecsRequest{Specs: []*fleet.TeamSpec{{Name: "team1", AgentOptions: &agentOpts}}} - _, closeFunc = doReq(t, teamSpecs, "POST", server, "/api/v1/fleet/spec/teams", token, http.StatusOK) - defer closeFunc() - - team, err := ds.TeamByName(context.Background(), "team1") - require.NoError(t, err) - - assert.Len(t, team.Secrets, 0) - require.JSONEq(t, string(agentOpts), string(*team.AgentOptions)) - - // creates a team with default agent options - user, err := ds.UserByEmail(context.Background(), "admin1@example.com") - require.NoError(t, err) - - teams, err := ds.ListTeams(context.Background(), fleet.TeamFilter{User: user}, fleet.ListOptions{}) - require.NoError(t, err) - assert.Len(t, teams, 1) - - teamSpecs = applyTeamSpecsRequest{Specs: []*fleet.TeamSpec{{Name: "team2"}}} - _, closeFunc = doReq(t, teamSpecs, "POST", server, "/api/v1/fleet/spec/teams", token, http.StatusOK) - defer closeFunc() - - teams, err = ds.ListTeams(context.Background(), fleet.TeamFilter{User: user}, fleet.ListOptions{}) - require.NoError(t, err) - assert.Len(t, teams, 2) - - team, err = ds.TeamByName(context.Background(), "team2") - require.NoError(t, err) - - defaultOpts := `{"config": {"options": {"logger_plugin": "tls", "pack_delimiter": "/", "logger_tls_period": 10, "distributed_plugin": "tls", "disable_distributed": false, "logger_tls_endpoint": "/api/v1/osquery/log", "distributed_interval": 10, "distributed_tls_max_attempts": 3}, "decorators": {"load": ["SELECT uuid AS host_uuid FROM system_info;", "SELECT hostname AS hostname FROM system_info;"]}}, "overrides": {}}` - assert.Len(t, team.Secrets, 0) - require.NotNil(t, team.AgentOptions) - require.JSONEq(t, defaultOpts, string(*team.AgentOptions)) - - // updates secrets - teamSpecs = applyTeamSpecsRequest{Specs: []*fleet.TeamSpec{{Name: "team2", Secrets: []fleet.EnrollSecret{{Secret: "ABC"}}}}} - _, closeFunc = doReq(t, teamSpecs, "POST", server, "/api/v1/fleet/spec/teams", token, http.StatusOK) - defer closeFunc() - - team, err = ds.TeamByName(context.Background(), "team2") - require.NoError(t, err) - - require.Len(t, team.Secrets, 1) - assert.Equal(t, "ABC", team.Secrets[0].Secret) -} - -func TestTranslator(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - defer ds.Close() - - users, server := RunServerForTestsWithDS(t, ds) - defer server.Close() - token := getTestAdminToken(t, server) - - payload := translatorResponse{} - params := translatorRequest{List: []fleet.TranslatePayload{ - { - Type: fleet.TranslatorTypeUserEmail, - Payload: fleet.StringIdentifierToIDPayload{Identifier: "admin1@example.com"}, - }, - }} - doJSONReq(t, ¶ms, "POST", server, "/api/v1/fleet/translate", token, http.StatusOK, &payload) - - require.Nil(t, payload.Err) - assert.Len(t, payload.List, 1) - - assert.Equal(t, users[payload.List[0].Payload.Identifier].ID, payload.List[0].Payload.ID) -} - -func TestTeamSchedule(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - defer ds.Close() - - test.AddAllHostsLabel(t, ds) - - _, server := RunServerForTestsWithDS(t, ds) - defer server.Close() - token := getTestAdminToken(t, server) - - team1, err := ds.NewTeam(context.Background(), &fleet.Team{ - ID: 42, - Name: "team1", - Description: "desc team1", - }) - require.NoError(t, err) - - ts := getTeamScheduleResponse{} - doJSONReq(t, nil, "GET", server, fmt.Sprintf("/api/v1/fleet/team/%d/schedule", team1.ID), token, http.StatusOK, &ts) - assert.Len(t, ts.Scheduled, 0) - - qr, err := ds.NewQuery(context.Background(), &fleet.Query{Name: "TestQuery", Description: "Some description", Query: "select * from osquery;", ObserverCanRun: true}) - require.NoError(t, err) - - gsParams := teamScheduleQueryRequest{ScheduledQueryPayload: fleet.ScheduledQueryPayload{QueryID: &qr.ID, Interval: ptr.Uint(42)}} - r := teamScheduleQueryResponse{} - doJSONReq(t, gsParams, "POST", server, fmt.Sprintf("/api/v1/fleet/team/%d/schedule", team1.ID), token, http.StatusOK, &r) - require.Nil(t, r.Err) - - ts = getTeamScheduleResponse{} - doJSONReq(t, nil, "GET", server, fmt.Sprintf("/api/v1/fleet/team/%d/schedule", team1.ID), token, http.StatusOK, &ts) - require.Len(t, ts.Scheduled, 1) - assert.Equal(t, uint(42), ts.Scheduled[0].Interval) - assert.Equal(t, "TestQuery", ts.Scheduled[0].Name) - assert.Equal(t, qr.ID, ts.Scheduled[0].QueryID) - id := ts.Scheduled[0].ID - - modifyResp := modifyTeamScheduleResponse{} - modifyParams := modifyTeamScheduleRequest{ScheduledQueryPayload: fleet.ScheduledQueryPayload{Interval: ptr.Uint(55)}} - doJSONReq( - t, modifyParams, "PATCH", server, - fmt.Sprintf("/api/v1/fleet/team/%d/schedule/%d", team1.ID, id), - token, http.StatusOK, &modifyResp, - ) - - // just to satisfy my paranoia, wanted to make sure the contents of the json would work - doRawReq(t, []byte(`{"interval": 77}`), "PATCH", server, - fmt.Sprintf("/api/v1/fleet/team/%d/schedule/%d", team1.ID, id), - token, http.StatusOK) - - ts = getTeamScheduleResponse{} - doJSONReq(t, nil, "GET", server, fmt.Sprintf("/api/v1/fleet/team/%d/schedule", team1.ID), token, http.StatusOK, &ts) - assert.Len(t, ts.Scheduled, 1) - assert.Equal(t, uint(77), ts.Scheduled[0].Interval) - - deleteResp := deleteTeamScheduleResponse{} - doJSONReq( - t, nil, "DELETE", server, - fmt.Sprintf("/api/v1/fleet/team/%d/schedule/%d", team1.ID, id), - token, http.StatusOK, &deleteResp, - ) - require.Nil(t, r.Err) - - ts = getTeamScheduleResponse{} - doJSONReq(t, nil, "GET", server, fmt.Sprintf("/api/v1/fleet/team/%d/schedule", team1.ID), token, http.StatusOK, &ts) - assert.Len(t, ts.Scheduled, 0) -} - -func TestLogger(t *testing.T) { - buf := new(bytes.Buffer) - logger := log.NewJSONLogger(buf) - logger = level.NewFilter(logger, level.AllowDebug()) - - ds := mysql.CreateMySQLDS(t) - defer ds.Close() - - _, server := RunServerForTestsWithDS(t, ds, TestServerOpts{Logger: logger}) - defer server.Close() - token := getTestAdminToken(t, server) - - getConfig(t, server, token) - params := fleet.QueryPayload{ - Name: ptr.String("somequery"), - Description: ptr.String("desc"), - Query: ptr.String("select 1 from osquery;"), - } - payload := createQueryRequest{} - doJSONReq(t, params, "POST", server, "/api/v1/fleet/queries", token, http.StatusOK, &payload) - - logs := buf.String() - parts := strings.Split(strings.TrimSpace(logs), "\n") - assert.Len(t, parts, 3) - for i, part := range parts { - kv := make(map[string]string) - err := json.Unmarshal([]byte(part), &kv) - require.NoError(t, err) - - assert.NotEqual(t, "", kv["took"]) - - switch i { - case 0: - assert.Equal(t, "info", kv["level"]) - assert.Equal(t, "POST", kv["method"]) - assert.Equal(t, "/api/v1/fleet/login", kv["uri"]) - case 1: - assert.Equal(t, "debug", kv["level"]) - assert.Equal(t, "GET", kv["method"]) - assert.Equal(t, "/api/v1/fleet/config", kv["uri"]) - assert.Equal(t, "admin1@example.com", kv["user"]) - case 2: - assert.Equal(t, "info", kv["level"]) - assert.Equal(t, "POST", kv["method"]) - assert.Equal(t, "/api/v1/fleet/queries", kv["uri"]) - assert.Equal(t, "admin1@example.com", kv["user"]) - assert.Equal(t, "somequery", kv["name"]) - assert.Equal(t, "select 1 from osquery;", kv["sql"]) - default: - t.Fail() - } - } -} - -func TestVulnerableSoftware(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - defer ds.Close() - - _, server := RunServerForTestsWithDS(t, ds) - defer server.Close() - token := getTestAdminToken(t, server) - - host, err := ds.NewHost(context.Background(), &fleet.Host{ - DetailUpdatedAt: time.Now(), - LabelUpdatedAt: time.Now(), - SeenTime: time.Now(), - NodeKey: "1", - UUID: "1", - Hostname: "foo.local", - PrimaryIP: "192.168.1.1", - PrimaryMac: "30-65-EC-6F-C4-58", - }) - require.NoError(t, err) - require.NotNil(t, host) - - soft := fleet.HostSoftware{ - Modified: true, - Software: []fleet.Software{ - {Name: "foo", Version: "0.0.1", Source: "chrome_extensions"}, - {Name: "bar", Version: "0.0.3", Source: "apps"}, - }, - } - host.HostSoftware = soft - require.NoError(t, ds.SaveHostSoftware(context.Background(), host)) - require.NoError(t, ds.LoadHostSoftware(context.Background(), host)) - - soft1 := host.Software[0] - if soft1.Name != "bar" { - soft1 = host.Software[1] - } - - require.NoError(t, ds.AddCPEForSoftware(context.Background(), soft1, "somecpe")) - require.NoError(t, ds.InsertCVEForCPE(context.Background(), "cve-123-123-132", []string{"somecpe"})) - - path := fmt.Sprintf("/api/v1/fleet/hosts/%d", host.ID) - resp, closeFunc := doReq(t, nil, "GET", server, path, token, http.StatusOK) - defer closeFunc() - bodyBytes, err := ioutil.ReadAll(resp.Body) - require.NoError(t, err) - - expectedJSONSoft2 := `"name": "bar", - "version": "0.0.3", - "source": "apps", - "generated_cpe": "somecpe", - "vulnerabilities": [ - { - "cve": "cve-123-123-132", - "details_link": "https://nvd.nist.gov/vuln/detail/cve-123-123-132" - } - ]` - expectedJSONSoft1 := `"name": "foo", - "version": "0.0.1", - "source": "chrome_extensions", - "generated_cpe": "", - "vulnerabilities": null` - // We are doing Contains instead of equals to test the output for software in particular - // ignoring other things like timestamps and things that are outside the cope of this ticket - assert.Contains(t, string(bodyBytes), expectedJSONSoft2) - assert.Contains(t, string(bodyBytes), expectedJSONSoft1) -} - -func TestGlobalPolicies(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - defer ds.Close() - - _, server := RunServerForTestsWithDS(t, ds) - defer server.Close() - token := getTestAdminToken(t, server) - - for i := 0; i < 3; i++ { - _, err := ds.NewHost(context.Background(), &fleet.Host{ - DetailUpdatedAt: time.Now(), - LabelUpdatedAt: time.Now(), - SeenTime: time.Now().Add(-time.Duration(i) * time.Minute), - OsqueryHostID: strconv.Itoa(i), - NodeKey: fmt.Sprintf("%d", i), - UUID: fmt.Sprintf("%d", i), - Hostname: fmt.Sprintf("foo.local%d", i), - }) - require.NoError(t, err) - } - - qr, err := ds.NewQuery(context.Background(), &fleet.Query{ - Name: "TestQuery", - Description: "Some description", - Query: "select * from osquery;", - ObserverCanRun: true, - }) - require.NoError(t, err) - - gpParams := globalPolicyRequest{QueryID: qr.ID} - gpResp := globalPolicyResponse{} - doJSONReq(t, gpParams, "POST", server, "/api/v1/fleet/global/policies", token, http.StatusOK, &gpResp) - require.NotNil(t, gpResp.Policy) - assert.Equal(t, qr.ID, gpResp.Policy.QueryID) - - policiesResponse := listGlobalPoliciesResponse{} - doJSONReq(t, nil, "GET", server, "/api/v1/fleet/global/policies", token, http.StatusOK, &policiesResponse) - require.Len(t, policiesResponse.Policies, 1) - assert.Equal(t, qr.ID, policiesResponse.Policies[0].QueryID) - - singlePolicyResponse := getPolicyByIDResponse{} - singlePolicyURL := fmt.Sprintf("/api/v1/fleet/global/policies/%d", policiesResponse.Policies[0].ID) - doJSONReq(t, nil, "GET", server, singlePolicyURL, token, http.StatusOK, &singlePolicyResponse) - assert.Equal(t, qr.ID, singlePolicyResponse.Policy.QueryID) - assert.Equal(t, qr.Name, singlePolicyResponse.Policy.QueryName) - - listHostsURL := fmt.Sprintf("/api/v1/fleet/hosts?policy_id=%d", policiesResponse.Policies[0].ID) - listHostsResp := listHostsResponse{} - doJSONReq(t, nil, "GET", server, listHostsURL, token, http.StatusOK, &listHostsResp) - require.Len(t, listHostsResp.Hosts, 3) - - h1 := listHostsResp.Hosts[0] - h2 := listHostsResp.Hosts[1] - - listHostsURL = fmt.Sprintf("/api/v1/fleet/hosts?policy_id=%d&policy_response=passing", policiesResponse.Policies[0].ID) - listHostsResp = listHostsResponse{} - doJSONReq(t, nil, "GET", server, listHostsURL, token, http.StatusOK, &listHostsResp) - require.Len(t, listHostsResp.Hosts, 0) - - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: ptr.Bool(true)}, time.Now())) - require.NoError(t, ds.RecordPolicyQueryExecutions(context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now())) - - listHostsURL = fmt.Sprintf("/api/v1/fleet/hosts?policy_id=%d&policy_response=passing", policiesResponse.Policies[0].ID) - listHostsResp = listHostsResponse{} - doJSONReq(t, nil, "GET", server, listHostsURL, token, http.StatusOK, &listHostsResp) - require.Len(t, listHostsResp.Hosts, 1) - - deletePolicyParams := deleteGlobalPoliciesRequest{IDs: []uint{policiesResponse.Policies[0].ID}} - deletePolicyResp := deleteGlobalPoliciesResponse{} - doJSONReq(t, deletePolicyParams, "POST", server, "/api/v1/fleet/global/policies/delete", token, http.StatusOK, &deletePolicyResp) - - policiesResponse = listGlobalPoliciesResponse{} - doJSONReq(t, nil, "GET", server, "/api/v1/fleet/global/policies", token, http.StatusOK, &policiesResponse) - require.Len(t, policiesResponse.Policies, 0) -} - -func TestOsqueryEndpointsLogErrors(t *testing.T) { - buf := new(bytes.Buffer) - logger := log.NewJSONLogger(buf) - logger = level.NewFilter(logger, level.AllowDebug()) - - ds := mysql.CreateMySQLDS(t) - defer ds.Close() - - _, server := RunServerForTestsWithDS(t, ds, TestServerOpts{Logger: logger}) - defer server.Close() - - _, err := ds.NewHost(context.Background(), &fleet.Host{ - DetailUpdatedAt: time.Now(), - LabelUpdatedAt: time.Now(), - SeenTime: time.Now(), - NodeKey: "1234", - UUID: "1", - Hostname: "foo.local", - PrimaryIP: "192.168.1.1", - PrimaryMac: "30-65-EC-6F-C4-58", - }) - require.NoError(t, err) - - requestBody := &nopCloser{bytes.NewBuffer([]byte(`{"node_key":"1234","log_type":"status","data":[}`))} - req, _ := http.NewRequest("POST", server.URL+"/api/v1/osquery/log", requestBody) - client := &http.Client{} - _, err = client.Do(req) - require.Nil(t, err) - - logString := buf.String() - assert.Equal(t, `{"err":"decoding JSON: invalid character '}' looking for beginning of value","level":"info","path":"/api/v1/osquery/log"} -`, logString) -} - -func TestSubmitStatusLog(t *testing.T) { - buf := new(bytes.Buffer) - logger := log.NewJSONLogger(buf) - logger = level.NewFilter(logger, level.AllowDebug()) - - ds := mysql.CreateMySQLDS(t) - defer ds.Close() - - _, server := RunServerForTestsWithDS(t, ds, TestServerOpts{Logger: logger}) - defer server.Close() - token := getTestAdminToken(t, server) - - _, err := ds.NewHost(context.Background(), &fleet.Host{ - DetailUpdatedAt: time.Now(), - LabelUpdatedAt: time.Now(), - SeenTime: time.Now(), - NodeKey: "1234", - UUID: "1", - Hostname: "foo.local", - PrimaryIP: "192.168.1.1", - PrimaryMac: "30-65-EC-6F-C4-58", - }) - require.NoError(t, err) - - req := submitLogsRequest{ - NodeKey: "1234", - LogType: "status", - Data: nil, - } - res := submitLogsResponse{} - doJSONReq(t, req, "POST", server, "/api/v1/osquery/log", token, http.StatusOK, &res) - - logString := buf.String() - assert.Equal(t, 1, strings.Count(logString, "\"ip_addr\"")) - assert.Equal(t, 1, strings.Count(logString, "x_for_ip_addr")) -} - -func TestEnrollAgentLogsErrors(t *testing.T) { - buf := new(bytes.Buffer) - logger := log.NewJSONLogger(buf) - logger = level.NewFilter(logger, level.AllowDebug()) - - ds := mysql.CreateMySQLDS(t) - defer ds.Close() - - _, server := RunServerForTestsWithDS(t, ds, TestServerOpts{Logger: logger}) - defer server.Close() - - _, err := ds.NewHost(context.Background(), &fleet.Host{ - DetailUpdatedAt: time.Now(), - LabelUpdatedAt: time.Now(), - SeenTime: time.Now(), - NodeKey: "1234", - UUID: "1", - Hostname: "foo.local", - PrimaryIP: "192.168.1.1", - PrimaryMac: "30-65-EC-6F-C4-58", - }) - require.NoError(t, err) - - j, err := json.Marshal(&enrollAgentRequest{ - EnrollSecret: "1234", - HostIdentifier: "4321", - HostDetails: nil, - }) - require.NoError(t, err) - - requestBody := &nopCloser{bytes.NewBuffer(j)} - req, _ := http.NewRequest("POST", server.URL+"/api/v1/osquery/enroll", requestBody) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) - require.NoError(t, resp.Body.Close()) - - parts := strings.Split(strings.TrimSpace(buf.String()), "\n") - require.Len(t, parts, 1) - logData := make(map[string]json.RawMessage) - require.NoError(t, json.Unmarshal([]byte(parts[0]), &logData)) - assert.Equal(t, json.RawMessage(`["enroll failed: no matching secret found"]`), logData["err"]) -} - -func TestLicenseExpiration(t *testing.T) { - ds := mysql.CreateMySQLDS(t) - defer ds.Close() - - testCases := []struct { - name string - tier string - expiration time.Time - shouldHaveHeader bool - }{ - {"premium expired", fleet.TierPremium, time.Now().Add(-24 * time.Hour), true}, - {"premium not expired", fleet.TierPremium, time.Now().Add(24 * time.Hour), false}, - {"free expired", fleet.TierFree, time.Now().Add(-24 * time.Hour), false}, - {"free not expired", fleet.TierFree, time.Now().Add(24 * time.Hour), false}, - } - - _ = createTestUsers(t, ds) - for _, tt := range testCases { - t.Run(tt.name, func(t *testing.T) { - license := &fleet.LicenseInfo{Tier: tt.tier, Expiration: tt.expiration} - _, server := RunServerForTestsWithDS(t, ds, TestServerOpts{License: license, SkipCreateTestUsers: true}) - defer server.Close() - - token := getTestAdminToken(t, server) - - resp, closeFunc := doReq(t, nil, "GET", server, "/api/v1/fleet/config", token, http.StatusOK) - defer closeFunc() - if tt.shouldHaveHeader { - require.Equal(t, fleet.HeaderLicenseValueExpired, resp.Header.Get(fleet.HeaderLicenseKey)) - } else { - require.Equal(t, "", resp.Header.Get(fleet.HeaderLicenseKey)) - } - }) - } -} diff --git a/server/service/testing_client.go b/server/service/testing_client.go new file mode 100644 index 000000000000..e5c1e3be7b3b --- /dev/null +++ b/server/service/testing_client.go @@ -0,0 +1,143 @@ +package service + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "net/http" + "net/http/httptest" + + "github.com/fleetdm/fleet/v4/server/datastore/mysql" + "github.com/fleetdm/fleet/v4/server/fleet" + "github.com/fleetdm/fleet/v4/server/test" + "github.com/ghodss/yaml" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/stretchr/testify/suite" +) + +type withDS struct { + s *suite.Suite + ds *mysql.Datastore +} + +func (ts *withDS) SetupSuite(dbName string) { + ts.ds = mysql.CreateNamedMySQLDS(ts.s.T(), dbName) + test.AddAllHostsLabel(ts.s.T(), ts.ds) +} + +func (ts *withDS) TearDownSuite() { + ts.ds.Close() +} + +type withServer struct { + withDS + + server *httptest.Server + users map[string]fleet.User + token string +} + +func (ts *withServer) SetupSuite(dbName string) { + ts.withDS.SetupSuite(dbName) + + users, server := RunServerForTestsWithDS(ts.s.T(), ts.ds) + ts.server = server + ts.users = users + ts.token = ts.getTestAdminToken() +} + +func (ts *withServer) TearDownSuite() { + ts.withDS.TearDownSuite() +} + +func (ts *withServer) Do(verb, path string, params interface{}, expectedStatusCode int) *http.Response { + t := ts.s.T() + + j, err := json.Marshal(params) + require.NoError(t, err) + + resp := ts.DoRaw(verb, path, j, expectedStatusCode) + + t.Cleanup(func() { + resp.Body.Close() + }) + return resp +} + +func (ts *withServer) DoRawWithHeaders(verb string, path string, rawBytes []byte, expectedStatusCode int, headers map[string]string) *http.Response { + t := ts.s.T() + + requestBody := io.NopCloser(bytes.NewBuffer(rawBytes)) + req, _ := http.NewRequest(verb, ts.server.URL+path, requestBody) + for key, val := range headers { + req.Header.Add(key, val) + } + client := &http.Client{} + resp, err := client.Do(req) + require.NoError(t, err) + require.Equal(t, expectedStatusCode, resp.StatusCode) + + return resp +} + +func (ts *withServer) DoRaw(verb string, path string, rawBytes []byte, expectedStatusCode int) *http.Response { + return ts.DoRawWithHeaders(verb, path, rawBytes, expectedStatusCode, map[string]string{ + "Authorization": fmt.Sprintf("Bearer %s", ts.token), + }) +} + +func (ts *withServer) DoRawNoAuth(verb string, path string, rawBytes []byte, expectedStatusCode int) *http.Response { + return ts.DoRawWithHeaders(verb, path, rawBytes, expectedStatusCode, nil) +} + +func (ts *withServer) DoJSON(verb, path string, params interface{}, expectedStatusCode int, v interface{}) { + resp := ts.Do(verb, path, params, expectedStatusCode) + err := json.NewDecoder(resp.Body).Decode(v) + require.NoError(ts.s.T(), err) + if e, ok := v.(errorer); ok { + require.NoError(ts.s.T(), e.error()) + } +} + +func (ts *withServer) getTestAdminToken() string { + testUser := testUsers["admin1"] + + params := loginRequest{ + Email: testUser.Email, + Password: testUser.PlaintextPassword, + } + j, err := json.Marshal(¶ms) + require.NoError(ts.s.T(), err) + + requestBody := io.NopCloser(bytes.NewBuffer(j)) + resp, err := http.Post(ts.server.URL+"/api/v1/fleet/login", "application/json", requestBody) + require.NoError(ts.s.T(), err) + defer resp.Body.Close() + assert.Equal(ts.s.T(), http.StatusOK, resp.StatusCode) + + var jsn = struct { + User *fleet.User `json:"user"` + Token string `json:"token"` + Err []map[string]string `json:"errors,omitempty"` + }{} + err = json.NewDecoder(resp.Body).Decode(&jsn) + require.Nil(ts.s.T(), err) + + return jsn.Token +} + +func (ts *withServer) applyConfig(spec []byte) { + var appConfigSpec interface{} + err := yaml.Unmarshal(spec, &appConfigSpec) + require.NoError(ts.s.T(), err) + + ts.Do("PATCH", "/api/v1/fleet/config", appConfigSpec, http.StatusOK) +} + +func (ts *withServer) getConfig() *appConfigResponse { + var responseBody *appConfigResponse + ts.DoJSON("GET", "/api/v1/fleet/config", nil, http.StatusOK, &responseBody) + return responseBody +} diff --git a/server/service/testing_utils.go b/server/service/testing_utils.go index 3b2ab6316a3b..8d57e4f3448e 100644 --- a/server/service/testing_utils.go +++ b/server/service/testing_utils.go @@ -2,6 +2,9 @@ package service import ( "context" + "encoding/json" + "io/ioutil" + "net/http" "net/http/httptest" "os" "strings" @@ -9,6 +12,7 @@ import ( eeservice "github.com/fleetdm/fleet/v4/ee/server/service" "github.com/fleetdm/fleet/v4/server/logging" + "github.com/stretchr/testify/assert" "github.com/WatchBeam/clock" "github.com/fleetdm/fleet/v4/server/config" @@ -163,6 +167,9 @@ func RunServerForTestsWithDS(t *testing.T, ds fleet.Datastore, opts ...TestServe limitStore, _ := memstore.New(0) r := MakeHandler(svc, config.FleetConfig{}, logger, limitStore) server := httptest.NewServer(r) + t.Cleanup(func() { + server.Close() + }) return users, server } @@ -235,3 +242,21 @@ func testStdoutPluginConfig() config.FleetConfig { c.Osquery.StatusLogPlugin = "stdout" return c } + +func assertBodyContains(t *testing.T, resp *http.Response, expected string) { + bodyBytes, err := ioutil.ReadAll(resp.Body) + require.Nil(t, err) + bodyString := string(bodyBytes) + assert.Contains(t, bodyString, expected) +} + +func getJSON(r *http.Response, target interface{}) error { + return json.NewDecoder(r.Body).Decode(target) +} + +func assertErrorCodeAndMessage(t *testing.T, resp *http.Response, code int, message string) { + err := &fleet.Error{} + require.Nil(t, getJSON(resp, err)) + assert.Equal(t, code, err.Code) + assert.Equal(t, message, err.Message) +}