diff --git a/cmd/fleetctl/apply_test.go b/cmd/fleetctl/apply_test.go index 64cd4bfc2186..84ec2fa99640 100644 --- a/cmd/fleetctl/apply_test.go +++ b/cmd/fleetctl/apply_test.go @@ -41,8 +41,7 @@ var userRoleSpecList = []*fleet.User{ } func TestApplyUserRoles(t *testing.T) { - server, ds := runServerWithMockedDS(t) - defer server.Close() + _, ds := runServerWithMockedDS(t) ds.ListUsersFunc = func(ctx context.Context, opt fleet.UserListOptions) ([]*fleet.User, error) { return userRoleSpecList, nil @@ -102,11 +101,10 @@ spec: func TestApplyTeamSpecs(t *testing.T) { license := &fleet.LicenseInfo{Tier: fleet.TierPremium, Expiration: time.Now().Add(24 * time.Hour)} - server, ds := runServerWithMockedDS(t, service.TestServerOpts{License: license}) - defer server.Close() + _, ds := runServerWithMockedDS(t, service.TestServerOpts{License: license}) teamsByName := map[string]*fleet.Team{ - "team1": &fleet.Team{ + "team1": { ID: 42, Name: "team1", Description: "team1 description", @@ -185,8 +183,7 @@ func writeTmpYml(t *testing.T, contents string) string { } func TestApplyAppConfig(t *testing.T) { - server, ds := runServerWithMockedDS(t) - defer server.Close() + _, ds := runServerWithMockedDS(t) ds.ListUsersFunc = func(ctx context.Context, opt fleet.UserListOptions) ([]*fleet.User, error) { return userRoleSpecList, nil diff --git a/cmd/fleetctl/debug_test.go b/cmd/fleetctl/debug_test.go index f8a30a99b848..d71e72196ef0 100644 --- a/cmd/fleetctl/debug_test.go +++ b/cmd/fleetctl/debug_test.go @@ -44,8 +44,7 @@ oug6edBNpdhp8r2/4t6n3AouK0/zG2naAlmXV0JoFuEvy2bX0BbbbPg+v4WNZIsC func TestDebugConnectionCommand(t *testing.T) { t.Run("without certificate", func(t *testing.T) { - server, ds := runServerWithMockedDS(t) - defer server.Close() + _, ds := runServerWithMockedDS(t) ds.VerifyEnrollSecretFunc = func(ctx context.Context, secret string) (*fleet.EnrollSecret, error) { return nil, errors.New("invalid") diff --git a/cmd/fleetctl/get_test.go b/cmd/fleetctl/get_test.go index 8212b268fed2..2e96b18ca6db 100644 --- a/cmd/fleetctl/get_test.go +++ b/cmd/fleetctl/get_test.go @@ -55,8 +55,7 @@ var userRoleList = []*fleet.User{ } func TestGetUserRoles(t *testing.T) { - server, ds := runServerWithMockedDS(t) - defer server.Close() + _, ds := runServerWithMockedDS(t) ds.ListUsersFunc = func(ctx context.Context, opt fleet.UserListOptions) ([]*fleet.User, error) { return userRoleList, nil @@ -114,8 +113,7 @@ func TestGetTeams(t *testing.T) { for _, tt := range testCases { t.Run(tt.name, func(t *testing.T) { license := tt.license - server, ds := runServerWithMockedDS(t, service.TestServerOpts{License: license}) - defer server.Close() + _, ds := runServerWithMockedDS(t, service.TestServerOpts{License: license}) agentOpts := json.RawMessage(`{"config":{"foo":"bar"},"overrides":{"platforms":{"darwin":{"foo":"override"}}}}`) ds.ListTeamsFunc = func(ctx context.Context, filter fleet.TeamFilter, opt fleet.ListOptions) ([]*fleet.Team, error) { @@ -196,8 +194,7 @@ spec: } func TestGetHosts(t *testing.T) { - server, ds := runServerWithMockedDS(t) - defer server.Close() + _, ds := runServerWithMockedDS(t) // this func is called when no host is specified i.e. `fleetctl get hosts --json` ds.ListHostsFunc = func(ctx context.Context, filter fleet.TeamFilter, opt fleet.HostListOptions) ([]*fleet.Host, error) { @@ -343,8 +340,7 @@ func TestGetHosts(t *testing.T) { } func TestGetConfig(t *testing.T) { - server, ds := runServerWithMockedDS(t) - defer server.Close() + _, ds := runServerWithMockedDS(t) ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) { return &fleet.AppConfig{ @@ -412,8 +408,7 @@ spec: } func TestGetSoftawre(t *testing.T) { - server, ds := runServerWithMockedDS(t) - defer server.Close() + _, ds := runServerWithMockedDS(t) foo001 := fleet.Software{ Name: "foo", Version: "0.0.1", Source: "chrome_extensions", GenerateCPE: "somecpe", diff --git a/cmd/fleetctl/hosts_test.go b/cmd/fleetctl/hosts_test.go index 7e8f08060589..936f0cae63b5 100644 --- a/cmd/fleetctl/hosts_test.go +++ b/cmd/fleetctl/hosts_test.go @@ -10,8 +10,7 @@ import ( ) func TestHostTransferFlagChecks(t *testing.T) { - server, _ := runServerWithMockedDS(t) - defer server.Close() + runServerWithMockedDS(t) runAppCheckErr(t, []string{"hosts", "transfer", "--team", "team1", "--hosts", "host1", "--label", "AAA"}, @@ -24,8 +23,7 @@ func TestHostTransferFlagChecks(t *testing.T) { } func TestHostsTransferByHosts(t *testing.T) { - server, ds := runServerWithMockedDS(t) - defer server.Close() + _, ds := runServerWithMockedDS(t) ds.HostByIdentifierFunc = func(ctx context.Context, identifier string) (*fleet.Host, error) { require.Equal(t, "host1", identifier) @@ -48,8 +46,7 @@ func TestHostsTransferByHosts(t *testing.T) { } func TestHostsTransferByLabel(t *testing.T) { - server, ds := runServerWithMockedDS(t) - defer server.Close() + _, ds := runServerWithMockedDS(t) ds.HostByIdentifierFunc = func(ctx context.Context, identifier string) (*fleet.Host, error) { require.Equal(t, "host1", identifier) @@ -83,8 +80,7 @@ func TestHostsTransferByLabel(t *testing.T) { } func TestHostsTransferByStatus(t *testing.T) { - server, ds := runServerWithMockedDS(t) - defer server.Close() + _, ds := runServerWithMockedDS(t) ds.HostByIdentifierFunc = func(ctx context.Context, identifier string) (*fleet.Host, error) { require.Equal(t, "host1", identifier) @@ -118,8 +114,7 @@ func TestHostsTransferByStatus(t *testing.T) { } func TestHostsTransferByStatusAndSearchQuery(t *testing.T) { - server, ds := runServerWithMockedDS(t) - defer server.Close() + _, ds := runServerWithMockedDS(t) ds.HostByIdentifierFunc = func(ctx context.Context, identifier string) (*fleet.Host, error) { require.Equal(t, "host1", identifier) diff --git a/cmd/fleetctl/query_test.go b/cmd/fleetctl/query_test.go index 306ecf173401..145d5d0f6e90 100644 --- a/cmd/fleetctl/query_test.go +++ b/cmd/fleetctl/query_test.go @@ -16,8 +16,7 @@ import ( func TestLiveQuery(t *testing.T) { rs := pubsub.NewInmemQueryResults() lq := new(live_query.MockLiveQuery) - server, ds := runServerWithMockedDS(t, service.TestServerOpts{Rs: rs, Lq: lq}) - defer server.Close() + _, ds := runServerWithMockedDS(t, service.TestServerOpts{Rs: rs, Lq: lq}) ds.HostIDsByNameFunc = func(ctx context.Context, filter fleet.TeamFilter, hostnames []string) ([]uint, error) { return []uint{1234}, nil diff --git a/cmd/fleetctl/users_test.go b/cmd/fleetctl/users_test.go index 663c6514bc10..482a1d4e0fd7 100644 --- a/cmd/fleetctl/users_test.go +++ b/cmd/fleetctl/users_test.go @@ -9,8 +9,7 @@ import ( ) func TestUserDelete(t *testing.T) { - server, ds := runServerWithMockedDS(t) - defer server.Close() + _, ds := runServerWithMockedDS(t) ds.UserByEmailFunc = func(ctx context.Context, email string) (*fleet.User, error) { return &fleet.User{ diff --git a/server/datastore/mysql/testing_utils.go b/server/datastore/mysql/testing_utils.go index 0a56e9b964f0..be4d04a80e5f 100644 --- a/server/datastore/mysql/testing_utils.go +++ b/server/datastore/mysql/testing_utils.go @@ -250,5 +250,5 @@ func CreateNamedMySQLDS(t *testing.T, name string) *Datastore { } t.Parallel() - return initializeDatabase(t, name) + return initializeDatabase(t, name, new(DatastoreTestOptions)) } diff --git a/server/service/http_auth_test.go b/server/service/http_auth_test.go index 1469f4fe304a..eea04c61642c 100644 --- a/server/service/http_auth_test.go +++ b/server/service/http_auth_test.go @@ -21,7 +21,6 @@ import ( func TestLogin(t *testing.T) { ds, users, server := setupAuthTest(t) - defer server.Close() var loginTests = []struct { email string status int @@ -191,7 +190,6 @@ func getTestAdminToken(t *testing.T, server *httptest.Server) string { func TestNoHeaderErrorsDifferently(t *testing.T) { _, _, server := setupAuthTest(t) - defer server.Close() req, _ := http.NewRequest("GET", server.URL+"/api/v1/fleet/users", nil) client := &http.Client{} diff --git a/server/service/integration_core_test.go b/server/service/integration_core_test.go index 6d726d73d107..3922225b700b 100644 --- a/server/service/integration_core_test.go +++ b/server/service/integration_core_test.go @@ -1,6 +1,7 @@ package service import ( + "context" "fmt" "io/ioutil" "net/http" @@ -42,10 +43,8 @@ func (s *integrationTestSuite) TestDoubleUserCreationErrors() { GlobalRole: ptr.String(fleet.RoleObserver), } - respFirst := s.Do("POST", "/api/v1/fleet/users/admin", ¶ms, http.StatusOK) - defer respFirst.Body.Close() + s.Do("POST", "/api/v1/fleet/users/admin", ¶ms, http.StatusOK) respSecond := s.Do("POST", "/api/v1/fleet/users/admin", ¶ms, http.StatusConflict) - defer respSecond.Body.Close() assertBodyContains(t, respSecond, `Error 1062: Duplicate entry 'email@asd.com'`) } @@ -60,7 +59,6 @@ func (s *integrationTestSuite) TestUserWithoutRoleErrors() { } resp := s.Do("POST", "/api/v1/fleet/users/admin", ¶ms, http.StatusUnprocessableEntity) - defer resp.Body.Close() assertErrorCodeAndMessage(t, resp, fleet.ErrNoRoleNeeded, "either global role or team role needs to be defined") } @@ -74,7 +72,6 @@ func (s *integrationTestSuite) TestUserWithWrongRoleErrors() { GlobalRole: ptr.String("wrongrole"), } resp := s.Do("POST", "/api/v1/fleet/users/admin", ¶ms, http.StatusUnprocessableEntity) - defer resp.Body.Close() assertErrorCodeAndMessage(t, resp, fleet.ErrNoRoleNeeded, "GlobalRole role can only be admin, observer, or maintainer.") } @@ -97,7 +94,6 @@ func (s *integrationTestSuite) TestUserCreationWrongTeamErrors() { Teams: &teams, } resp := s.Do("POST", "/api/v1/fleet/users/admin", ¶ms, http.StatusUnprocessableEntity) - defer resp.Body.Close() assertBodyContains(t, resp, `Error 1452: Cannot add or update a child row: a foreign key constraint fails`) } @@ -106,15 +102,14 @@ func (s *integrationTestSuite) TestQueryCreationLogsActivity() { admin1 := s.users["admin1@example.com"] admin1.GravatarURL = "http://iii.com" - err := s.ds.SaveUser(&admin1) + err := s.ds.SaveUser(context.Background(), &admin1) require.NoError(t, err) params := fleet.QueryPayload{ Name: ptr.String("user1"), Query: ptr.String("select * from time;"), } - resp := s.Do("POST", "/api/v1/fleet/queries", ¶ms, http.StatusOK) - defer resp.Body.Close() + s.Do("POST", "/api/v1/fleet/queries", ¶ms, http.StatusOK) activities := listActivitiesResponse{} s.DoJSON("GET", "/api/v1/fleet/activities", nil, http.StatusOK, &activities) @@ -165,7 +160,7 @@ func (s *integrationTestSuite) TestAppConfigDefaultValues() { func (s *integrationTestSuite) TestUserRolesSpec() { t := s.T() - _, err := s.ds.NewTeam(&fleet.Team{ + _, err := s.ds.NewTeam(context.Background(), &fleet.Team{ ID: 42, Name: "team1", Description: "desc team1", @@ -180,7 +175,7 @@ func (s *integrationTestSuite) TestUserRolesSpec() { GravatarURL: "http://asd.com", GlobalRole: ptr.String(fleet.RoleObserver), } - user, err := s.ds.NewUser(u) + user, err := s.ds.NewUser(context.Background(), u) require.NoError(t, err) assert.Len(t, user.Teams, 0) @@ -198,10 +193,9 @@ func (s *integrationTestSuite) TestUserRolesSpec() { err = yaml.Unmarshal(spec, &userRoleSpec.Spec) require.NoError(t, err) - resp := s.Do("POST", "/api/v1/fleet/users/roles/spec", &userRoleSpec, http.StatusOK) - defer resp.Body.Close() + s.Do("POST", "/api/v1/fleet/users/roles/spec", &userRoleSpec, http.StatusOK) - user, err = s.ds.UserByEmail(email) + user, err = s.ds.UserByEmail(context.Background(), email) require.NoError(t, err) require.Len(t, user.Teams, 1) assert.Equal(t, fleet.RoleMaintainer, user.Teams[0].Role) @@ -214,7 +208,7 @@ func (s *integrationTestSuite) TestGlobalSchedule() { s.DoJSON("GET", "/api/v1/fleet/global/schedule", nil, http.StatusOK, &gs) require.Len(t, gs.GlobalSchedule, 0) - qr, err := s.ds.NewQuery(&fleet.Query{ + qr, err := s.ds.NewQuery(context.Background(), &fleet.Query{ Name: "TestQuery1", Description: "Some description", Query: "select * from osquery;", @@ -269,7 +263,7 @@ func (s *integrationTestSuite) TestTranslator() { func (s *integrationTestSuite) TestVulnerableSoftware() { t := s.T() - host, err := s.ds.NewHost(&fleet.Host{ + host, err := s.ds.NewHost(context.Background(), &fleet.Host{ DetailUpdatedAt: time.Now(), LabelUpdatedAt: time.Now(), SeenTime: time.Now(), @@ -290,19 +284,18 @@ func (s *integrationTestSuite) TestVulnerableSoftware() { }, } host.HostSoftware = soft - require.NoError(t, s.ds.SaveHostSoftware(host)) - require.NoError(t, s.ds.LoadHostSoftware(host)) + require.NoError(t, s.ds.SaveHostSoftware(context.Background(), host)) + require.NoError(t, s.ds.LoadHostSoftware(context.Background(), host)) soft1 := host.Software[0] if soft1.Name != "bar" { soft1 = host.Software[1] } - require.NoError(t, s.ds.AddCPEForSoftware(soft1, "somecpe")) - require.NoError(t, s.ds.InsertCVEForCPE("cve-123-123-132", []string{"somecpe"})) + require.NoError(t, s.ds.AddCPEForSoftware(context.Background(), soft1, "somecpe")) + require.NoError(t, s.ds.InsertCVEForCPE(context.Background(), "cve-123-123-132", []string{"somecpe"})) resp := s.Do("GET", fmt.Sprintf("/api/v1/fleet/hosts/%d", host.ID), nil, http.StatusOK) - defer resp.Body.Close() bodyBytes, err := ioutil.ReadAll(resp.Body) require.NoError(t, err) @@ -331,7 +324,7 @@ func (s *integrationTestSuite) TestGlobalPolicies() { t := s.T() for i := 0; i < 3; i++ { - _, err := s.ds.NewHost(&fleet.Host{ + _, err := s.ds.NewHost(context.Background(), &fleet.Host{ DetailUpdatedAt: time.Now(), LabelUpdatedAt: time.Now(), SeenTime: time.Now().Add(-time.Duration(i) * time.Minute), @@ -343,7 +336,7 @@ func (s *integrationTestSuite) TestGlobalPolicies() { require.NoError(t, err) } - qr, err := s.ds.NewQuery(&fleet.Query{ + qr, err := s.ds.NewQuery(context.Background(), &fleet.Query{ Name: "TestQuery3", Description: "Some description", Query: "select * from osquery;", @@ -381,8 +374,8 @@ func (s *integrationTestSuite) TestGlobalPolicies() { s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp) require.Len(t, listHostsResp.Hosts, 0) - require.NoError(t, s.ds.RecordPolicyQueryExecutions(h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: ptr.Bool(true)}, time.Now())) - require.NoError(t, s.ds.RecordPolicyQueryExecutions(h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now())) + require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: ptr.Bool(true)}, time.Now())) + require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now())) listHostsURL = fmt.Sprintf("/api/v1/fleet/hosts?policy_id=%d&policy_response=passing", policiesResponse.Policies[0].ID) listHostsResp = listHostsResponse{} diff --git a/server/service/integration_ds_only_test.go b/server/service/integration_ds_only_test.go index 1ac83bb08267..780a281fb2b4 100644 --- a/server/service/integration_ds_only_test.go +++ b/server/service/integration_ds_only_test.go @@ -32,10 +32,10 @@ func (s *integrationDSTestSuite) TestLicenseExpiration() { expiration time.Time shouldHaveHeader bool }{ - {"basic expired", fleet.TierBasic, time.Now().Add(-24 * time.Hour), true}, - {"basic not expired", fleet.TierBasic, time.Now().Add(24 * time.Hour), false}, - {"core expired", fleet.TierCore, time.Now().Add(-24 * time.Hour), false}, - {"core not expired", fleet.TierCore, time.Now().Add(24 * time.Hour), false}, + {"basic expired", fleet.TierPremium, time.Now().Add(-24 * time.Hour), true}, + {"basic not expired", fleet.TierPremium, time.Now().Add(24 * time.Hour), false}, + {"core expired", fleet.TierFree, time.Now().Add(-24 * time.Hour), false}, + {"core not expired", fleet.TierFree, time.Now().Add(24 * time.Hour), false}, } createTestUsers(s.T(), s.ds) @@ -45,14 +45,12 @@ func (s *integrationDSTestSuite) TestLicenseExpiration() { license := &fleet.LicenseInfo{Tier: tt.tier, Expiration: tt.expiration} _, server := RunServerForTestsWithDS(t, s.ds, TestServerOpts{License: license, SkipCreateTestUsers: true}) - defer server.Close() ts := withServer{server: server} ts.s = &s.Suite ts.token = ts.getTestAdminToken() resp := ts.Do("GET", "/api/v1/fleet/config", nil, http.StatusOK) - defer resp.Body.Close() if tt.shouldHaveHeader { require.Equal(t, fleet.HeaderLicenseValueExpired, resp.Header.Get(fleet.HeaderLicenseKey)) } else { diff --git a/server/service/integration_enterprise_test.go b/server/service/integration_enterprise_test.go index eb231e62e219..53af9f01b67b 100644 --- a/server/service/integration_enterprise_test.go +++ b/server/service/integration_enterprise_test.go @@ -1,6 +1,7 @@ package service import ( + "context" "encoding/json" "fmt" "net/http" @@ -28,7 +29,7 @@ func (s *integrationEnterpriseTestSuite) SetupSuite() { s.withDS.SetupSuite("integrationEnterpriseTestSuite") users, server := RunServerForTestsWithDS( - s.T(), s.ds, TestServerOpts{License: &fleet.LicenseInfo{Tier: fleet.TierBasic}}) + s.T(), s.ds, TestServerOpts{License: &fleet.LicenseInfo{Tier: fleet.TierPremium}}) s.server = server s.users = users s.token = s.getTestAdminToken() @@ -44,38 +45,35 @@ func (s *integrationEnterpriseTestSuite) TestTeamSpecs() { Description: "desc team1", } - resp := s.Do("POST", "/api/v1/fleet/teams", team, http.StatusOK) - defer resp.Body.Close() + s.Do("POST", "/api/v1/fleet/teams", team, http.StatusOK) // updates a team agentOpts := json.RawMessage(`{"config": {"foo": "bar"}, "overrides": {"platforms": {"darwin": {"foo": "override"}}}}`) teamSpecs := applyTeamSpecsRequest{Specs: []*fleet.TeamSpec{{Name: teamName, AgentOptions: &agentOpts}}} - respUpdateTeam := s.Do("POST", "/api/v1/fleet/spec/teams", teamSpecs, http.StatusOK) - defer respUpdateTeam.Body.Close() + s.Do("POST", "/api/v1/fleet/spec/teams", teamSpecs, http.StatusOK) - team, err := s.ds.TeamByName(teamName) + team, err := s.ds.TeamByName(context.Background(), teamName) require.NoError(t, err) assert.Len(t, team.Secrets, 0) require.JSONEq(t, string(agentOpts), string(*team.AgentOptions)) // creates a team with default agent options - user, err := s.ds.UserByEmail("admin1@example.com") + user, err := s.ds.UserByEmail(context.Background(), "admin1@example.com") require.NoError(t, err) - teams, err := s.ds.ListTeams(fleet.TeamFilter{User: user}, fleet.ListOptions{}) + teams, err := s.ds.ListTeams(context.Background(), fleet.TeamFilter{User: user}, fleet.ListOptions{}) require.NoError(t, err) require.True(t, len(teams) >= 1) teamSpecs = applyTeamSpecsRequest{Specs: []*fleet.TeamSpec{{Name: "team2"}}} - respUpdateTeam2 := s.Do("POST", "/api/v1/fleet/spec/teams", teamSpecs, http.StatusOK) - defer respUpdateTeam2.Body.Close() + s.Do("POST", "/api/v1/fleet/spec/teams", teamSpecs, http.StatusOK) - teams, err = s.ds.ListTeams(fleet.TeamFilter{User: user}, fleet.ListOptions{}) + teams, err = s.ds.ListTeams(context.Background(), fleet.TeamFilter{User: user}, fleet.ListOptions{}) require.NoError(t, err) assert.True(t, len(teams) >= 2) - team, err = s.ds.TeamByName("team2") + team, err = s.ds.TeamByName(context.Background(), "team2") require.NoError(t, err) defaultOpts := `{"config": {"options": {"logger_plugin": "tls", "pack_delimiter": "/", "logger_tls_period": 10, "distributed_plugin": "tls", "disable_distributed": false, "logger_tls_endpoint": "/api/v1/osquery/log", "distributed_interval": 10, "distributed_tls_max_attempts": 3}, "decorators": {"load": ["SELECT uuid AS host_uuid FROM system_info;", "SELECT hostname AS hostname FROM system_info;"]}}, "overrides": {}}` @@ -85,10 +83,9 @@ func (s *integrationEnterpriseTestSuite) TestTeamSpecs() { // updates secrets teamSpecs = applyTeamSpecsRequest{Specs: []*fleet.TeamSpec{{Name: "team2", Secrets: []fleet.EnrollSecret{{Secret: "ABC"}}}}} - respUpdateSecrets := s.Do("POST", "/api/v1/fleet/spec/teams", teamSpecs, http.StatusOK) - defer respUpdateSecrets.Body.Close() + s.Do("POST", "/api/v1/fleet/spec/teams", teamSpecs, http.StatusOK) - team, err = s.ds.TeamByName("team2") + team, err = s.ds.TeamByName(context.Background(), "team2") require.NoError(t, err) require.Len(t, team.Secrets, 1) @@ -98,7 +95,7 @@ func (s *integrationEnterpriseTestSuite) TestTeamSpecs() { func (s *integrationEnterpriseTestSuite) TestTeamSchedule() { t := s.T() - team1, err := s.ds.NewTeam(&fleet.Team{ + team1, err := s.ds.NewTeam(context.Background(), &fleet.Team{ ID: 42, Name: "team1", Description: "desc team1", @@ -109,7 +106,7 @@ func (s *integrationEnterpriseTestSuite) TestTeamSchedule() { s.DoJSON("GET", fmt.Sprintf("/api/v1/fleet/team/%d/schedule", team1.ID), nil, http.StatusOK, &ts) require.Len(t, ts.Scheduled, 0) - qr, err := s.ds.NewQuery(&fleet.Query{Name: "TestQuery2", Description: "Some description", Query: "select * from osquery;", ObserverCanRun: true}) + qr, err := s.ds.NewQuery(context.Background(), &fleet.Query{Name: "TestQuery2", Description: "Some description", Query: "select * from osquery;", ObserverCanRun: true}) require.NoError(t, err) gsParams := teamScheduleQueryRequest{ScheduledQueryPayload: fleet.ScheduledQueryPayload{QueryID: &qr.ID, Interval: ptr.Uint(42)}} diff --git a/server/service/integration_logger_test.go b/server/service/integration_logger_test.go index 85d4b2900221..0816e6e8e8f9 100644 --- a/server/service/integration_logger_test.go +++ b/server/service/integration_logger_test.go @@ -2,6 +2,7 @@ package service import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -98,7 +99,7 @@ func (s *integrationLoggerTestSuite) TestLogger() { func (s *integrationLoggerTestSuite) TestOsqueryEndpointsLogErrors() { t := s.T() - _, err := s.ds.NewHost(&fleet.Host{ + _, err := s.ds.NewHost(context.Background(), &fleet.Host{ DetailUpdatedAt: time.Now(), LabelUpdatedAt: time.Now(), SeenTime: time.Now(), @@ -125,7 +126,7 @@ func (s *integrationLoggerTestSuite) TestOsqueryEndpointsLogErrors() { func (s *integrationLoggerTestSuite) TestSubmitStatusLog() { t := s.T() - _, err := s.ds.NewHost(&fleet.Host{ + _, err := s.ds.NewHost(context.Background(), &fleet.Host{ DetailUpdatedAt: time.Now(), LabelUpdatedAt: time.Now(), SeenTime: time.Now(), @@ -153,7 +154,7 @@ func (s *integrationLoggerTestSuite) TestSubmitStatusLog() { func (s *integrationLoggerTestSuite) TestEnrollAgentLogsErrors() { t := s.T() - _, err := s.ds.NewHost(&fleet.Host{ + _, err := s.ds.NewHost(context.Background(), &fleet.Host{ DetailUpdatedAt: time.Now(), LabelUpdatedAt: time.Now(), SeenTime: time.Now(), @@ -172,12 +173,7 @@ func (s *integrationLoggerTestSuite) TestEnrollAgentLogsErrors() { }) require.NoError(t, err) - requestBody := io.NopCloser(bytes.NewBuffer(j)) - req, _ := http.NewRequest("POST", s.server.URL+"/api/v1/osquery/enroll", requestBody) - client := &http.Client{} - resp, err := client.Do(req) - require.NoError(t, err) - require.NoError(t, resp.Body.Close()) + s.DoRawNoAuth("POST", "/api/v1/osquery/enroll", j, http.StatusUnauthorized) parts := strings.Split(strings.TrimSpace(s.buf.String()), "\n") require.Len(t, parts, 1) diff --git a/server/service/testing_client.go b/server/service/testing_client.go index 8262114e2b20..e5c1e3be7b3b 100644 --- a/server/service/testing_client.go +++ b/server/service/testing_client.go @@ -50,8 +50,6 @@ func (ts *withServer) SetupSuite(dbName string) { func (ts *withServer) TearDownSuite() { ts.withDS.TearDownSuite() - - ts.server.Close() } func (ts *withServer) Do(verb, path string, params interface{}, expectedStatusCode int) *http.Response { @@ -62,25 +60,40 @@ func (ts *withServer) Do(verb, path string, params interface{}, expectedStatusCo resp := ts.DoRaw(verb, path, j, expectedStatusCode) + t.Cleanup(func() { + resp.Body.Close() + }) return resp } -func (ts *withServer) DoRaw(verb string, path string, rawBytes []byte, expectedStatusCode int) *http.Response { +func (ts *withServer) DoRawWithHeaders(verb string, path string, rawBytes []byte, expectedStatusCode int, headers map[string]string) *http.Response { t := ts.s.T() requestBody := io.NopCloser(bytes.NewBuffer(rawBytes)) req, _ := http.NewRequest(verb, ts.server.URL+path, requestBody) - req.Header.Add("Authorization", fmt.Sprintf("Bearer %s", ts.token)) + for key, val := range headers { + req.Header.Add(key, val) + } client := &http.Client{} resp, err := client.Do(req) require.NoError(t, err) require.Equal(t, expectedStatusCode, resp.StatusCode) + return resp } +func (ts *withServer) DoRaw(verb string, path string, rawBytes []byte, expectedStatusCode int) *http.Response { + return ts.DoRawWithHeaders(verb, path, rawBytes, expectedStatusCode, map[string]string{ + "Authorization": fmt.Sprintf("Bearer %s", ts.token), + }) +} + +func (ts *withServer) DoRawNoAuth(verb string, path string, rawBytes []byte, expectedStatusCode int) *http.Response { + return ts.DoRawWithHeaders(verb, path, rawBytes, expectedStatusCode, nil) +} + func (ts *withServer) DoJSON(verb, path string, params interface{}, expectedStatusCode int, v interface{}) { resp := ts.Do(verb, path, params, expectedStatusCode) - defer resp.Body.Close() err := json.NewDecoder(resp.Body).Decode(v) require.NoError(ts.s.T(), err) if e, ok := v.(errorer); ok { @@ -120,8 +133,7 @@ func (ts *withServer) applyConfig(spec []byte) { err := yaml.Unmarshal(spec, &appConfigSpec) require.NoError(ts.s.T(), err) - resp := ts.Do("PATCH", "/api/v1/fleet/config", appConfigSpec, http.StatusOK) - resp.Body.Close() + ts.Do("PATCH", "/api/v1/fleet/config", appConfigSpec, http.StatusOK) } func (ts *withServer) getConfig() *appConfigResponse { diff --git a/server/service/testing_utils.go b/server/service/testing_utils.go index 40d83fa69ec1..8d57e4f3448e 100644 --- a/server/service/testing_utils.go +++ b/server/service/testing_utils.go @@ -1,10 +1,10 @@ package service import ( + "context" "encoding/json" "io/ioutil" "net/http" - "context" "net/http/httptest" "os" "strings" @@ -167,6 +167,9 @@ func RunServerForTestsWithDS(t *testing.T, ds fleet.Datastore, opts ...TestServe limitStore, _ := memstore.New(0) r := MakeHandler(svc, config.FleetConfig{}, logger, limitStore) server := httptest.NewServer(r) + t.Cleanup(func() { + server.Close() + }) return users, server } @@ -240,11 +243,11 @@ func testStdoutPluginConfig() config.FleetConfig { return c } -func assertBodyContains(t *testing.T, resp *http.Response, expectedError string) { +func assertBodyContains(t *testing.T, resp *http.Response, expected string) { bodyBytes, err := ioutil.ReadAll(resp.Body) require.Nil(t, err) bodyString := string(bodyBytes) - assert.Contains(t, bodyString, expectedError) + assert.Contains(t, bodyString, expected) } func getJSON(r *http.Response, target interface{}) error {