diff --git a/command/snapshot/save/snapshot_save.go b/command/snapshot/save/snapshot_save.go index 16e473b1f016..0f4e6d7e984a 100644 --- a/command/snapshot/save/snapshot_save.go +++ b/command/snapshot/save/snapshot_save.go @@ -6,9 +6,9 @@ package save import ( "flag" "fmt" - "github.com/hashicorp/consul/version" "golang.org/x/exp/slices" "os" + "path/filepath" "strings" "github.com/mitchellh/cli" @@ -26,16 +26,16 @@ func New(ui cli.Ui) *cmd { } type cmd struct { - UI cli.Ui - flags *flag.FlagSet - http *flags.HTTPFlags - help string - appendFileName flags.StringValue + UI cli.Ui + flags *flag.FlagSet + http *flags.HTTPFlags + help string + appendFileNameFlag flags.StringValue } func (c *cmd) getAppendFileNameFlag() *flag.FlagSet { fs := flag.NewFlagSet("", flag.ContinueOnError) - fs.Var(&c.appendFileName, "append-filename", "Append filename flag takes two possible values. "+ + fs.Var(&c.appendFileNameFlag, "append-filename", "Append filename flag takes two possible values. "+ "1. version, 2. dc. It appends consul version and datacenter to filename given in command") return fs } @@ -71,23 +71,38 @@ func (c *cmd) Run(args []string) int { // Create and test the HTTP client client, err := c.http.APIClient() - appendFileNameFlags := strings.Split(c.appendFileName.String(), ",") + appendFileNameFlags := strings.Split(c.appendFileNameFlag.String(), ",") - if slices.Contains(appendFileNameFlags, "version") { - file = file + "-" + version.GetHumanVersion() - } + var agentSelfResponse map[string]map[string]interface{} - if slices.Contains(appendFileNameFlags, "dc") { - agentSelfResponse, err := client.Agent().Self() + if len(appendFileNameFlags) != 0 { + agentSelfResponse, err = client.Agent().Self() if err != nil { - c.UI.Error(fmt.Sprintf("Error connecting to Consul agent and fetching datacenter: %s", err)) + c.UI.Error(fmt.Sprintf("Error connecting to Consul agent and fetching datacenter/version: %s", err)) return 1 } - if config, ok := agentSelfResponse["Config"]; ok { - if datacenter, ok := config["Datacenter"]; ok { - file = file + "-" + datacenter.(string) + + fileExt := filepath.Ext(file) + fileNameWithoutExt := strings.TrimSuffix(file, fileExt) + + if slices.Contains(appendFileNameFlags, "version") { + if config, ok := agentSelfResponse["Config"]; ok { + if version, ok := config["Version"]; ok { + fileNameWithoutExt = fileNameWithoutExt + "-" + version.(string) + } } } + + if slices.Contains(appendFileNameFlags, "dc") { + if config, ok := agentSelfResponse["Config"]; ok { + if datacenter, ok := config["Datacenter"]; ok { + fileNameWithoutExt = fileNameWithoutExt + "-" + datacenter.(string) + } + } + } + + //adding extension back + file = fileNameWithoutExt + fileExt } if err != nil { diff --git a/command/snapshot/save/snapshot_save_test.go b/command/snapshot/save/snapshot_save_test.go index a95b537e9475..47be3c77363b 100644 --- a/command/snapshot/save/snapshot_save_test.go +++ b/command/snapshot/save/snapshot_save_test.go @@ -72,6 +72,48 @@ func TestSnapshotSaveCommand_Validation(t *testing.T) { } } +func TestSnapshotSaveCommandWithAppendFileNameFlag(t *testing.T) { + if testing.Short() { + t.Skip("too slow for testing.Short") + } + + t.Parallel() + a := agent.NewTestAgent(t, ``) + defer a.Shutdown() + client := a.Client() + + ui := cli.NewMockUi() + c := New(ui) + + dir := testutil.TempDir(t, "snapshot") + file := filepath.Join(dir, "backup.tgz") + args := []string{ + "-append-filename=version,dc", + file, + } + + newFilePath := filepath.Join(dir, "backup"+"-"+a.Config.Version+"-"+a.Config.Datacenter+".tgz") + + code := c.Run(args) + if code != 0 { + t.Fatalf("bad: %d. %#v", code, ui.ErrorWriter.String()) + } + + fi, err := os.Stat(newFilePath) + require.NoError(t, err) + require.Equal(t, fi.Mode(), os.FileMode(0600)) + + f, err := os.Open(newFilePath) + if err != nil { + t.Fatalf("err: %v", err) + } + defer f.Close() + + if err := client.Snapshot().Restore(nil, f); err != nil { + t.Fatalf("err: %v", err) + } +} + func TestSnapshotSaveCommand(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short")