Skip to content

Commit

Permalink
CLI: Tune plugin version for auth/secret mounts (#17277)
Browse files Browse the repository at this point in the history
* Add -plugin-version flag to vault auth/secrets tune
* CLI tests for auth/secrets tune
* CLI test for plugin register
* Plugin catalog listing bug where plugins of different type with the same name could be double counted
* Use constant for -plugin-version flag name
  • Loading branch information
tomhjp committed Sep 22, 2022
1 parent 6fc6bb1 commit 21d1363
Show file tree
Hide file tree
Showing 13 changed files with 208 additions and 55 deletions.
2 changes: 1 addition & 1 deletion api/sys_mounts.go
Expand Up @@ -247,7 +247,6 @@ type MountInput struct {
SealWrap bool `json:"seal_wrap" mapstructure:"seal_wrap"`
ExternalEntropyAccess bool `json:"external_entropy_access" mapstructure:"external_entropy_access"`
Options map[string]string `json:"options"`
PluginVersion string `json:"plugin_version,omitempty"`

// Deprecated: Newer server responses should be returning this information in the
// Type field (json: "type") instead.
Expand All @@ -267,6 +266,7 @@ type MountConfigInput struct {
AllowedResponseHeaders []string `json:"allowed_response_headers,omitempty" mapstructure:"allowed_response_headers"`
TokenType string `json:"token_type,omitempty" mapstructure:"token_type"`
AllowedManagedKeys []string `json:"allowed_managed_keys,omitempty" mapstructure:"allowed_managed_keys"`
PluginVersion string `json:"plugin_version,omitempty"`

// Deprecated: This field will always be blank for newer server responses.
PluginName string `json:"plugin_name,omitempty" mapstructure:"plugin_name"`
Expand Down
7 changes: 5 additions & 2 deletions command/auth_enable.go
Expand Up @@ -201,7 +201,7 @@ func (c *AuthEnableCommand) Flags() *FlagSets {
})

f.StringVar(&StringVar{
Name: "plugin-version",
Name: flagNamePluginVersion,
Target: &c.flagPluginVersion,
Default: "",
Usage: "Select the semantic version of the plugin to enable.",
Expand Down Expand Up @@ -270,7 +270,6 @@ func (c *AuthEnableCommand) Run(args []string) int {

authOpts := &api.EnableAuthOptions{
Type: authType,
PluginVersion: c.flagPluginVersion,
Description: c.flagDescription,
Local: c.flagLocal,
SealWrap: c.flagSealWrap,
Expand Down Expand Up @@ -307,6 +306,10 @@ func (c *AuthEnableCommand) Run(args []string) int {
if fl.Name == flagNameTokenType {
authOpts.Config.TokenType = c.flagTokenType
}

if fl.Name == flagNamePluginVersion {
authOpts.Config.PluginVersion = c.flagPluginVersion
}
})

if err := client.Sys().EnableAuthWithOptions(authPath, authOpts); err != nil {
Expand Down
13 changes: 13 additions & 0 deletions command/auth_tune.go
Expand Up @@ -31,6 +31,7 @@ type AuthTuneCommand struct {
flagOptions map[string]string
flagTokenType string
flagVersion int
flagPluginVersion string
}

func (c *AuthTuneCommand) Synopsis() string {
Expand Down Expand Up @@ -144,6 +145,14 @@ func (c *AuthTuneCommand) Flags() *FlagSets {
Usage: "Select the version of the auth method to run. Not supported by all auth methods.",
})

f.StringVar(&StringVar{
Name: flagNamePluginVersion,
Target: &c.flagPluginVersion,
Default: "",
Usage: "Select the semantic version of the plugin to run. The new version must be registered in " +
"the plugin catalog, and will not start running until the plugin is reloaded.",
})

return set
}

Expand Down Expand Up @@ -221,6 +230,10 @@ func (c *AuthTuneCommand) Run(args []string) int {
if fl.Name == flagNameTokenType {
mountConfigInput.TokenType = c.flagTokenType
}

if fl.Name == flagNamePluginVersion {
mountConfigInput.PluginVersion = c.flagPluginVersion
}
})

// Append /auth (since that's where auths live) and a trailing slash to
Expand Down
30 changes: 27 additions & 3 deletions command/auth_tune_test.go
Expand Up @@ -6,6 +6,8 @@ import (

"github.com/go-test/deep"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/vault"
"github.com/mitchellh/cli"
)

Expand Down Expand Up @@ -74,7 +76,10 @@ func TestAuthTuneCommand_Run(t *testing.T) {
t.Run("integration", func(t *testing.T) {
t.Run("flags_all", func(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
pluginDir, cleanup := vault.MakeTestPluginDir(t)
defer cleanup(t)

client, _, closer := testVaultServerPluginDir(t, pluginDir)
defer closer()

ui, cmd := testAuthTuneCommand(t)
Expand All @@ -87,6 +92,21 @@ func TestAuthTuneCommand_Run(t *testing.T) {
t.Fatal(err)
}

auths, err := client.Sys().ListAuth()
if err != nil {
t.Fatal(err)
}
mountInfo, ok := auths["my-auth/"]
if !ok {
t.Fatalf("expected mount to exist: %#v", auths)
}

if exp := ""; mountInfo.PluginVersion != exp {
t.Errorf("expected %q to be %q", mountInfo.PluginVersion, exp)
}

_, _, version := testPluginCreateAndRegisterVersioned(t, client, pluginDir, "userpass", consts.PluginTypeCredential)

code := cmd.Run([]string{
"-description", "new description",
"-default-lease-ttl", "30m",
Expand All @@ -97,6 +117,7 @@ func TestAuthTuneCommand_Run(t *testing.T) {
"-passthrough-request-headers", "www-authentication",
"-allowed-response-headers", "authorization,www-authentication",
"-listing-visibility", "unauth",
"-plugin-version", version,
"my-auth/",
})
if exp := 0; code != exp {
Expand All @@ -109,12 +130,12 @@ func TestAuthTuneCommand_Run(t *testing.T) {
t.Errorf("expected %q to contain %q", combined, expected)
}

auths, err := client.Sys().ListAuth()
auths, err = client.Sys().ListAuth()
if err != nil {
t.Fatal(err)
}

mountInfo, ok := auths["my-auth/"]
mountInfo, ok = auths["my-auth/"]
if !ok {
t.Fatalf("expected auth to exist")
}
Expand All @@ -124,6 +145,9 @@ func TestAuthTuneCommand_Run(t *testing.T) {
if exp := "userpass"; mountInfo.Type != exp {
t.Errorf("expected %q to be %q", mountInfo.Type, exp)
}
if exp := version; mountInfo.PluginVersion != exp {
t.Errorf("expected %q to be %q", mountInfo.PluginVersion, exp)
}
if exp := 1800; mountInfo.Config.DefaultLeaseTTL != exp {
t.Errorf("expected %d to be %d", mountInfo.Config.DefaultLeaseTTL, exp)
}
Expand Down
2 changes: 2 additions & 0 deletions command/commands.go
Expand Up @@ -124,6 +124,8 @@ const (
flagNameTokenType = "token-type"
// flagNameAllowedManagedKeys is the flag name used for auth/secrets enable
flagNameAllowedManagedKeys = "allowed-managed-keys"
// flagNamePluginVersion selects what version of a plugin should be used.
flagNamePluginVersion = "plugin-version"
)

var (
Expand Down
71 changes: 71 additions & 0 deletions command/plugin_register_test.go
@@ -1,6 +1,8 @@
package command

import (
"reflect"
"sort"
"strings"
"testing"

Expand Down Expand Up @@ -124,6 +126,75 @@ func TestPluginRegisterCommand_Run(t *testing.T) {
}
})

t.Run("integration with version", func(t *testing.T) {
t.Parallel()

pluginDir, cleanup := vault.MakeTestPluginDir(t)
defer cleanup(t)

client, _, closer := testVaultServerPluginDir(t, pluginDir)
defer closer()

const pluginName = "my-plugin"
versions := []string{"v1.0.0", "v2.0.1"}
_, sha256Sum := testPluginCreate(t, pluginDir, pluginName)
types := []consts.PluginType{consts.PluginTypeCredential, consts.PluginTypeDatabase, consts.PluginTypeSecrets}

for _, typ := range types {
for _, version := range versions {
ui, cmd := testPluginRegisterCommand(t)
cmd.client = client

code := cmd.Run([]string{
"-version=" + version,
"-sha256=" + sha256Sum,
typ.String(),
pluginName,
})
if exp := 0; code != exp {
t.Errorf("expected %d to be %d", code, exp)
}

expected := "Success! Registered plugin: my-plugin"
combined := ui.OutputWriter.String() + ui.ErrorWriter.String()
if !strings.Contains(combined, expected) {
t.Errorf("expected %q to contain %q", combined, expected)
}
}
}

resp, err := client.Sys().ListPlugins(&api.ListPluginsInput{
Type: consts.PluginTypeUnknown,
})
if err != nil {
t.Fatal(err)
}

found := make(map[consts.PluginType]int)
versionsFound := make(map[consts.PluginType][]string)
for _, p := range resp.Details {
if p.Name == pluginName {
typ, err := consts.ParsePluginType(p.Type)
if err != nil {
t.Fatal(err)
}
found[typ]++
versionsFound[typ] = append(versionsFound[typ], p.Version)
}
}

for _, typ := range types {
if found[typ] != 2 {
t.Fatalf("expected %q to be found 2 times, but found it %d times for %s type in %#v", pluginName, found[typ], typ.String(), resp.Details)
}
sort.Strings(versions)
sort.Strings(versionsFound[typ])
if !reflect.DeepEqual(versions, versionsFound[typ]) {
t.Fatalf("expected %v versions but got %v", versions, versionsFound[typ])
}
}
})

t.Run("communication_failure", func(t *testing.T) {
t.Parallel()

Expand Down
12 changes: 12 additions & 0 deletions command/secrets_enable.go
Expand Up @@ -32,6 +32,7 @@ type SecretsEnableCommand struct {
flagAllowedResponseHeaders []string
flagForceNoCache bool
flagPluginName string
flagPluginVersion string
flagOptions map[string]string
flagLocal bool
flagSealWrap bool
Expand Down Expand Up @@ -173,6 +174,13 @@ func (c *SecretsEnableCommand) Flags() *FlagSets {
"exist in Vault's plugin catalog.",
})

f.StringVar(&StringVar{
Name: flagNamePluginVersion,
Target: &c.flagPluginVersion,
Default: "",
Usage: "Select the semantic version of the plugin to enable.",
})

f.StringMapVar(&StringMapVar{
Name: "options",
Target: &c.flagOptions,
Expand Down Expand Up @@ -320,6 +328,10 @@ func (c *SecretsEnableCommand) Run(args []string) int {
if fl.Name == flagNameAllowedManagedKeys {
mountInput.Config.AllowedManagedKeys = c.flagAllowedManagedKeys
}

if fl.Name == flagNamePluginVersion {
mountInput.Config.PluginVersion = c.flagPluginVersion
}
})

if err := client.Sys().Mount(mountPath, mountInput); err != nil {
Expand Down
13 changes: 13 additions & 0 deletions command/secrets_tune.go
Expand Up @@ -30,6 +30,7 @@ type SecretsTuneCommand struct {
flagAllowedResponseHeaders []string
flagOptions map[string]string
flagVersion int
flagPluginVersion string
flagAllowedManagedKeys []string
}

Expand Down Expand Up @@ -146,6 +147,14 @@ func (c *SecretsTuneCommand) Flags() *FlagSets {
"each time with 1 key.",
})

f.StringVar(&StringVar{
Name: flagNamePluginVersion,
Target: &c.flagPluginVersion,
Default: "",
Usage: "Select the semantic version of the plugin to run. The new version must be registered in " +
"the plugin catalog, and will not start running until the plugin is reloaded.",
})

return set
}

Expand Down Expand Up @@ -226,6 +235,10 @@ func (c *SecretsTuneCommand) Run(args []string) int {
if fl.Name == flagNameAllowedManagedKeys {
mountConfigInput.AllowedManagedKeys = c.flagAllowedManagedKeys
}

if fl.Name == flagNamePluginVersion {
mountConfigInput.PluginVersion = c.flagPluginVersion
}
})

if err := client.Sys().TuneMount(mountPath, mountConfigInput); err != nil {
Expand Down
30 changes: 27 additions & 3 deletions command/secrets_tune_test.go
Expand Up @@ -6,6 +6,8 @@ import (

"github.com/go-test/deep"
"github.com/hashicorp/vault/api"
"github.com/hashicorp/vault/sdk/helper/consts"
"github.com/hashicorp/vault/vault"
"github.com/mitchellh/cli"
)

Expand Down Expand Up @@ -148,7 +150,10 @@ func TestSecretsTuneCommand_Run(t *testing.T) {
t.Run("integration", func(t *testing.T) {
t.Run("flags_all", func(t *testing.T) {
t.Parallel()
client, closer := testVaultServer(t)
pluginDir, cleanup := vault.MakeTestPluginDir(t)
defer cleanup(t)

client, _, closer := testVaultServerPluginDir(t, pluginDir)
defer closer()

ui, cmd := testSecretsTuneCommand(t)
Expand All @@ -161,6 +166,21 @@ func TestSecretsTuneCommand_Run(t *testing.T) {
t.Fatal(err)
}

mounts, err := client.Sys().ListMounts()
if err != nil {
t.Fatal(err)
}
mountInfo, ok := mounts["mount_tune_integration/"]
if !ok {
t.Fatalf("expected mount to exist")
}

if exp := ""; mountInfo.PluginVersion != exp {
t.Errorf("expected %q to be %q", mountInfo.PluginVersion, exp)
}

_, _, version := testPluginCreateAndRegisterVersioned(t, client, pluginDir, "pki", consts.PluginTypeSecrets)

code := cmd.Run([]string{
"-description", "new description",
"-default-lease-ttl", "30m",
Expand All @@ -172,6 +192,7 @@ func TestSecretsTuneCommand_Run(t *testing.T) {
"-allowed-response-headers", "authorization,www-authentication",
"-allowed-managed-keys", "key1,key2",
"-listing-visibility", "unauth",
"-plugin-version", version,
"mount_tune_integration/",
})
if exp := 0; code != exp {
Expand All @@ -184,12 +205,12 @@ func TestSecretsTuneCommand_Run(t *testing.T) {
t.Errorf("expected %q to contain %q", combined, expected)
}

mounts, err := client.Sys().ListMounts()
mounts, err = client.Sys().ListMounts()
if err != nil {
t.Fatal(err)
}

mountInfo, ok := mounts["mount_tune_integration/"]
mountInfo, ok = mounts["mount_tune_integration/"]
if !ok {
t.Fatalf("expected mount to exist")
}
Expand All @@ -199,6 +220,9 @@ func TestSecretsTuneCommand_Run(t *testing.T) {
if exp := "pki"; mountInfo.Type != exp {
t.Errorf("expected %q to be %q", mountInfo.Type, exp)
}
if exp := version; mountInfo.PluginVersion != exp {
t.Errorf("expected %q to be %q", mountInfo.PluginVersion, exp)
}
if exp := 1800; mountInfo.Config.DefaultLeaseTTL != exp {
t.Errorf("expected %d to be %d", mountInfo.Config.DefaultLeaseTTL, exp)
}
Expand Down

0 comments on commit 21d1363

Please sign in to comment.