diff --git a/Makefile b/Makefile index 3eae270c..c4556f8c 100644 --- a/Makefile +++ b/Makefile @@ -19,6 +19,9 @@ cross-compile: clean ## Compile vsh binaries for multiple platforms and architec compile: clean ## Compile vsh for platform based on uname go build -ldflags "-X main.vshVersion=$(VERSION)" -o build/${APP_NAME}_$(shell uname | tr '[:upper:]' '[:lower:]')_amd64 +compile-debug: clean + go build -ldflags "-X main.vshVersion=$(VERSION)" -o build/${APP_NAME}_$(shell uname | tr '[:upper:]' '[:lower:]')_amd64 -gcflags="all=-N -l" + get-bats: ## Download bats dependencies to test directory rm -rf test/bin/ mkdir -p test/bin/core diff --git a/cli/append.go b/cli/append.go index 6907da18..6618bb79 100644 --- a/cli/append.go +++ b/cli/append.go @@ -120,7 +120,7 @@ func (cmd *AppendCommand) createDummySecret(target string) error { dummy := make(map[string]interface{}) dummy["placeholder"] = struct{}{} - dummySecret := client.NewSecret(&api.Secret{Data: dummy}) + dummySecret := client.NewSecret(&api.Secret{Data: dummy}, target) if targetSecret == nil { if err = cmd.client.Write(target, dummySecret); err != nil { return err @@ -150,7 +150,7 @@ func (cmd *AppendCommand) mergeSecrets(source string, target string) error { } // write - resultSecret := client.NewSecret(&api.Secret{Data: merged}) + resultSecret := client.NewSecret(&api.Secret{Data: merged}, target) if err := cmd.client.Write(target, resultSecret); err != nil { fmt.Println(err) return err diff --git a/cli/command.go b/cli/command.go index 783dcb7f..778ecece 100644 --- a/cli/command.go +++ b/cli/command.go @@ -2,7 +2,9 @@ package cli import ( "path/filepath" + "sort" "strings" + "sync" "github.com/fishi0x01/vsh/client" "github.com/fishi0x01/vsh/log" @@ -90,3 +92,32 @@ func transportSecrets(c *client.Client, source string, target string, transport return 0 } + +func funcOnPaths(c *client.Client, paths []string, f func(s *client.Secret) (matches []*Match)) (matches []*Match, err error) { + secrets, err := c.BatchRead(c.FilterPaths(paths, client.LEAF)) + if err != nil { + return nil, err + } + + var wg sync.WaitGroup + queue := make(chan *client.Secret, len(paths)) + recv := make(chan []*Match, len(paths)) + for _, secret := range secrets { + queue <- secret + } + for range secrets { + wg.Add(1) + go func() { + recv <- f(<-queue) + wg.Done() + }() + } + wg.Wait() + close(recv) + + for m := range recv { + matches = append(matches, m...) + } + sort.Slice(matches, func(i, j int) bool { return matches[i].path < matches[j].path }) + return matches, nil +} diff --git a/cli/grep.go b/cli/grep.go index 384aea59..4c9f470f 100644 --- a/cli/grep.go +++ b/cli/grep.go @@ -95,14 +95,12 @@ func (cmd *GrepCommand) Run() int { return 1 } - for _, curPath := range filePaths { - matches, err := cmd.grepFile(cmd.args.Search, curPath) - if err != nil { - return 1 - } - for _, match := range matches { - match.print(os.Stdout, false) - } + matches, err := cmd.grepPaths(cmd.args.Search, filePaths) + if err != nil { + return 1 + } + for _, match := range matches { + match.print(os.Stdout, false) } return 0 } @@ -116,19 +114,11 @@ func (cmd *GrepCommand) GetSearchParams() SearchParameters { } } -func (cmd *GrepCommand) grepFile(search string, path string) (matches []*Match, err error) { - matches = []*Match{} - - if cmd.client.GetType(path) == client.LEAF { - secret, err := cmd.client.Read(path) - if err != nil { - return matches, err - } - - for k, v := range secret.GetData() { - matches = append(matches, cmd.searcher.DoSearch(path, k, fmt.Sprintf("%v", v))...) +func (cmd *GrepCommand) grepPaths(search string, paths []string) (matches []*Match, err error) { + return funcOnPaths(cmd.client, paths, func(s *client.Secret) []*Match { + for k, v := range s.GetData() { + matches = append(matches, cmd.searcher.DoSearch(s.Path, k, fmt.Sprintf("%v", v))...) } - } - - return matches, nil + return matches + }) } diff --git a/cli/replace.go b/cli/replace.go index eef0767e..50eb14f6 100644 --- a/cli/replace.go +++ b/cli/replace.go @@ -114,7 +114,7 @@ func (cmd *ReplaceCommand) Run() int { return 1 } - allMatches, err := cmd.findMatches(filePaths) + allMatches, err := cmd.FindMatches(filePaths) if err != nil { log.UserError(fmt.Sprintf("%s", err)) return 1 @@ -122,25 +122,37 @@ func (cmd *ReplaceCommand) Run() int { return cmd.commitMatches(allMatches) } -func (cmd *ReplaceCommand) findMatches(filePaths []string) (matchesByPath map[string][]*Match, err error) { - matchesByPath = make(map[string][]*Match, 0) - for _, curPath := range filePaths { - matches, err := cmd.FindReplacements(cmd.args.Search, cmd.args.Replacement, curPath) - if err != nil { - return matchesByPath, err +func (cmd *ReplaceCommand) grepPaths(search string, paths []string) (matches []*Match, err error) { + return funcOnPaths(cmd.client, paths, func(s *client.Secret) []*Match { + for k, v := range s.GetData() { + matches = append(matches, cmd.searcher.DoSearch(s.Path, k, fmt.Sprintf("%v", v))...) } - for _, match := range matches { - match.print(os.Stdout, true) - } - if len(matches) > 0 { - _, ok := matchesByPath[curPath] - if ok == false { - matchesByPath[curPath] = make([]*Match, 0) - } - matchesByPath[curPath] = append(matchesByPath[curPath], matches...) + return matches + }) +} + +// FindMatches will return a map of files sorted by path in which the search occurs +func (cmd *ReplaceCommand) FindMatches(filePaths []string) (matchesByPath map[string][]*Match, err error) { + matches, err := cmd.grepPaths(cmd.args.Search, filePaths) + if err != nil { + return matchesByPath, err + } + for _, match := range matches { + match.print(os.Stdout, true) + } + return cmd.groupMatchesByPath(matches), nil +} + +func (cmd *ReplaceCommand) groupMatchesByPath(matches []*Match) (matchesByPath map[string][]*Match) { + matchesByPath = make(map[string][]*Match, 0) + for _, m := range matches { + _, ok := matchesByPath[m.path] + if ok == false { + matchesByPath[m.path] = make([]*Match, 0) } + matchesByPath[m.path] = append(matchesByPath[m.path], matches...) } - return matchesByPath, nil + return matchesByPath } func (cmd *ReplaceCommand) commitMatches(matchesByPath map[string][]*Match) int { @@ -165,34 +177,23 @@ func (cmd *ReplaceCommand) commitMatches(matchesByPath map[string][]*Match) int return 0 } -// FindReplacements will find the matches for a given search string to be replaced -func (cmd *ReplaceCommand) FindReplacements(search string, replacement string, path string) (matches []*Match, err error) { - if cmd.client.GetType(path) == client.LEAF { - secret, err := cmd.client.Read(path) - if err != nil { - return matches, err - } - - for k, v := range secret.GetData() { - match := cmd.searcher.DoSearch(path, k, fmt.Sprintf("%v", v)) - matches = append(matches, match...) - } - } - return matches, nil -} - // WriteReplacements will write replacement data back to Vault func (cmd *ReplaceCommand) WriteReplacements(groupedMatches map[string][]*Match) error { - // process matches by vault path - for path, matches := range groupedMatches { - secret, err := cmd.client.Read(path) - if err != nil { - return err - } - data := secret.GetData() + // Re-read paths because they could've gone stale + paths := make([]string, 0) + for path, _ := range groupedMatches { + paths = append(paths, path) + } + secrets, err := cmd.client.BatchRead(paths) + if err != nil { + return err + } + + for _, secret := range secrets { + data, path := secret.GetData(), secret.Path // update secret with changes. remove key w/ prior names, add renamed keys, update values. - for _, match := range matches { + for _, match := range groupedMatches[path] { if path != match.path { return fmt.Errorf("match path does not equal group path") } diff --git a/client/batch.go b/client/batch.go new file mode 100644 index 00000000..9a37ed6a --- /dev/null +++ b/client/batch.go @@ -0,0 +1,95 @@ +package client + +import ( + "fmt" +) + +// BatchOperation is a kind of operation to perform +type BatchOperation int + +// types of operations +const ( + OP_READ BatchOperation = 0 + OP_WRITE BatchOperation = 1 +) + +// how many worker threads to use for batch operations +const ( + VAULT_CONCURENCY = 5 +) + +// BatchOperation can perform reads or writes with concurrency +func (client *Client) BatchOperation(absolutePaths []string, op BatchOperation, secretsIn []*Secret) (secrets []*Secret, err error) { + read_queue := make(chan string, len(absolutePaths)) + write_queue := make(chan *Secret, len(absolutePaths)) + results := make(chan *secretOperation, len(absolutePaths)) + + // load up queue for operation + switch op { + case OP_READ: + for _, path := range absolutePaths { + read_queue <- path + } + case OP_WRITE: + for _, secret := range secretsIn { + write_queue <- secret + } + default: + return nil, fmt.Errorf("invalid batch operation") + } + + // fire off goroutines for operation + for i := 0; i < VAULT_CONCURENCY; i++ { + client.waitGroup.Add(1) + switch op { + case OP_READ: + go client.readWorker(read_queue, results) + case OP_WRITE: + go client.writeWorker(write_queue, results) + } + } + client.waitGroup.Wait() + close(results) + + // read results from the queue and return as array + for result := range results { + err = result.Error + if err != nil { + return secrets, err + } + if result.Result != nil { + secrets = append(secrets, result.Result) + } + } + return secrets, nil +} + +// readWorker fetches paths to be read from the queue until empty +func (client *Client) readWorker(queue chan string, out chan *secretOperation) { + defer client.waitGroup.Done() +readFromQueue: + for { + select { + case path := <-queue: + s, err := client.Read(path) + out <- &secretOperation{Result: s, Path: path, Error: err} + default: + break readFromQueue + } + } +} + +// writeWorker writes secrets to Vault in parallel +func (client *Client) writeWorker(queue chan *Secret, out chan *secretOperation) { + defer client.waitGroup.Done() +readFromQueue: + for { + select { + case secret := <-queue: + err := client.Write(secret.Path, secret) + out <- &secretOperation{Result: nil, Path: secret.Path, Error: err} + default: + break readFromQueue + } + } +} diff --git a/client/client.go b/client/client.go index b18b49bb..23f07681 100644 --- a/client/client.go +++ b/client/client.go @@ -6,6 +6,7 @@ import ( "path/filepath" "strconv" "strings" + "sync" "github.com/fishi0x01/vsh/log" "github.com/hashicorp/vault/api" @@ -18,6 +19,8 @@ type Client struct { Pwd string KVBackends map[string]int listCache map[string][]string + cacheMutex sync.Mutex + waitGroup sync.WaitGroup } // VaultConfig container to keep parameters for Client configuration @@ -27,6 +30,12 @@ type VaultConfig struct { StartPath string } +type secretOperation struct { + Result *Secret + Path string + Error error +} + func verifyClientPwd(client *Client) (*Client, error) { if client.Pwd == "" { client.Pwd = "/" @@ -105,11 +114,16 @@ func (client *Client) Read(absolutePath string) (secret *Secret, err error) { apiSecret, err = client.lowLevelRead(normalizedVaultPath(absolutePath)) } if apiSecret != nil { - secret = NewSecret(apiSecret) + secret = NewSecret(apiSecret, absolutePath) } return secret, err } +// BatchRead returns secrets for given paths +func (client *Client) BatchRead(absolutePaths []string) (secrets []*Secret, err error) { + return client.BatchOperation(absolutePaths, OP_READ, make([]*Secret, 0)) +} + // Write writes secret to given path, using given Client func (client *Client) Write(absolutePath string, secret *Secret) (err error) { if client.isTopLevelPath(absolutePath) { @@ -121,6 +135,12 @@ func (client *Client) Write(absolutePath string, secret *Secret) (err error) { return err } +// BatchWrite writes provided secrets to Vault +func (client *Client) BatchWrite(absolutePaths []string, secrets []*Secret) (err error) { + _, err = client.BatchOperation(absolutePaths, OP_WRITE, secrets) + return err +} + // Delete deletes secret at given absolutePath, using given client func (client *Client) Delete(absolutePath string) (err error) { if client.isTopLevelPath(absolutePath) { @@ -134,14 +154,19 @@ func (client *Client) Delete(absolutePath string) (err error) { // List elements at the given absolutePath, using the given client func (client *Client) List(absolutePath string) (result []string, err error) { + defer client.cacheMutex.Unlock() + + client.cacheMutex.Lock() if val, ok := client.listCache[absolutePath]; ok { return val, nil } + client.cacheMutex.Unlock() // reading from Vault will be relatively slow if client.isTopLevelPath(absolutePath) { result = client.listTopLevel() } else { result, err = client.listLowLevel(normalizedVaultPath(absolutePath)) } + client.cacheMutex.Lock() client.listCache[absolutePath] = result return result, err @@ -186,5 +211,7 @@ func (client *Client) SubpathsForPath(path string) (filePaths []string, err erro // ClearCache clears the list cache func (client *Client) ClearCache() { + client.cacheMutex.Lock() client.listCache = make(map[string][]string) + client.cacheMutex.Unlock() } diff --git a/client/secret.go b/client/secret.go index f2099aa7..6d8f4a5a 100644 --- a/client/secret.go +++ b/client/secret.go @@ -7,12 +7,14 @@ import ( // Secret holds vault secret and offers operations to simplify KV abstraction type Secret struct { vaultSecret *api.Secret + Path string } // NewSecret create a new Secret object -func NewSecret(vaultSecret *api.Secret) *Secret { +func NewSecret(vaultSecret *api.Secret, path string) *Secret { return &Secret{ vaultSecret: vaultSecret, + Path: path, } } diff --git a/client/type.go b/client/type.go index ccc41898..e30c69a8 100644 --- a/client/type.go +++ b/client/type.go @@ -26,6 +26,15 @@ func (client *Client) topLevelType(path string) PathKind { } } +func (client *Client) FilterPaths(paths []string, kind PathKind) (filtered []string) { + for _, path := range paths { + if client.GetType(path) == kind { + filtered = append(filtered, path) + } + } + return filtered +} + var cachedPath = "" var cachedDirFiles = make(map[string]int)