Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
chiiph committed Sep 15, 2021
1 parent d3fcc05 commit b5297f5
Show file tree
Hide file tree
Showing 14 changed files with 84 additions and 103 deletions.
11 changes: 4 additions & 7 deletions cmd/fleetctl/apply_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,7 @@ var userRoleSpecList = []*fleet.User{
}

func TestApplyUserRoles(t *testing.T) {
server, ds := runServerWithMockedDS(t)
defer server.Close()
_, ds := runServerWithMockedDS(t)

ds.ListUsersFunc = func(ctx context.Context, opt fleet.UserListOptions) ([]*fleet.User, error) {
return userRoleSpecList, nil
Expand Down Expand Up @@ -102,11 +101,10 @@ spec:

func TestApplyTeamSpecs(t *testing.T) {
license := &fleet.LicenseInfo{Tier: fleet.TierPremium, Expiration: time.Now().Add(24 * time.Hour)}
server, ds := runServerWithMockedDS(t, service.TestServerOpts{License: license})
defer server.Close()
_, ds := runServerWithMockedDS(t, service.TestServerOpts{License: license})

teamsByName := map[string]*fleet.Team{
"team1": &fleet.Team{
"team1": {
ID: 42,
Name: "team1",
Description: "team1 description",
Expand Down Expand Up @@ -185,8 +183,7 @@ func writeTmpYml(t *testing.T, contents string) string {
}

func TestApplyAppConfig(t *testing.T) {
server, ds := runServerWithMockedDS(t)
defer server.Close()
_, ds := runServerWithMockedDS(t)

ds.ListUsersFunc = func(ctx context.Context, opt fleet.UserListOptions) ([]*fleet.User, error) {
return userRoleSpecList, nil
Expand Down
3 changes: 1 addition & 2 deletions cmd/fleetctl/debug_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ oug6edBNpdhp8r2/4t6n3AouK0/zG2naAlmXV0JoFuEvy2bX0BbbbPg+v4WNZIsC

func TestDebugConnectionCommand(t *testing.T) {
t.Run("without certificate", func(t *testing.T) {
server, ds := runServerWithMockedDS(t)
defer server.Close()
_, ds := runServerWithMockedDS(t)

ds.VerifyEnrollSecretFunc = func(ctx context.Context, secret string) (*fleet.EnrollSecret, error) {
return nil, errors.New("invalid")
Expand Down
15 changes: 5 additions & 10 deletions cmd/fleetctl/get_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,7 @@ var userRoleList = []*fleet.User{
}

func TestGetUserRoles(t *testing.T) {
server, ds := runServerWithMockedDS(t)
defer server.Close()
_, ds := runServerWithMockedDS(t)

ds.ListUsersFunc = func(ctx context.Context, opt fleet.UserListOptions) ([]*fleet.User, error) {
return userRoleList, nil
Expand Down Expand Up @@ -114,8 +113,7 @@ func TestGetTeams(t *testing.T) {
for _, tt := range testCases {
t.Run(tt.name, func(t *testing.T) {
license := tt.license
server, ds := runServerWithMockedDS(t, service.TestServerOpts{License: license})
defer server.Close()
_, ds := runServerWithMockedDS(t, service.TestServerOpts{License: license})

agentOpts := json.RawMessage(`{"config":{"foo":"bar"},"overrides":{"platforms":{"darwin":{"foo":"override"}}}}`)
ds.ListTeamsFunc = func(ctx context.Context, filter fleet.TeamFilter, opt fleet.ListOptions) ([]*fleet.Team, error) {
Expand Down Expand Up @@ -196,8 +194,7 @@ spec:
}

func TestGetHosts(t *testing.T) {
server, ds := runServerWithMockedDS(t)
defer server.Close()
_, ds := runServerWithMockedDS(t)

// this func is called when no host is specified i.e. `fleetctl get hosts --json`
ds.ListHostsFunc = func(ctx context.Context, filter fleet.TeamFilter, opt fleet.HostListOptions) ([]*fleet.Host, error) {
Expand Down Expand Up @@ -343,8 +340,7 @@ func TestGetHosts(t *testing.T) {
}

func TestGetConfig(t *testing.T) {
server, ds := runServerWithMockedDS(t)
defer server.Close()
_, ds := runServerWithMockedDS(t)

ds.AppConfigFunc = func(ctx context.Context) (*fleet.AppConfig, error) {
return &fleet.AppConfig{
Expand Down Expand Up @@ -412,8 +408,7 @@ spec:
}

func TestGetSoftawre(t *testing.T) {
server, ds := runServerWithMockedDS(t)
defer server.Close()
_, ds := runServerWithMockedDS(t)

foo001 := fleet.Software{
Name: "foo", Version: "0.0.1", Source: "chrome_extensions", GenerateCPE: "somecpe",
Expand Down
15 changes: 5 additions & 10 deletions cmd/fleetctl/hosts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ import (
)

func TestHostTransferFlagChecks(t *testing.T) {
server, _ := runServerWithMockedDS(t)
defer server.Close()
runServerWithMockedDS(t)

runAppCheckErr(t,
[]string{"hosts", "transfer", "--team", "team1", "--hosts", "host1", "--label", "AAA"},
Expand All @@ -24,8 +23,7 @@ func TestHostTransferFlagChecks(t *testing.T) {
}

func TestHostsTransferByHosts(t *testing.T) {
server, ds := runServerWithMockedDS(t)
defer server.Close()
_, ds := runServerWithMockedDS(t)

ds.HostByIdentifierFunc = func(ctx context.Context, identifier string) (*fleet.Host, error) {
require.Equal(t, "host1", identifier)
Expand All @@ -48,8 +46,7 @@ func TestHostsTransferByHosts(t *testing.T) {
}

func TestHostsTransferByLabel(t *testing.T) {
server, ds := runServerWithMockedDS(t)
defer server.Close()
_, ds := runServerWithMockedDS(t)

ds.HostByIdentifierFunc = func(ctx context.Context, identifier string) (*fleet.Host, error) {
require.Equal(t, "host1", identifier)
Expand Down Expand Up @@ -83,8 +80,7 @@ func TestHostsTransferByLabel(t *testing.T) {
}

func TestHostsTransferByStatus(t *testing.T) {
server, ds := runServerWithMockedDS(t)
defer server.Close()
_, ds := runServerWithMockedDS(t)

ds.HostByIdentifierFunc = func(ctx context.Context, identifier string) (*fleet.Host, error) {
require.Equal(t, "host1", identifier)
Expand Down Expand Up @@ -118,8 +114,7 @@ func TestHostsTransferByStatus(t *testing.T) {
}

func TestHostsTransferByStatusAndSearchQuery(t *testing.T) {
server, ds := runServerWithMockedDS(t)
defer server.Close()
_, ds := runServerWithMockedDS(t)

ds.HostByIdentifierFunc = func(ctx context.Context, identifier string) (*fleet.Host, error) {
require.Equal(t, "host1", identifier)
Expand Down
3 changes: 1 addition & 2 deletions cmd/fleetctl/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ import (
func TestLiveQuery(t *testing.T) {
rs := pubsub.NewInmemQueryResults()
lq := new(live_query.MockLiveQuery)
server, ds := runServerWithMockedDS(t, service.TestServerOpts{Rs: rs, Lq: lq})
defer server.Close()
_, ds := runServerWithMockedDS(t, service.TestServerOpts{Rs: rs, Lq: lq})

ds.HostIDsByNameFunc = func(ctx context.Context, filter fleet.TeamFilter, hostnames []string) ([]uint, error) {
return []uint{1234}, nil
Expand Down
3 changes: 1 addition & 2 deletions cmd/fleetctl/users_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ import (
)

func TestUserDelete(t *testing.T) {
server, ds := runServerWithMockedDS(t)
defer server.Close()
_, ds := runServerWithMockedDS(t)

ds.UserByEmailFunc = func(ctx context.Context, email string) (*fleet.User, error) {
return &fleet.User{
Expand Down
2 changes: 1 addition & 1 deletion server/datastore/mysql/testing_utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -250,5 +250,5 @@ func CreateNamedMySQLDS(t *testing.T, name string) *Datastore {
}

t.Parallel()
return initializeDatabase(t, name)
return initializeDatabase(t, name, new(DatastoreTestOptions))
}
2 changes: 0 additions & 2 deletions server/service/http_auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (

func TestLogin(t *testing.T) {
ds, users, server := setupAuthTest(t)
defer server.Close()
var loginTests = []struct {
email string
status int
Expand Down Expand Up @@ -191,7 +190,6 @@ func getTestAdminToken(t *testing.T, server *httptest.Server) string {

func TestNoHeaderErrorsDifferently(t *testing.T) {
_, _, server := setupAuthTest(t)
defer server.Close()

req, _ := http.NewRequest("GET", server.URL+"/api/v1/fleet/users", nil)
client := &http.Client{}
Expand Down
43 changes: 18 additions & 25 deletions server/service/integration_core_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package service

import (
"context"
"fmt"
"io/ioutil"
"net/http"
Expand Down Expand Up @@ -42,10 +43,8 @@ func (s *integrationTestSuite) TestDoubleUserCreationErrors() {
GlobalRole: ptr.String(fleet.RoleObserver),
}

respFirst := s.Do("POST", "/api/v1/fleet/users/admin", &params, http.StatusOK)
defer respFirst.Body.Close()
s.Do("POST", "/api/v1/fleet/users/admin", &params, http.StatusOK)
respSecond := s.Do("POST", "/api/v1/fleet/users/admin", &params, http.StatusConflict)
defer respSecond.Body.Close()

assertBodyContains(t, respSecond, `Error 1062: Duplicate entry 'email@asd.com'`)
}
Expand All @@ -60,7 +59,6 @@ func (s *integrationTestSuite) TestUserWithoutRoleErrors() {
}

resp := s.Do("POST", "/api/v1/fleet/users/admin", &params, http.StatusUnprocessableEntity)
defer resp.Body.Close()
assertErrorCodeAndMessage(t, resp, fleet.ErrNoRoleNeeded, "either global role or team role needs to be defined")
}

Expand All @@ -74,7 +72,6 @@ func (s *integrationTestSuite) TestUserWithWrongRoleErrors() {
GlobalRole: ptr.String("wrongrole"),
}
resp := s.Do("POST", "/api/v1/fleet/users/admin", &params, http.StatusUnprocessableEntity)
defer resp.Body.Close()
assertErrorCodeAndMessage(t, resp, fleet.ErrNoRoleNeeded, "GlobalRole role can only be admin, observer, or maintainer.")
}

Expand All @@ -97,7 +94,6 @@ func (s *integrationTestSuite) TestUserCreationWrongTeamErrors() {
Teams: &teams,
}
resp := s.Do("POST", "/api/v1/fleet/users/admin", &params, http.StatusUnprocessableEntity)
defer resp.Body.Close()
assertBodyContains(t, resp, `Error 1452: Cannot add or update a child row: a foreign key constraint fails`)
}

Expand All @@ -106,15 +102,14 @@ func (s *integrationTestSuite) TestQueryCreationLogsActivity() {

admin1 := s.users["admin1@example.com"]
admin1.GravatarURL = "http://iii.com"
err := s.ds.SaveUser(&admin1)
err := s.ds.SaveUser(context.Background(), &admin1)
require.NoError(t, err)

params := fleet.QueryPayload{
Name: ptr.String("user1"),
Query: ptr.String("select * from time;"),
}
resp := s.Do("POST", "/api/v1/fleet/queries", &params, http.StatusOK)
defer resp.Body.Close()
s.Do("POST", "/api/v1/fleet/queries", &params, http.StatusOK)

activities := listActivitiesResponse{}
s.DoJSON("GET", "/api/v1/fleet/activities", nil, http.StatusOK, &activities)
Expand Down Expand Up @@ -165,7 +160,7 @@ func (s *integrationTestSuite) TestAppConfigDefaultValues() {
func (s *integrationTestSuite) TestUserRolesSpec() {
t := s.T()

_, err := s.ds.NewTeam(&fleet.Team{
_, err := s.ds.NewTeam(context.Background(), &fleet.Team{
ID: 42,
Name: "team1",
Description: "desc team1",
Expand All @@ -180,7 +175,7 @@ func (s *integrationTestSuite) TestUserRolesSpec() {
GravatarURL: "http://asd.com",
GlobalRole: ptr.String(fleet.RoleObserver),
}
user, err := s.ds.NewUser(u)
user, err := s.ds.NewUser(context.Background(), u)
require.NoError(t, err)
assert.Len(t, user.Teams, 0)

Expand All @@ -198,10 +193,9 @@ func (s *integrationTestSuite) TestUserRolesSpec() {
err = yaml.Unmarshal(spec, &userRoleSpec.Spec)
require.NoError(t, err)

resp := s.Do("POST", "/api/v1/fleet/users/roles/spec", &userRoleSpec, http.StatusOK)
defer resp.Body.Close()
s.Do("POST", "/api/v1/fleet/users/roles/spec", &userRoleSpec, http.StatusOK)

user, err = s.ds.UserByEmail(email)
user, err = s.ds.UserByEmail(context.Background(), email)
require.NoError(t, err)
require.Len(t, user.Teams, 1)
assert.Equal(t, fleet.RoleMaintainer, user.Teams[0].Role)
Expand All @@ -214,7 +208,7 @@ func (s *integrationTestSuite) TestGlobalSchedule() {
s.DoJSON("GET", "/api/v1/fleet/global/schedule", nil, http.StatusOK, &gs)
require.Len(t, gs.GlobalSchedule, 0)

qr, err := s.ds.NewQuery(&fleet.Query{
qr, err := s.ds.NewQuery(context.Background(), &fleet.Query{
Name: "TestQuery1",
Description: "Some description",
Query: "select * from osquery;",
Expand Down Expand Up @@ -269,7 +263,7 @@ func (s *integrationTestSuite) TestTranslator() {
func (s *integrationTestSuite) TestVulnerableSoftware() {
t := s.T()

host, err := s.ds.NewHost(&fleet.Host{
host, err := s.ds.NewHost(context.Background(), &fleet.Host{
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
SeenTime: time.Now(),
Expand All @@ -290,19 +284,18 @@ func (s *integrationTestSuite) TestVulnerableSoftware() {
},
}
host.HostSoftware = soft
require.NoError(t, s.ds.SaveHostSoftware(host))
require.NoError(t, s.ds.LoadHostSoftware(host))
require.NoError(t, s.ds.SaveHostSoftware(context.Background(), host))
require.NoError(t, s.ds.LoadHostSoftware(context.Background(), host))

soft1 := host.Software[0]
if soft1.Name != "bar" {
soft1 = host.Software[1]
}

require.NoError(t, s.ds.AddCPEForSoftware(soft1, "somecpe"))
require.NoError(t, s.ds.InsertCVEForCPE("cve-123-123-132", []string{"somecpe"}))
require.NoError(t, s.ds.AddCPEForSoftware(context.Background(), soft1, "somecpe"))
require.NoError(t, s.ds.InsertCVEForCPE(context.Background(), "cve-123-123-132", []string{"somecpe"}))

resp := s.Do("GET", fmt.Sprintf("/api/v1/fleet/hosts/%d", host.ID), nil, http.StatusOK)
defer resp.Body.Close()
bodyBytes, err := ioutil.ReadAll(resp.Body)
require.NoError(t, err)

Expand Down Expand Up @@ -331,7 +324,7 @@ func (s *integrationTestSuite) TestGlobalPolicies() {
t := s.T()

for i := 0; i < 3; i++ {
_, err := s.ds.NewHost(&fleet.Host{
_, err := s.ds.NewHost(context.Background(), &fleet.Host{
DetailUpdatedAt: time.Now(),
LabelUpdatedAt: time.Now(),
SeenTime: time.Now().Add(-time.Duration(i) * time.Minute),
Expand All @@ -343,7 +336,7 @@ func (s *integrationTestSuite) TestGlobalPolicies() {
require.NoError(t, err)
}

qr, err := s.ds.NewQuery(&fleet.Query{
qr, err := s.ds.NewQuery(context.Background(), &fleet.Query{
Name: "TestQuery3",
Description: "Some description",
Query: "select * from osquery;",
Expand Down Expand Up @@ -381,8 +374,8 @@ func (s *integrationTestSuite) TestGlobalPolicies() {
s.DoJSON("GET", listHostsURL, nil, http.StatusOK, &listHostsResp)
require.Len(t, listHostsResp.Hosts, 0)

require.NoError(t, s.ds.RecordPolicyQueryExecutions(h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: ptr.Bool(true)}, time.Now()))
require.NoError(t, s.ds.RecordPolicyQueryExecutions(h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now()))
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h1.Host, map[uint]*bool{policiesResponse.Policies[0].ID: ptr.Bool(true)}, time.Now()))
require.NoError(t, s.ds.RecordPolicyQueryExecutions(context.Background(), h2.Host, map[uint]*bool{policiesResponse.Policies[0].ID: nil}, time.Now()))

listHostsURL = fmt.Sprintf("/api/v1/fleet/hosts?policy_id=%d&policy_response=passing", policiesResponse.Policies[0].ID)
listHostsResp = listHostsResponse{}
Expand Down
10 changes: 4 additions & 6 deletions server/service/integration_ds_only_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ func (s *integrationDSTestSuite) TestLicenseExpiration() {
expiration time.Time
shouldHaveHeader bool
}{
{"basic expired", fleet.TierBasic, time.Now().Add(-24 * time.Hour), true},
{"basic not expired", fleet.TierBasic, time.Now().Add(24 * time.Hour), false},
{"core expired", fleet.TierCore, time.Now().Add(-24 * time.Hour), false},
{"core not expired", fleet.TierCore, time.Now().Add(24 * time.Hour), false},
{"basic expired", fleet.TierPremium, time.Now().Add(-24 * time.Hour), true},
{"basic not expired", fleet.TierPremium, time.Now().Add(24 * time.Hour), false},
{"core expired", fleet.TierFree, time.Now().Add(-24 * time.Hour), false},
{"core not expired", fleet.TierFree, time.Now().Add(24 * time.Hour), false},
}

createTestUsers(s.T(), s.ds)
Expand All @@ -45,14 +45,12 @@ func (s *integrationDSTestSuite) TestLicenseExpiration() {

license := &fleet.LicenseInfo{Tier: tt.tier, Expiration: tt.expiration}
_, server := RunServerForTestsWithDS(t, s.ds, TestServerOpts{License: license, SkipCreateTestUsers: true})
defer server.Close()

ts := withServer{server: server}
ts.s = &s.Suite
ts.token = ts.getTestAdminToken()

resp := ts.Do("GET", "/api/v1/fleet/config", nil, http.StatusOK)
defer resp.Body.Close()
if tt.shouldHaveHeader {
require.Equal(t, fleet.HeaderLicenseValueExpired, resp.Header.Get(fleet.HeaderLicenseKey))
} else {
Expand Down
Loading

0 comments on commit b5297f5

Please sign in to comment.