diff --git a/internal/cli/atlas/quickstart/access_list_setup.go b/internal/cli/atlas/quickstart/access_list_setup.go index ae9828a2d1..4e78015eef 100644 --- a/internal/cli/atlas/quickstart/access_list_setup.go +++ b/internal/cli/atlas/quickstart/access_list_setup.go @@ -20,6 +20,7 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/mongodb/mongodb-atlas-cli/internal/cli" + "github.com/mongodb/mongodb-atlas-cli/internal/flag" "github.com/mongodb/mongodb-atlas-cli/internal/store" "github.com/mongodb/mongodb-atlas-cli/internal/telemetry" atlas "go.mongodb.org/atlas/mongodbatlas" @@ -35,27 +36,30 @@ func (opts *Opts) createAccessList() error { } func (opts *Opts) askAccessListOptions() error { - if len(opts.IPAddresses) > 0 { + if !opts.shouldAskForValue(flag.AccessListIP) { return nil } + message := "" + if len(opts.IPAddresses) == 0 { + publicIP := store.IPAddress() + if publicIP != "" { + message = fmt.Sprintf(" [Press Enter to use your public IP address '%s']", publicIP) + } + opts.IPAddresses = append(opts.IPAddresses, publicIP) + } fmt.Print(` [Set up your database network access details] `) - message := "" - publicIP := store.IPAddress() - if publicIP != "" { - message = fmt.Sprintf(" [Press Enter to use your public IP address '%s']", publicIP) - } err := telemetry.TrackAskOne( - newAccessListQuestion(publicIP, message), + newAccessListQuestion(strings.Join(opts.IPAddresses, ", "), message), &opts.IPAddressesResponse, survey.WithValidator(survey.Required), ) if err == nil && opts.IPAddressesResponse != "" { ips := strings.Split(opts.IPAddressesResponse, ",") - opts.IPAddresses = append(opts.IPAddresses, ips...) + opts.IPAddresses = ips } return err } diff --git a/internal/cli/atlas/quickstart/cluster_setup.go b/internal/cli/atlas/quickstart/cluster_setup.go index e05e1cb335..4479922dba 100644 --- a/internal/cli/atlas/quickstart/cluster_setup.go +++ b/internal/cli/atlas/quickstart/cluster_setup.go @@ -21,6 +21,7 @@ import ( "github.com/AlecAivazis/survey/v2" "github.com/mongodb/mongodb-atlas-cli/internal/cli" + "github.com/mongodb/mongodb-atlas-cli/internal/flag" "github.com/mongodb/mongodb-atlas-cli/internal/search" "github.com/mongodb/mongodb-atlas-cli/internal/telemetry" "github.com/mongodb/mongodb-atlas-cli/internal/usage" @@ -38,16 +39,18 @@ func (opts *Opts) createCluster() error { func (opts *Opts) askClusterOptions() error { var qs []*survey.Question - if opts.ClusterName == "" { - opts.ClusterName = opts.defaultName + if opts.shouldAskForValue(flag.ClusterName) { + if opts.ClusterName == "" { + opts.ClusterName = opts.defaultName + } qs = append(qs, newClusterNameQuestion(opts.ClusterName)) } - if opts.Provider == "" { + if opts.shouldAskForValue(flag.Provider) { qs = append(qs, newClusterProviderQuestion()) } - if opts.Provider == "" || opts.ClusterName == "" || opts.Region == "" { + if opts.shouldAskForValue(flag.ClusterName) || opts.shouldAskForValue(flag.Provider) || opts.shouldAskForValue(flag.Region) { fmt.Print(` [Set up your Atlas cluster] `) @@ -58,7 +61,7 @@ func (opts *Opts) askClusterOptions() error { } // We need the provider to ask for the region - if opts.Region == "" { + if opts.shouldAskForValue(flag.Region) { return opts.askClusterRegion() } return nil diff --git a/internal/cli/atlas/quickstart/confirm_cluster_setup.go b/internal/cli/atlas/quickstart/confirm_cluster_setup.go index addb131527..7308a06f30 100644 --- a/internal/cli/atlas/quickstart/confirm_cluster_setup.go +++ b/internal/cli/atlas/quickstart/confirm_cluster_setup.go @@ -25,6 +25,8 @@ import ( const loadSampleDataMsg = ` Load sample data: Yes` +var ErrUserAborted = errors.New("user-aborted. Not creating cluster") + func (opts *Opts) askConfirmConfigQuestion() error { if opts.Confirm { return nil @@ -68,7 +70,7 @@ Allow connections from (IP Address): %s } if !opts.Confirm { - return errors.New("user-aborted. Not creating cluster") + return ErrUserAborted } return nil } diff --git a/internal/cli/atlas/quickstart/dbuser_setup.go b/internal/cli/atlas/quickstart/dbuser_setup.go index 80f7f9c0ea..d35fb204a0 100644 --- a/internal/cli/atlas/quickstart/dbuser_setup.go +++ b/internal/cli/atlas/quickstart/dbuser_setup.go @@ -18,12 +18,12 @@ import ( "errors" "fmt" - "github.com/mongodb/mongodb-atlas-cli/internal/config" - "github.com/mongodb/mongodb-atlas-cli/internal/telemetry" - "github.com/AlecAivazis/survey/v2" + "github.com/mongodb/mongodb-atlas-cli/internal/config" "github.com/mongodb/mongodb-atlas-cli/internal/convert" + "github.com/mongodb/mongodb-atlas-cli/internal/flag" "github.com/mongodb/mongodb-atlas-cli/internal/randgen" + "github.com/mongodb/mongodb-atlas-cli/internal/telemetry" atlas "go.mongodb.org/atlas/mongodbatlas" ) @@ -40,23 +40,28 @@ func (opts *Opts) askDBUserOptions() error { if opts.DBUsername == "" { opts.DBUsername = opts.defaultName + } + if opts.shouldAskForValue(flag.Username) { qs = append(qs, newDBUsernameQuestion(opts.DBUsername, opts.validateUniqueUsername)) } - if opts.DBUserPassword == "" { - pwd, err := generatePassword() - if err != nil { - return err + if opts.shouldAskForValue(flag.Password) { + if opts.DBUserPassword == "" { + pwd, err := generatePassword() + if err != nil { + return err + } + opts.DBUserPassword = pwd } - opts.DBUserPassword = pwd + minLength := 10 if config.Service() == config.CloudGovService { minLength = 12 } message := fmt.Sprintf(" [Must be >%d characters. Press Enter to use an auto-generated password]", minLength) - qs = append(qs, newDBUserPasswordQuestion(pwd, message)) + qs = append(qs, newDBUserPasswordQuestion(opts.DBUserPassword, message)) } if len(qs) == 0 { diff --git a/internal/cli/atlas/quickstart/quick_start.go b/internal/cli/atlas/quickstart/quick_start.go index 1395d9c71d..692a9b7f9a 100644 --- a/internal/cli/atlas/quickstart/quick_start.go +++ b/internal/cli/atlas/quickstart/quick_start.go @@ -37,6 +37,7 @@ import ( "github.com/mongodb/mongodb-atlas-cli/internal/usage" "github.com/mongodb/mongodb-atlas-cli/internal/validate" "github.com/spf13/cobra" + "github.com/spf13/pflag" atlas "go.mongodb.org/atlas/mongodbatlas" ) @@ -107,6 +108,8 @@ type Opts struct { CurrentIP bool store store.AtlasClusterQuickStarter shouldRunLogin bool + flags *pflag.FlagSet + flagSet map[string]struct{} } type quickstart struct { @@ -179,6 +182,23 @@ func (opts *Opts) quickstartPreRun(ctx context.Context, outWriter io.Writer) err return opts.login.Run(ctx) } +func (opts *Opts) shouldAskForValue(f string) bool { + _, isFlagSet := opts.flagSet[f] + return !isFlagSet +} + +func (opts *Opts) trackFlags() { + if opts.flags == nil { + opts.flagSet = make(map[string]struct{}) + return + } + + opts.flagSet = make(map[string]struct{}, opts.flags.NFlag()) + opts.flags.Visit(func(f *pflag.Flag) { + opts.flagSet[f.Name] = struct{}{} + }) +} + func (opts *Opts) PreRun(ctx context.Context, outWriter io.Writer) error { opts.shouldRunLogin = false opts.setTier() @@ -197,6 +217,7 @@ func (opts *Opts) Run() error { const base10 = 10 opts.defaultName = "Cluster" + strconv.FormatInt(time.Now().Unix(), base10)[5:] opts.providerAndRegionToConstant() + opts.trackFlags() if opts.CurrentIP { if publicIP := store.IPAddress(); publicIP != "" { @@ -457,23 +478,31 @@ func (opts *Opts) replaceWithDefaultSettings(values *quickstart) { } func (opts *Opts) interactiveSetup() error { - if err := opts.askClusterOptions(); err != nil { - return err - } + for { + if err := opts.askClusterOptions(); err != nil { + return err + } - if err := opts.askSampleDataQuestion(); err != nil { - return err - } + if err := opts.askSampleDataQuestion(); err != nil { + return err + } - if err := opts.askDBUserOptions(); err != nil { - return err - } + if err := opts.askDBUserOptions(); err != nil { + return err + } - if err := opts.askAccessListOptions(); err != nil { - return err - } + if err := opts.askAccessListOptions(); err != nil { + return err + } + + if err := opts.askConfirmConfigQuestion(); err != nil && !errors.Is(err, ErrUserAborted) { + return err + } - return opts.askConfirmConfigQuestion() + if opts.Confirm { + return nil + } + } } // Builder @@ -502,6 +531,7 @@ func Builder() *cobra.Command { return opts.quickstartPreRun(cmd.Context(), cmd.OutOrStdout()) }, RunE: func(cmd *cobra.Command, args []string) error { + opts.flags = cmd.Flags() return opts.Run() }, } diff --git a/internal/cli/atlas/quickstart/quick_start_test.go b/internal/cli/atlas/quickstart/quick_start_test.go index 2b4cdbb1fe..6f5c6efc77 100644 --- a/internal/cli/atlas/quickstart/quick_start_test.go +++ b/internal/cli/atlas/quickstart/quick_start_test.go @@ -20,7 +20,6 @@ package quickstart import ( "bytes" "context" - "fmt" "testing" "github.com/golang/mock/gomock" @@ -30,6 +29,8 @@ import ( "github.com/mongodb/mongodb-atlas-cli/internal/mocks" "github.com/mongodb/mongodb-atlas-cli/internal/test" "github.com/mongodb/mongodb-atlas-cli/internal/validate" + "github.com/spf13/pflag" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.mongodb.org/atlas/mongodbatlas" ) @@ -199,8 +200,6 @@ func TestQuickstartOpts_Run_NeedLogin_ForceAfterLogin(t *testing.T) { CreateDatabaseUser(opts.newDatabaseUser()).Return(expectedDBUser, nil). Times(1) - fmt.Println(config.Service()) - if err := opts.quickstartPreRun(ctx, buf); err != nil { t.Fatalf("Run() unexpected error: %v", err) } @@ -211,6 +210,80 @@ func TestQuickstartOpts_Run_NeedLogin_ForceAfterLogin(t *testing.T) { } } +func TestQuickstartOpts_Run_CheckFlagsSet(t *testing.T) { + t.Cleanup(test.CleanupConfig) + ctrl := gomock.NewController(t) + mockStore := mocks.NewMockAtlasClusterQuickStarter(ctrl) + defer ctrl.Finish() + + expectedCluster := &mongodbatlas.AdvancedCluster{ + StateName: "IDLE", + ConnectionStrings: &mongodbatlas.ConnectionStrings{ + StandardSrv: "", + }, + } + + expectedDBUser := &mongodbatlas.DatabaseUser{} + + var expectedProjectAccessLists *mongodbatlas.ProjectIPAccessLists + + opts := &Opts{ + ClusterName: "ProjectBar", + Region: "US", + store: mockStore, + IPAddresses: []string{"0.0.0.0"}, + DBUsername: "user", + DBUserPassword: "test", + Provider: "AWS", + SkipMongosh: true, + SkipSampleData: true, + Confirm: true, + } + + opts.runMongoShell = true + + projectIPAccessList := opts.newProjectIPAccessList() + + mockStore. + EXPECT(). + CreateCluster(opts.newCluster()).Return(expectedCluster, nil). + Times(1) + + mockStore. + EXPECT(). + CreateProjectIPAccessList(projectIPAccessList).Return(expectedProjectAccessLists, nil). + Times(1) + + mockStore. + EXPECT(). + AtlasCluster(opts.ConfigProjectID(), opts.ClusterName).Return(expectedCluster, nil). + Times(2) + + mockStore. + EXPECT(). + CreateDatabaseUser(opts.newDatabaseUser()).Return(expectedDBUser, nil). + Times(1) + + cmd := Builder() + cmd.Flags().VisitAll(func(f *pflag.Flag) { + f.Changed = false + _ = cmd.Flags().Set(f.Name, f.DefValue) + }) + + opts.flags = cmd.Flags() + + if err := opts.Run(); err != nil { + t.Fatalf("Run() unexpected error: %v", err) + } + + assert.False(t, opts.shouldAskForValue(flag.ClusterName)) + assert.False(t, opts.shouldAskForValue(flag.Region)) + assert.False(t, opts.shouldAskForValue(flag.AccessListIP)) + assert.False(t, opts.shouldAskForValue(flag.Region)) + assert.False(t, opts.shouldAskForValue(flag.Username)) + assert.False(t, opts.shouldAskForValue(flag.Password)) +} + func setConfig() func(ctx context.Context) error { return func(ctx context.Context) error { config.SetOrgID("a")