Skip to content

Commit

Permalink
apiserver: set issues to nil when kolide is disabled
Browse files Browse the repository at this point in the history
  • Loading branch information
sechmann committed Jul 15, 2024
1 parent ccd4642 commit b083c8d
Show file tree
Hide file tree
Showing 7 changed files with 42 additions and 36 deletions.
2 changes: 1 addition & 1 deletion cmd/apiserver/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func run(log *logrus.Entry, cfg config.Config) error {

v4Allocator := ip.NewV4Allocator(wireguardPrefix, []string{cfg.WireGuardIPv4Prefix.Addr().String()})
v6Allocator := ip.NewV6Allocator(cfg.WireGuardIPv6Prefix)
db, err := database.New(cfg.DBPath, v4Allocator, v6Allocator, !cfg.KolideEventHandlerEnabled, log.WithField("component", "database"))
db, err := database.New(cfg.DBPath, v4Allocator, v6Allocator, cfg.KolideEventHandlerEnabled, log.WithField("component", "database"))
if err != nil {
return fmt.Errorf("initialize database: %w", err)
}
Expand Down
2 changes: 1 addition & 1 deletion hack/local-device.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func main() {
}
v4Allocator := ip.NewV4Allocator(wireguardPrefix, []string{cfg.WireGuardIPv4Prefix.Addr().String()})
v6Allocator := ip.NewV6Allocator(cfg.WireGuardIPv6Prefix)
db, err := database.New(cfg.DBPath, v4Allocator, v6Allocator, !cfg.KolideEventHandlerEnabled, logrus.New())
db, err := database.New(cfg.DBPath, v4Allocator, v6Allocator, cfg.KolideEventHandlerEnabled, logrus.New())
if err != nil {
panic(fmt.Sprint("initialize database:", err))
}
Expand Down
10 changes: 5 additions & 5 deletions internal/apiserver/auth/sessionstore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (

func TestSessionStore_SetAndGetFromCache(t *testing.T) {
ctx := context.Background()
db := testdatabase.Setup(t)
db := testdatabase.Setup(t, false)
store := auth.NewSessionStore(db)

session := &pb.Session{
Expand Down Expand Up @@ -62,7 +62,7 @@ func TestSessionStore_SetAndGetFromCache(t *testing.T) {

func TestSessionStore_Errors(t *testing.T) {
ctx := context.Background()
db := testdatabase.Setup(t)
db := testdatabase.Setup(t, false)
store := auth.NewSessionStore(db)

session := &pb.Session{
Expand All @@ -81,7 +81,7 @@ func TestSessionStore_Errors(t *testing.T) {

func TestSessionStore_Warmup(t *testing.T) {
ctx := context.Background()
db := testdatabase.Setup(t)
db := testdatabase.Setup(t, false)
store := auth.NewSessionStore(db)

for i := range 20 {
Expand Down Expand Up @@ -121,7 +121,7 @@ func TestSessionStore_Warmup(t *testing.T) {

func TestSessionStore_UpdateDevice(t *testing.T) {
ctx := context.Background()
db := testdatabase.Setup(t)
db := testdatabase.Setup(t, false)
store := auth.NewSessionStore(db)

now := time.Now()
Expand Down Expand Up @@ -168,7 +168,7 @@ func TestSessionStore_UpdateDevice(t *testing.T) {
// Test that existing sessions with the same device id are removed.
func TestSessionStore_ReplaceOnSet(t *testing.T) {
ctx := context.Background()
db := testdatabase.Setup(t)
db := testdatabase.Setup(t, false)
store := auth.NewSessionStore(db)

now := time.Now()
Expand Down
54 changes: 30 additions & 24 deletions internal/apiserver/database/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,28 @@ import (
)

type database struct {
queries Querier
ipv4Allocator ip.Allocator
ipv6Allocator ip.Allocator
defaultDeviceHealth bool
log logrus.FieldLogger
queries Querier
ipv4Allocator ip.Allocator
ipv6Allocator ip.Allocator
kolideEnabled bool
log logrus.FieldLogger
}

var mux sync.Mutex

func New(dbPath string, v4Allocator ip.Allocator, v6Allocator ip.Allocator, defaultDeviceHealth bool, log logrus.FieldLogger) (*database, error) {
func New(dbPath string, v4Allocator ip.Allocator, v6Allocator ip.Allocator, kolideEnabled bool, log logrus.FieldLogger) (*database, error) {
connectionString := "file:" + dbPath + "?_foreign_keys=1&_cache_size=-100000&_busy_timeout=5000&_journal_mode=WAL"
db, err := sql.Open("sqlite3", connectionString)
if err != nil {
return nil, fmt.Errorf("open database: %w", err)
}

apiServerDB := database{
queries: NewQuerier(db),
ipv4Allocator: v4Allocator,
ipv6Allocator: v6Allocator,
defaultDeviceHealth: defaultDeviceHealth,
log: log,
queries: NewQuerier(db),
ipv4Allocator: v4Allocator,
ipv6Allocator: v6Allocator,
kolideEnabled: kolideEnabled,
log: log,
}

if err = runMigrations(dbPath); err != nil {
Expand All @@ -60,7 +60,7 @@ func (db *database) ReadDevices(ctx context.Context) ([]*pb.Device, error) {

devices := make([]*pb.Device, 0)
for _, row := range rows {
device, err := sqlcDeviceToPbDevice(*row)
device, err := db.sqlcDeviceToPbDevice(*row)
if err != nil {
return nil, fmt.Errorf("converting device %v: %w", row.ID, err)
}
Expand Down Expand Up @@ -362,13 +362,15 @@ func (db *database) AddDevice(ctx context.Context, device *pb.Device) error {
return fmt.Errorf("finding available ip: %w", err)
}

initialHealthy := !db.kolideEnabled

err = db.queries.AddDevice(ctx, sqlc.AddDeviceParams{
Serial: device.Serial,
Username: device.Username,
PublicKey: device.PublicKey,
Ipv4: availableIpV4,
Ipv6: availableIpV6,
Healthy: db.defaultDeviceHealth,
Healthy: initialHealthy,
Platform: device.Platform,
})
if err != nil {
Expand All @@ -384,7 +386,7 @@ func (db *database) ReadDevice(ctx context.Context, publicKey string) (*pb.Devic
return nil, err
}

return sqlcDeviceToPbDevice(*device)
return db.sqlcDeviceToPbDevice(*device)
}

func (db *database) ReadDeviceById(ctx context.Context, deviceID int64) (*pb.Device, error) {
Expand All @@ -393,7 +395,7 @@ func (db *database) ReadDeviceById(ctx context.Context, deviceID int64) (*pb.Dev
return nil, err
}

return sqlcDeviceToPbDevice(*device)
return db.sqlcDeviceToPbDevice(*device)
}

func (db *database) ReadDeviceByExternalID(ctx context.Context, externalID string) (*pb.Device, error) {
Expand All @@ -406,7 +408,7 @@ func (db *database) ReadDeviceByExternalID(ctx context.Context, externalID strin
return nil, err
}

return sqlcDeviceToPbDevice(*device)
return db.sqlcDeviceToPbDevice(*device)
}

func (db *database) ReadGateways(ctx context.Context) ([]*pb.Gateway, error) {
Expand Down Expand Up @@ -485,7 +487,7 @@ func (db *database) ReadDeviceBySerialPlatform(ctx context.Context, serial, plat
return nil, err
}

return sqlcDeviceToPbDevice(*device)
return db.sqlcDeviceToPbDevice(*device)
}

func (db *database) AddSessionInfo(ctx context.Context, si *pb.Session) error {
Expand Down Expand Up @@ -530,7 +532,7 @@ func (db *database) ReadSessionInfo(ctx context.Context, key string) (*pb.Sessio
return nil, err
}

return sqlcSessionAndDeviceToPbSession(row.Session, row.Device, groupIDs)
return db.sqlcSessionAndDeviceToPbSession(row.Session, row.Device, groupIDs)
}

func (db *database) ReadSessionInfos(ctx context.Context) ([]*pb.Session, error) {
Expand All @@ -546,7 +548,7 @@ func (db *database) ReadSessionInfos(ctx context.Context) ([]*pb.Session, error)
return nil, err
}

session, err := sqlcSessionAndDeviceToPbSession(row.Session, row.Device, groupIDs)
session, err := db.sqlcSessionAndDeviceToPbSession(row.Session, row.Device, groupIDs)
if err != nil {
return nil, err
}
Expand All @@ -568,7 +570,7 @@ func (db *database) ReadMostRecentSessionInfo(ctx context.Context, deviceID int6
return nil, err
}

return sqlcSessionAndDeviceToPbSession(row.Session, row.Device, groupIDs)
return db.sqlcSessionAndDeviceToPbSession(row.Session, row.Device, groupIDs)
}

func (db *database) getNextAvailableIPv4(ctx context.Context) (string, error) {
Expand Down Expand Up @@ -604,7 +606,7 @@ func (db *database) RemoveExpiredSessions(ctx context.Context) error {
return db.queries.RemoveExpiredSessions(ctx)
}

func sqlcDeviceToPbDevice(sqlcDevice sqlc.Device) (*pb.Device, error) {
func (db *database) sqlcDeviceToPbDevice(sqlcDevice sqlc.Device) (*pb.Device, error) {
pbDevice := &pb.Device{
Id: int64(sqlcDevice.ID),
Serial: sqlcDevice.Serial,
Expand All @@ -630,7 +632,11 @@ func sqlcDeviceToPbDevice(sqlcDevice sqlc.Device) (*pb.Device, error) {
pbDevice.LastSeen = timestamppb.New(stringToTime(sqlcDevice.LastSeen.String))
}

pbDevice.UpdateLastSeenIssues()
if db.kolideEnabled {
pbDevice.UpdateLastSeenIssues()
} else {
pbDevice.Issues = nil
}

return pbDevice, nil
}
Expand Down Expand Up @@ -673,8 +679,8 @@ func sqlcGatewayToPbGateway(g sqlc.Gateway, groupIDs []string, routes []*sqlc.Ge
}
}

func sqlcSessionAndDeviceToPbSession(s sqlc.Session, d sqlc.Device, groupIDs []string) (*pb.Session, error) {
device, err := sqlcDeviceToPbDevice(d)
func (db *database) sqlcSessionAndDeviceToPbSession(s sqlc.Session, d sqlc.Device, groupIDs []string) (*pb.Session, error) {
device, err := db.sqlcDeviceToPbDevice(d)
if err != nil {
return nil, fmt.Errorf("converting device: %w", err)
}
Expand Down
4 changes: 2 additions & 2 deletions internal/apiserver/database/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import (
const timeout = time.Second * 5

func TestAddGateway(t *testing.T) {
db := testdatabase.Setup(t)
db := testdatabase.Setup(t, false)

ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
Expand Down Expand Up @@ -91,7 +91,7 @@ func TestAddGateway(t *testing.T) {
}

func TestAddDevice(t *testing.T) {
db := testdatabase.Setup(t)
db := testdatabase.Setup(t, true)

ctx, cancel := context.WithTimeout(context.Background(), timeout)
defer cancel()
Expand Down
4 changes: 2 additions & 2 deletions internal/apiserver/testdatabase/testdatabase.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ const (
apiserverWireGuardIP = "10.255.240.1"
)

func Setup(t *testing.T) database.Database {
func Setup(t *testing.T, kolideEnabled bool) database.Database {
testDir := filepath.Join(os.TempDir(), "naisdevice-tests")
err := os.MkdirAll(testDir, 0o755)
if err != nil {
Expand Down Expand Up @@ -48,7 +48,7 @@ func Setup(t *testing.T) database.Database {
ipAllocator := ip.NewV4Allocator(netip.MustParsePrefix(wireguardNetworkAddress), []string{apiserverWireGuardIP})
prefix := netip.MustParsePrefix("fd00::/64")
ip6Allocator := ip.NewV6Allocator(&prefix)
db, err := database.New(tempFile.Name(), ipAllocator, ip6Allocator, false, logrus.New())
db, err := database.New(tempFile.Name(), ipAllocator, ip6Allocator, kolideEnabled, logrus.New())
if err != nil {
t.Fatalf("Instantiating database: %v", err)
}
Expand Down
2 changes: 1 addition & 1 deletion internal/integration_test/database_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@ import (
)

func NewDB(t *testing.T) database.Database {
return testdatabase.Setup(t)
return testdatabase.Setup(t, true)
}

0 comments on commit b083c8d

Please sign in to comment.