Skip to content

Commit

Permalink
feat: Prompt when multiple database types are discovered
Browse files Browse the repository at this point in the history
  • Loading branch information
gabe565 committed May 1, 2024
1 parent 90d0e91 commit cfd6eda
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 17 deletions.
4 changes: 4 additions & 0 deletions internal/config/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ type DatabaseAliases interface {
Aliases() []string
}

type DatabasePriority interface {
Priority() uint8
}

type DatabaseDump interface {
Database
DumpCommand(conf Dump) *command.Builder
Expand Down
39 changes: 35 additions & 4 deletions internal/database/detect_dialect.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,19 +11,50 @@ import (

var ErrDatabaseNotFound = errors.New("could not detect a database")

func DetectDialect(ctx context.Context, client kubernetes.KubeClient) (config.Database, []v1.Pod, error) {
type DetectResult map[config.Database][]v1.Pod

func DetectDialect(ctx context.Context, client kubernetes.KubeClient) (DetectResult, error) {
podList, err := client.GetNamespacedPods(ctx)
if err != nil {
return nil, []v1.Pod{}, err
return nil, err
}

result := make(DetectResult)
for _, db := range All() {
pods := kubernetes.FilterPodList(podList.Items, db.PodFilters())
if len(pods) != 0 {
return db, pods, nil
result[db] = pods
}
}
if len(result) == 0 {
return nil, ErrDatabaseNotFound
}
if len(result) > 1 {
// Find the highest priority dialects
var maxPriority uint8
for dialect := range result {
if dbPriority, ok := dialect.(config.DatabasePriority); ok {
priority := dbPriority.Priority()
if maxPriority < priority {
maxPriority = priority
}
}
}
if maxPriority != 0 {
// Filter out dialects that are lower than the max
for dialect := range result {
if dbPriority, ok := dialect.(config.DatabasePriority); ok {
priority := dbPriority.Priority()
if priority < maxPriority {
delete(result, dialect)
}
} else {
delete(result, dialect)
}
}
}
}
return nil, []v1.Pod{}, ErrDatabaseNotFound
return result, nil
}

func DetectDialectFromPod(pod v1.Pod) (config.Database, error) {
Expand Down
18 changes: 6 additions & 12 deletions internal/database/detect_dialect_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import (
"context"
"testing"

"github.com/clevyr/kubedb/internal/config"
"github.com/clevyr/kubedb/internal/database/mariadb"
"github.com/clevyr/kubedb/internal/database/postgres"
"github.com/clevyr/kubedb/internal/kubernetes"
"github.com/stretchr/testify/assert"
Expand All @@ -28,7 +28,7 @@ func TestDetectDialect(t *testing.T) {
mariadbPod := v1.Pod{
ObjectMeta: metav1.ObjectMeta{
Labels: map[string]string{
"app.kubernetes.io/name": "postgresql",
"app.kubernetes.io/name": "mariadb",
"app.kubernetes.io/component": "primary",
},
},
Expand All @@ -40,8 +40,7 @@ func TestDetectDialect(t *testing.T) {
tests := []struct {
name string
args args
want config.Database
want1 []v1.Pod
want DetectResult
wantErr require.ErrorAssertionFunc
}{
{
Expand All @@ -51,8 +50,7 @@ func TestDetectDialect(t *testing.T) {
ClientSet: kubernetesfake.NewSimpleClientset(&postgresPod),
},
},
postgres.Postgres{},
[]v1.Pod{postgresPod},
DetectResult{postgres.Postgres{}: []v1.Pod{postgresPod}},
require.NoError,
},
{
Expand All @@ -62,8 +60,7 @@ func TestDetectDialect(t *testing.T) {
ClientSet: kubernetesfake.NewSimpleClientset(&mariadbPod),
},
},
postgres.Postgres{},
[]v1.Pod{mariadbPod},
DetectResult{mariadb.MariaDB{}: []v1.Pod{mariadbPod}},
require.NoError,
},
{
Expand All @@ -74,7 +71,6 @@ func TestDetectDialect(t *testing.T) {
},
},
nil,
[]v1.Pod{},
require.Error,
},
{
Expand All @@ -85,17 +81,15 @@ func TestDetectDialect(t *testing.T) {
},
},
nil,
[]v1.Pod{},
require.Error,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, got1, err := DetectDialect(context.Background(), tt.args.client)
got, err := DetectDialect(context.Background(), tt.args.client)
tt.wantErr(t, err)
assert.Equal(t, tt.want, got)
assert.Equal(t, tt.want1, got1)
})
}
}
3 changes: 3 additions & 0 deletions internal/database/mariadb/mariadb.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

var (
_ config.DatabaseAliases = MariaDB{}
_ config.DatabasePriority = MariaDB{}
_ config.DatabaseDump = MariaDB{}
_ config.DatabaseExec = MariaDB{}
_ config.DatabaseRestore = MariaDB{}
Expand All @@ -35,6 +36,8 @@ func (MariaDB) Aliases() []string {
return []string{"maria", "mysql"}
}

func (MariaDB) Priority() uint8 { return 255 }

func (MariaDB) PortEnvNames() kubernetes.ConfigLookups {
return kubernetes.ConfigLookups{kubernetes.LookupEnv{"MARIADB_PORT_NUMBER", "MYSQL_PORT_NUMBER"}}
}
Expand Down
3 changes: 3 additions & 0 deletions internal/database/mongodb/mongodb.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

var (
_ config.DatabaseAliases = MongoDB{}
_ config.DatabasePriority = MongoDB{}
_ config.DatabaseDump = MongoDB{}
_ config.DatabaseExec = MongoDB{}
_ config.DatabaseRestore = MongoDB{}
Expand All @@ -33,6 +34,8 @@ func (MongoDB) Aliases() []string {
return []string{"mongo"}
}

func (MongoDB) Priority() uint8 { return 255 }

func (MongoDB) PortEnvNames() kubernetes.ConfigLookups {
return kubernetes.ConfigLookups{kubernetes.LookupEnv{"MONGODB_PORT_NUMBER"}}
}
Expand Down
3 changes: 3 additions & 0 deletions internal/database/postgres/postgres.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

var (
_ config.DatabaseAliases = Postgres{}
_ config.DatabasePriority = Postgres{}
_ config.DatabaseDump = Postgres{}
_ config.DatabaseExec = Postgres{}
_ config.DatabaseRestore = Postgres{}
Expand All @@ -46,6 +47,8 @@ func (Postgres) Aliases() []string {
return []string{"postgresql", "psql", "pg"}
}

func (Postgres) Priority() uint8 { return 255 }

func (Postgres) PortEnvNames() kubernetes.ConfigLookups {
return kubernetes.ConfigLookups{
kubernetes.LookupEnv{"POSTGRESQL_PORT_NUMBER"},
Expand Down
25 changes: 24 additions & 1 deletion internal/util/cmd_setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,33 @@ func DefaultSetup(cmd *cobra.Command, conf *config.Global, opts SetupOptions) er
if dialectFlag == "" {
// Configure via detection
if len(pods) == 0 {
conf.Dialect, pods, err = database.DetectDialect(ctx, conf.Client)
result, err := database.DetectDialect(ctx, conf.Client)
if err != nil {
return err
}
if len(result) == 1 || opts.NoSurvey {
for dialect, p := range result {
conf.Dialect = dialect
pods = p
break
}
} else {
opts := make([]huh.Option[config.Database], 0, len(result))
for dialect := range result {
opts = append(opts, huh.NewOption(dialect.Name(), dialect))
}
var chosen config.Database
if err := huh.NewForm(huh.NewGroup(
huh.NewSelect[config.Database]().
Title("Select database type").
Options(opts...).
Value(&chosen),
)).Run(); err != nil {
return err
}
conf.Dialect = chosen
pods = result[chosen]
}
log.Debug().Str("dialect", conf.Dialect.Name()).Msg("detected dialect")
} else {
conf.Dialect, err = database.DetectDialectFromPod(pods[0])
Expand Down

0 comments on commit cfd6eda

Please sign in to comment.