Skip to content

Commit

Permalink
add proper character escaping:
Browse files Browse the repository at this point in the history
- count the number of \ in front
- do not escape it again if it was already escaped
  • Loading branch information
creativeprojects committed Nov 12, 2020
1 parent c9b0c34 commit be97fae
Show file tree
Hide file tree
Showing 8 changed files with 112 additions and 75 deletions.
63 changes: 32 additions & 31 deletions commands.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"encoding/base64"
"errors"
"fmt"
"io"
"os"
"path/filepath"
"runtime"
Expand All @@ -25,9 +26,9 @@ import (
type ownCommand struct {
name string
description string
action func(*config.Config, commandLineFlags, []string) error
action func(io.Writer, *config.Config, commandLineFlags, []string) error
needConfiguration bool // true if the action needs a configuration file loaded
hide bool
hide bool // don't display the command in the help
}

var (
Expand Down Expand Up @@ -108,8 +109,8 @@ var (
}
)

func displayOwnCommands() {
w := tabwriter.NewWriter(os.Stdout, 0, 0, 3, ' ', 0)
func displayOwnCommands(output io.Writer) {
w := tabwriter.NewWriter(output, 0, 0, 3, ' ', 0)
for _, command := range ownCommands {
if command.hide {
continue
Expand All @@ -131,23 +132,23 @@ func isOwnCommand(command string, configurationLoaded bool) bool {
func runOwnCommand(configuration *config.Config, command string, flags commandLineFlags, args []string) error {
for _, commandDef := range ownCommands {
if commandDef.name == command {
return commandDef.action(configuration, flags, args)
return commandDef.action(os.Stdout, configuration, flags, args)
}
}
return fmt.Errorf("command not found: %v", command)
}

func displayProfilesCommand(configuration *config.Config, _ commandLineFlags, _ []string) error {
displayProfiles(configuration)
displayGroups(configuration)
func displayProfilesCommand(output io.Writer, configuration *config.Config, _ commandLineFlags, _ []string) error {
displayProfiles(output, configuration)
displayGroups(output, configuration)
return nil
}

func displayVersion(_ *config.Config, flags commandLineFlags, _ []string) error {
fmt.Printf("resticprofile version %s commit %s.\n", version, commit)
func displayVersion(output io.Writer, _ *config.Config, flags commandLineFlags, _ []string) error {
fmt.Fprintf(output, "resticprofile version %s commit %s.\n", version, commit)

if flags.verbose {
w := tabwriter.NewWriter(os.Stderr, 0, 0, 3, ' ', 0)
w := tabwriter.NewWriter(output, 0, 0, 3, ' ', 0)
_, _ = fmt.Fprintf(w, "\n")
_, _ = fmt.Fprintf(w, "\t%s:\t%s\n", "home", "https://github.com/creativeprojects/resticprofile")
_, _ = fmt.Fprintf(w, "\t%s:\t%s\n", "os", runtime.GOOS)
Expand All @@ -173,14 +174,14 @@ func displayVersion(_ *config.Config, flags commandLineFlags, _ []string) error
return nil
}

func displayProfiles(configuration *config.Config) {
func displayProfiles(output io.Writer, configuration *config.Config) {
profileSections := configuration.GetProfileSections()
keys := sortedMapKeys(profileSections)
if len(profileSections) == 0 {
fmt.Println("\nThere's no available profile in the configuration")
fmt.Fprintln(output, "\nThere's no available profile in the configuration")
} else {
fmt.Println("\nProfiles available:")
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(output, "\nProfiles available:")
w := tabwriter.NewWriter(output, 0, 0, 2, ' ', 0)
for _, name := range keys {
sections := profileSections[name]
sort.Strings(sections)
Expand All @@ -192,36 +193,36 @@ func displayProfiles(configuration *config.Config) {
}
_ = w.Flush()
}
fmt.Println("")
fmt.Fprintln(output, "")
}

func displayGroups(configuration *config.Config) {
func displayGroups(output io.Writer, configuration *config.Config) {
groups := configuration.GetProfileGroups()
if len(groups) == 0 {
return
}
fmt.Println("Groups available:")
w := tabwriter.NewWriter(os.Stdout, 0, 0, 2, ' ', 0)
fmt.Fprintln(output, "Groups available:")
w := tabwriter.NewWriter(output, 0, 0, 2, ' ', 0)
for name, groupList := range groups {
_, _ = fmt.Fprintf(w, "\t%s:\t%s\n", name, strings.Join(groupList, ", "))
}
_ = w.Flush()
fmt.Println("")
fmt.Fprintln(output, "")
}

func selfUpdate(_ *config.Config, flags commandLineFlags, args []string) error {
func selfUpdate(_ io.Writer, _ *config.Config, flags commandLineFlags, args []string) error {
err := confirmAndSelfUpdate(flags.quiet, flags.verbose, version)
if err != nil {
return err
}
return nil
}

func panicCommand(_ *config.Config, _ commandLineFlags, _ []string) error {
func panicCommand(_ io.Writer, _ *config.Config, _ commandLineFlags, _ []string) error {
panic("you asked for it")
}

func testCommand(_ *config.Config, _ commandLineFlags, _ []string) error {
func testCommand(_ io.Writer, _ *config.Config, _ commandLineFlags, _ []string) error {
clog.Info("Nothing to test")
return nil
}
Expand All @@ -235,7 +236,7 @@ func sortedMapKeys(data map[string][]string) []string {
return keys
}

func showProfile(c *config.Config, flags commandLineFlags, args []string) error {
func showProfile(output io.Writer, c *config.Config, flags commandLineFlags, args []string) error {
// Show global section first
global, err := c.GetGlobalSection()
if err != nil {
Expand Down Expand Up @@ -266,7 +267,7 @@ func showProfile(c *config.Config, flags commandLineFlags, args []string) error
}

// randomKey simply display a base64'd random key to the console
func randomKey(c *config.Config, flags commandLineFlags, args []string) error {
func randomKey(output io.Writer, c *config.Config, flags commandLineFlags, args []string) error {
var err error
size := uint64(1024)
// flags.resticArgs contain the command and the rest of the command line
Expand All @@ -285,14 +286,14 @@ func randomKey(c *config.Config, flags commandLineFlags, args []string) error {
if err != nil {
return err
}
encoder := base64.NewEncoder(base64.StdEncoding, os.Stdout)
encoder := base64.NewEncoder(base64.StdEncoding, output)
_, err = encoder.Write(buffer)
encoder.Close()
fmt.Println("")
fmt.Fprintln(output, "")
return err
}

func createSchedule(c *config.Config, flags commandLineFlags, args []string) error {
func createSchedule(_ io.Writer, c *config.Config, flags commandLineFlags, args []string) error {
profile, err := c.GetProfile(flags.name)
if err != nil {
return fmt.Errorf("cannot load profile '%s': %w", flags.name, err)
Expand All @@ -313,7 +314,7 @@ func createSchedule(c *config.Config, flags commandLineFlags, args []string) err
return nil
}

func removeSchedule(c *config.Config, flags commandLineFlags, args []string) error {
func removeSchedule(_ io.Writer, c *config.Config, flags commandLineFlags, args []string) error {
profile, err := c.GetProfile(flags.name)
if err != nil {
return fmt.Errorf("cannot load profile '%s': %w", flags.name, err)
Expand All @@ -334,7 +335,7 @@ func removeSchedule(c *config.Config, flags commandLineFlags, args []string) err
return nil
}

func statusSchedule(c *config.Config, flags commandLineFlags, args []string) error {
func statusSchedule(_ io.Writer, c *config.Config, flags commandLineFlags, args []string) error {
profile, err := c.GetProfile(flags.name)
if err != nil {
return fmt.Errorf("cannot load profile '%s': %w", flags.name, err)
Expand All @@ -355,7 +356,7 @@ func statusSchedule(c *config.Config, flags commandLineFlags, args []string) err
return nil
}

func testElevationCommand(c *config.Config, flags commandLineFlags, args []string) error {
func testElevationCommand(_ io.Writer, c *config.Config, flags commandLineFlags, args []string) error {
if flags.isChild {
client := remote.NewClient(flags.parentPort)
term.Print("first line", "\n")
Expand Down
26 changes: 14 additions & 12 deletions commands_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ package main

import (
"errors"
"io"
"os"
"strings"
"testing"

"github.com/creativeprojects/resticprofile/config"
Expand Down Expand Up @@ -33,23 +36,22 @@ func init() {
}
}

func firstCommand(_ *config.Config, _ commandLineFlags, _ []string) error {
func firstCommand(_ io.Writer, _ *config.Config, _ commandLineFlags, _ []string) error {
return errors.New("first")
}

func secondCommand(_ *config.Config, _ commandLineFlags, _ []string) error {
func secondCommand(_ io.Writer, _ *config.Config, _ commandLineFlags, _ []string) error {
return errors.New("second")
}

func thirdCommand(_ *config.Config, _ commandLineFlags, _ []string) error {
func thirdCommand(_ io.Writer, _ *config.Config, _ commandLineFlags, _ []string) error {
return errors.New("third")
}

func ExampleDisplayOwnCommands() {
displayOwnCommands()
// Output:
// first first first
// second second second
func TestDisplayOwnCommands(t *testing.T) {
buffer := &strings.Builder{}
displayOwnCommands(buffer)
assert.Equal(t, " first first first\n second second second\n", buffer.String())
}

func TestIsOwnCommand(t *testing.T) {
Expand All @@ -68,19 +70,19 @@ func TestRunOwnCommand(t *testing.T) {

func TestPanicCommand(t *testing.T) {
assert.Panics(t, func() {
_ = panicCommand(nil, commandLineFlags{}, nil)
_ = panicCommand(nil, nil, commandLineFlags{}, nil)
})
}

func TestRandomKeyOfInvalidSize(t *testing.T) {
assert.Error(t, randomKey(nil, commandLineFlags{resticArgs: []string{"restic", "size"}}, nil))
assert.Error(t, randomKey(os.Stdout, nil, commandLineFlags{resticArgs: []string{"restic", "size"}}, nil))
}

func TestRandomKeyOfZeroSize(t *testing.T) {
assert.Error(t, randomKey(nil, commandLineFlags{resticArgs: []string{"restic", "0"}}, nil))
assert.Error(t, randomKey(os.Stdout, nil, commandLineFlags{resticArgs: []string{"restic", "0"}}, nil))
}

func TestRandomKey(t *testing.T) {
// doesn't look like much, but it's testing the random generator is not throwing an error
assert.NoError(t, randomKey(nil, commandLineFlags{}, nil))
assert.NoError(t, randomKey(os.Stdout, nil, commandLineFlags{}, nil))
}
37 changes: 33 additions & 4 deletions config/path.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,18 @@ func expandEnv(value string) string {
return value
}

func unixSpaces(value string) string {
// escapeSpaces escapes ' ' characters (unix only)
func escapeSpaces(value string) string {
if runtime.GOOS != "windows" {
value = strings.ReplaceAll(strings.TrimSpace(value), " ", `\ `)
value = escapeString(value, []byte{' '})
}
return value
}

func unixGlobs(value string) string {
// escapeShellString escapes ' ', '*' and '?' characters (unix only)
func escapeShellString(value string) string {
if runtime.GOOS != "windows" {
value = strings.ReplaceAll(value, "*", `\*`)
value = escapeString(value, []byte{' ', '*', '?'})
}
return value
}
Expand Down Expand Up @@ -81,3 +83,30 @@ func absolutePath(value string) string {
clog.Errorf("cannot determine absolute path for '%s'", value)
return value
}

// escapeString adds a '\' in front of the characters to escape.
// it checks for the number of '\' characters in front:
// - if even: add one
// - if odd: do nothing, it means the character is already escaped
func escapeString(value string, chars []byte) string {
output := &strings.Builder{}
escape := 0
for i := 0; i < len(value); i++ {
if value[i] == '\\' {
escape++
} else {
for _, char := range chars {
if value[i] == char {
if escape%2 == 0 {
// even number of escape characters in front, we need to escape this one
output.WriteByte('\\')
}
}
}
// reset number of '\'
escape = 0
}
output.WriteByte(value[i])
}
return output.String()
}
27 changes: 20 additions & 7 deletions config/path_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,29 @@ func TestFixUnixPaths(t *testing.T) {
expected string
}{
{"", ""},
{"dir", "prefix/dir"},
{"dir", "/prefix/dir"},
{"/dir", "/dir"},
{"~/dir", "~/dir"},
{"$TEMP_TEST_DIR/dir", "/home/dir"},
{"some file.txt", "prefix/some\\ file.txt"},
{"some file.txt", "/prefix/some\\ file.txt"},
{"/**/.git", "/\\*\\*/.git"},
{"/\\*\\*/.git", "/\\*\\*/.git"},
{`/?`, `/\?`},
{`/\?`, `/\?`},
{`/\\?`, `/\\\?`},
{`/\\\?`, `/\\\?`},
{`/\\\\?`, `/\\\\\?`},
{`/ ?*`, `/\ \?\*`},
}

err := os.Setenv("TEMP_TEST_DIR", "/home")
require.NoError(t, err)

for _, testPath := range paths {
fixed := fixPath(testPath.source, expandEnv, absolutePrefix("prefix"), unixSpaces, unixGlobs)
fixed := fixPath(testPath.source, expandEnv, absolutePrefix("/prefix"), escapeShellString)
assert.Equalf(t, testPath.expected, fixed, "source was '%s'", testPath.source)
// running it again should not change the value
fixed = fixPath(fixed, expandEnv, absolutePrefix("/prefix"), escapeShellString)
assert.Equalf(t, testPath.expected, fixed, "source was '%s'", testPath.source)
}
}
Expand All @@ -46,18 +56,21 @@ func TestFixWindowsPaths(t *testing.T) {
expected string
}{
{``, ``},
{`dir`, `prefix\dir`},
{`\dir`, `prefix\dir`},
{`dir`, `\prefix\dir`},
{`\dir`, `\prefix\dir`},
{`c:\dir`, `c:\dir`},
{`%TEMP_TEST_DIR%\dir`, `%TEMP_TEST_DIR%\dir`},
{"some file.txt", `prefix\some file.txt`},
{"some file.txt", `\prefix\some file.txt`},
}

err := os.Setenv("TEMP_TEST_DIR", "/home")
require.NoError(t, err)

for _, testPath := range paths {
fixed := fixPath(testPath.source, expandEnv, absolutePrefix("prefix"), unixSpaces)
fixed := fixPath(testPath.source, expandEnv, absolutePrefix("\\prefix"), escapeShellString)
assert.Equalf(t, testPath.expected, fixed, "source was '%s'", testPath.source)
// running it again should not change the value
fixed = fixPath(fixed, expandEnv, absolutePrefix("\\prefix"), escapeShellString)
assert.Equalf(t, testPath.expected, fixed, "source was '%s'", testPath.source)
}
}

0 comments on commit be97fae

Please sign in to comment.