diff --git a/internal/config/database.go b/internal/config/database.go index b32b090..6aaabdd 100644 --- a/internal/config/database.go +++ b/internal/config/database.go @@ -19,6 +19,10 @@ type DatabaseAliases interface { Aliases() []string } +type DatabasePriority interface { + Priority() uint8 +} + type DatabaseDump interface { Database DumpCommand(conf Dump) *command.Builder diff --git a/internal/database/detect_dialect.go b/internal/database/detect_dialect.go index a77a545..3ab2fa6 100644 --- a/internal/database/detect_dialect.go +++ b/internal/database/detect_dialect.go @@ -11,19 +11,50 @@ import ( var ErrDatabaseNotFound = errors.New("could not detect a database") -func DetectDialect(ctx context.Context, client kubernetes.KubeClient) (config.Database, []v1.Pod, error) { +type DetectResult map[config.Database][]v1.Pod + +func DetectDialect(ctx context.Context, client kubernetes.KubeClient) (DetectResult, error) { podList, err := client.GetNamespacedPods(ctx) if err != nil { - return nil, []v1.Pod{}, err + return nil, err } + result := make(DetectResult) for _, db := range All() { pods := kubernetes.FilterPodList(podList.Items, db.PodFilters()) if len(pods) != 0 { - return db, pods, nil + result[db] = pods + } + } + if len(result) == 0 { + return nil, ErrDatabaseNotFound + } + if len(result) > 1 { + // Find the highest priority dialects + var maxPriority uint8 + for dialect := range result { + if dbPriority, ok := dialect.(config.DatabasePriority); ok { + priority := dbPriority.Priority() + if maxPriority < priority { + maxPriority = priority + } + } + } + if maxPriority != 0 { + // Filter out dialects that are lower than the max + for dialect := range result { + if dbPriority, ok := dialect.(config.DatabasePriority); ok { + priority := dbPriority.Priority() + if priority < maxPriority { + delete(result, dialect) + } + } else { + delete(result, dialect) + } + } } } - return nil, []v1.Pod{}, ErrDatabaseNotFound + return result, nil } func DetectDialectFromPod(pod v1.Pod) (config.Database, error) { diff --git a/internal/database/detect_dialect_test.go b/internal/database/detect_dialect_test.go index bd50e59..b006a48 100644 --- a/internal/database/detect_dialect_test.go +++ b/internal/database/detect_dialect_test.go @@ -4,7 +4,7 @@ import ( "context" "testing" - "github.com/clevyr/kubedb/internal/config" + "github.com/clevyr/kubedb/internal/database/mariadb" "github.com/clevyr/kubedb/internal/database/postgres" "github.com/clevyr/kubedb/internal/kubernetes" "github.com/stretchr/testify/assert" @@ -28,7 +28,7 @@ func TestDetectDialect(t *testing.T) { mariadbPod := v1.Pod{ ObjectMeta: metav1.ObjectMeta{ Labels: map[string]string{ - "app.kubernetes.io/name": "postgresql", + "app.kubernetes.io/name": "mariadb", "app.kubernetes.io/component": "primary", }, }, @@ -40,8 +40,7 @@ func TestDetectDialect(t *testing.T) { tests := []struct { name string args args - want config.Database - want1 []v1.Pod + want DetectResult wantErr require.ErrorAssertionFunc }{ { @@ -51,8 +50,7 @@ func TestDetectDialect(t *testing.T) { ClientSet: kubernetesfake.NewSimpleClientset(&postgresPod), }, }, - postgres.Postgres{}, - []v1.Pod{postgresPod}, + DetectResult{postgres.Postgres{}: []v1.Pod{postgresPod}}, require.NoError, }, { @@ -62,8 +60,7 @@ func TestDetectDialect(t *testing.T) { ClientSet: kubernetesfake.NewSimpleClientset(&mariadbPod), }, }, - postgres.Postgres{}, - []v1.Pod{mariadbPod}, + DetectResult{mariadb.MariaDB{}: []v1.Pod{mariadbPod}}, require.NoError, }, { @@ -74,7 +71,6 @@ func TestDetectDialect(t *testing.T) { }, }, nil, - []v1.Pod{}, require.Error, }, { @@ -85,17 +81,15 @@ func TestDetectDialect(t *testing.T) { }, }, nil, - []v1.Pod{}, require.Error, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { t.Parallel() - got, got1, err := DetectDialect(context.Background(), tt.args.client) + got, err := DetectDialect(context.Background(), tt.args.client) tt.wantErr(t, err) assert.Equal(t, tt.want, got) - assert.Equal(t, tt.want1, got1) }) } } diff --git a/internal/database/mariadb/mariadb.go b/internal/database/mariadb/mariadb.go index 6eff273..f73c089 100644 --- a/internal/database/mariadb/mariadb.go +++ b/internal/database/mariadb/mariadb.go @@ -13,6 +13,7 @@ import ( var ( _ config.DatabaseAliases = MariaDB{} + _ config.DatabasePriority = MariaDB{} _ config.DatabaseDump = MariaDB{} _ config.DatabaseExec = MariaDB{} _ config.DatabaseRestore = MariaDB{} @@ -35,6 +36,8 @@ func (MariaDB) Aliases() []string { return []string{"maria", "mysql"} } +func (MariaDB) Priority() uint8 { return 255 } + func (MariaDB) PortEnvNames() kubernetes.ConfigLookups { return kubernetes.ConfigLookups{kubernetes.LookupEnv{"MARIADB_PORT_NUMBER", "MYSQL_PORT_NUMBER"}} } diff --git a/internal/database/mongodb/mongodb.go b/internal/database/mongodb/mongodb.go index 2d76cab..e605dfe 100644 --- a/internal/database/mongodb/mongodb.go +++ b/internal/database/mongodb/mongodb.go @@ -12,6 +12,7 @@ import ( var ( _ config.DatabaseAliases = MongoDB{} + _ config.DatabasePriority = MongoDB{} _ config.DatabaseDump = MongoDB{} _ config.DatabaseExec = MongoDB{} _ config.DatabaseRestore = MongoDB{} @@ -33,6 +34,8 @@ func (MongoDB) Aliases() []string { return []string{"mongo"} } +func (MongoDB) Priority() uint8 { return 255 } + func (MongoDB) PortEnvNames() kubernetes.ConfigLookups { return kubernetes.ConfigLookups{kubernetes.LookupEnv{"MONGODB_PORT_NUMBER"}} } diff --git a/internal/database/postgres/postgres.go b/internal/database/postgres/postgres.go index 12d5b26..85f5d7f 100644 --- a/internal/database/postgres/postgres.go +++ b/internal/database/postgres/postgres.go @@ -22,6 +22,7 @@ import ( var ( _ config.DatabaseAliases = Postgres{} + _ config.DatabasePriority = Postgres{} _ config.DatabaseDump = Postgres{} _ config.DatabaseExec = Postgres{} _ config.DatabaseRestore = Postgres{} @@ -46,6 +47,8 @@ func (Postgres) Aliases() []string { return []string{"postgresql", "psql", "pg"} } +func (Postgres) Priority() uint8 { return 255 } + func (Postgres) PortEnvNames() kubernetes.ConfigLookups { return kubernetes.ConfigLookups{ kubernetes.LookupEnv{"POSTGRESQL_PORT_NUMBER"}, diff --git a/internal/util/cmd_setup.go b/internal/util/cmd_setup.go index 654f843..4ded667 100644 --- a/internal/util/cmd_setup.go +++ b/internal/util/cmd_setup.go @@ -84,10 +84,33 @@ func DefaultSetup(cmd *cobra.Command, conf *config.Global, opts SetupOptions) er if dialectFlag == "" { // Configure via detection if len(pods) == 0 { - conf.Dialect, pods, err = database.DetectDialect(ctx, conf.Client) + result, err := database.DetectDialect(ctx, conf.Client) if err != nil { return err } + if len(result) == 1 || opts.NoSurvey { + for dialect, p := range result { + conf.Dialect = dialect + pods = p + break + } + } else { + opts := make([]huh.Option[config.Database], 0, len(result)) + for dialect := range result { + opts = append(opts, huh.NewOption(dialect.Name(), dialect)) + } + var chosen config.Database + if err := huh.NewForm(huh.NewGroup( + huh.NewSelect[config.Database](). + Title("Select database type"). + Options(opts...). + Value(&chosen), + )).Run(); err != nil { + return err + } + conf.Dialect = chosen + pods = result[chosen] + } log.Debug().Str("dialect", conf.Dialect.Name()).Msg("detected dialect") } else { conf.Dialect, err = database.DetectDialectFromPod(pods[0])