Skip to content

Commit

Permalink
support workload identity federation (#151)
Browse files Browse the repository at this point in the history
Co-authored-by: Vinay Gopalan <vinay@hashicorp.com>
  • Loading branch information
fairclothjm and vinay-gopalan committed May 8, 2024
1 parent a24ca1e commit 52ae174
Show file tree
Hide file tree
Showing 13 changed files with 395 additions and 158 deletions.
2 changes: 1 addition & 1 deletion .go-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.20.7
1.22.2
56 changes: 55 additions & 1 deletion azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,14 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/coreos/go-oidc"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/vault-plugin-auth-azure/client"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/pluginidentityutil"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/helper/useragent"
"github.com/hashicorp/vault/sdk/logical"
"golang.org/x/oauth2"

"github.com/hashicorp/vault-plugin-auth-azure/client"
)

// https://learn.microsoft.com/en-us/graph/sdks/national-clouds
Expand All @@ -54,6 +58,8 @@ type azureProvider struct {
oidcVerifier *oidc.IDTokenVerifier
settings *azureSettings
httpClient *http.Client
logger hclog.Logger
systemView logical.SystemView
}

type oidcDiscoveryInfo struct {
Expand Down Expand Up @@ -136,6 +142,8 @@ func (b *azureAuthBackend) newAzureProvider(ctx context.Context, config *azureCo
settings: settings,
oidcVerifier: oidcVerifier,
httpClient: httpClient,
logger: b.Logger(),
systemView: b.System(),
}, nil
}

Expand Down Expand Up @@ -266,6 +274,24 @@ func (p *azureProvider) getTokenCredential() (azcore.TokenCredential, error) {
return cred, nil
}

if p.settings.IdentityTokenAudience != "" {
options := &azidentity.ClientAssertionCredentialOptions{
ClientOptions: clientCloudOpts,
}
getAssertion := getAssertionFunc(p.logger, p.systemView, p.settings)
cred, err := azidentity.NewClientAssertionCredential(
p.settings.TenantID,
p.settings.ClientID,
getAssertion,
options,
)
if err != nil {
return nil, fmt.Errorf("failed to create client assertion credential: %w", err)
}

return cred, nil
}

// Fall back to using managed service identity
options := &azidentity.ManagedIdentityCredentialOptions{
ClientOptions: clientCloudOpts,
Expand All @@ -278,7 +304,32 @@ func (p *azureProvider) getTokenCredential() (azcore.TokenCredential, error) {
return cred, nil
}

type getAssertion func(context.Context) (string, error)

func getAssertionFunc(logger hclog.Logger, sys logical.SystemView, s *azureSettings) getAssertion {
return func(ctx context.Context) (string, error) {
req := &pluginutil.IdentityTokenRequest{
Audience: s.IdentityTokenAudience,
TTL: s.IdentityTokenTTL * time.Second,
}
resp, err := sys.GenerateIdentityToken(ctx, req)
if err != nil {
return "", fmt.Errorf("failed to generate plugin identity token: %w", err)
}
logger.Info("fetched new plugin identity token")

if resp.TTL < req.TTL {
logger.Debug("generated plugin identity token has shorter TTL than requested",
"requested", req.TTL, "actual", resp.TTL)
}

return resp.Token.Token(), nil
}
}

type azureSettings struct {
pluginidentityutil.PluginIdentityTokenParams

TenantID string
ClientID string
ClientSecret string
Expand Down Expand Up @@ -330,6 +381,9 @@ func (b *azureAuthBackend) getAzureSettings(ctx context.Context, config *azureCo
}
settings.ClientSecret = clientSecret

settings.IdentityTokenAudience = config.IdentityTokenAudience
settings.IdentityTokenTTL = config.IdentityTokenTTL

environment := os.Getenv("AZURE_ENVIRONMENT")
if environment == "" {
// set environment from config
Expand Down
35 changes: 18 additions & 17 deletions azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/coreos/go-oidc"

"github.com/hashicorp/vault-plugin-auth-azure/client"
)

Expand Down Expand Up @@ -45,41 +46,41 @@ func newMockVerifier() client.TokenVerifier {
}

type mockComputeClient struct {
computeClientFunc func(vmName string) (armcompute.VirtualMachinesClientGetResponse, error)
computeClientFunc computeClientFunc
}

type mockVMSSClient struct {
vmssClientFunc func(vmssName string) (armcompute.VirtualMachineScaleSetsClientGetResponse, error)
vmssClientFunc vmssClientFunc
}

type mockMSIClient struct {
msiClientFunc func(resourceName string) (armmsi.UserAssignedIdentitiesClientGetResponse, error)
msiListFunc func(resourceGroup string) armmsi.UserAssignedIdentitiesClientListByResourceGroupResponse
msiClientFunc msiClientFunc
msiListFunc msiListFunc
}

type mockResourceClient struct {
resourceClientFunc func(resourceID string) (armresources.ClientGetByIDResponse, error)
resourceClientFunc resourceClientFunc
}

type mockProvidersClient struct {
providersClientFunc func(string) (armresources.ProvidersClientGetResponse, error)
providersClientFunc providersClientFunc
}

func (c *mockComputeClient) Get(_ context.Context, _, vmName string, _ *armcompute.VirtualMachinesClientGetOptions) (armcompute.VirtualMachinesClientGetResponse, error) {
func (c *mockComputeClient) Get(ctx context.Context, _, vmName string, _ *armcompute.VirtualMachinesClientGetOptions) (armcompute.VirtualMachinesClientGetResponse, error) {
if c.computeClientFunc != nil {
return c.computeClientFunc(vmName)
}
return armcompute.VirtualMachinesClientGetResponse{}, nil
}

func (c *mockVMSSClient) Get(_ context.Context, _, vmssName string, _ *armcompute.VirtualMachineScaleSetsClientGetOptions) (armcompute.VirtualMachineScaleSetsClientGetResponse, error) {
func (c *mockVMSSClient) Get(ctx context.Context, _, vmssName string, _ *armcompute.VirtualMachineScaleSetsClientGetOptions) (armcompute.VirtualMachineScaleSetsClientGetResponse, error) {
if c.vmssClientFunc != nil {
return c.vmssClientFunc(vmssName)
}
return armcompute.VirtualMachineScaleSetsClientGetResponse{}, nil
}

func (c *mockMSIClient) Get(_ context.Context, _, resourceName string, _ *armmsi.UserAssignedIdentitiesClientGetOptions) (armmsi.UserAssignedIdentitiesClientGetResponse, error) {
func (c *mockMSIClient) Get(ctx context.Context, _, resourceName string, _ *armmsi.UserAssignedIdentitiesClientGetOptions) (armmsi.UserAssignedIdentitiesClientGetResponse, error) {
if c.msiClientFunc != nil {
return c.msiClientFunc(resourceName)
}
Expand All @@ -101,14 +102,14 @@ func (c *mockMSIClient) NewListByResourceGroupPager(resourceGroup string, _ *arm
return nil
}

func (c *mockResourceClient) GetByID(_ context.Context, resourceID, _ string, _ *armresources.ClientGetByIDOptions) (armresources.ClientGetByIDResponse, error) {
func (c *mockResourceClient) GetByID(ctx context.Context, resourceID, _ string, _ *armresources.ClientGetByIDOptions) (armresources.ClientGetByIDResponse, error) {
if c.resourceClientFunc != nil {
return c.resourceClientFunc(resourceID)
}
return armresources.ClientGetByIDResponse{}, nil
}

func (c *mockProvidersClient) Get(_ context.Context, resourceID string, _ *armresources.ProvidersClientGetOptions) (armresources.ProvidersClientGetResponse, error) {
func (c *mockProvidersClient) Get(ctx context.Context, resourceID string, _ *armresources.ProvidersClientGetOptions) (armresources.ProvidersClientGetResponse, error) {
if c.providersClientFunc != nil {
return c.providersClientFunc(resourceID)
}
Expand All @@ -127,7 +128,7 @@ type msGraphClientFunc func() (client.MSGraphClient, error)

type resourceClientFunc func(resourceID string) (armresources.ClientGetByIDResponse, error)

type providersClientFunc func(string) (armresources.ProvidersClientGetResponse, error)
type providersClientFunc func(s string) (armresources.ProvidersClientGetResponse, error)

type mockProvider struct {
computeClientFunc
Expand All @@ -153,19 +154,19 @@ func (*mockProvider) TokenVerifier() client.TokenVerifier {
return newMockVerifier()
}

func (p *mockProvider) ComputeClient(string) (client.ComputeClient, error) {
func (p *mockProvider) ComputeClient(subscriptionID string) (client.ComputeClient, error) {
return &mockComputeClient{
computeClientFunc: p.computeClientFunc,
}, nil
}

func (p *mockProvider) VMSSClient(string) (client.VMSSClient, error) {
func (p *mockProvider) VMSSClient(subscriptionID string) (client.VMSSClient, error) {
return &mockVMSSClient{
vmssClientFunc: p.vmssClientFunc,
}, nil
}

func (p *mockProvider) MSIClient(string) (client.MSIClient, error) {
func (p *mockProvider) MSIClient(subscriptionID string) (client.MSIClient, error) {
return &mockMSIClient{
msiClientFunc: p.msiClientFunc,
msiListFunc: p.msiListFunc,
Expand All @@ -176,13 +177,13 @@ func (p *mockProvider) MSGraphClient() (client.MSGraphClient, error) {
return nil, nil
}

func (p *mockProvider) ResourceClient(string) (client.ResourceClient, error) {
func (p *mockProvider) ResourceClient(subscriptionID string) (client.ResourceClient, error) {
return &mockResourceClient{
resourceClientFunc: p.resourceClientFunc,
}, nil
}

func (p *mockProvider) ProvidersClient(string) (client.ProvidersClient, error) {
func (p *mockProvider) ProvidersClient(subscriptionID string) (client.ProvidersClient, error) {
return &mockProvidersClient{
providersClientFunc: p.providersClientFunc,
}, nil
Expand Down
6 changes: 3 additions & 3 deletions backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func backend() *azureAuthBackend {
// operations. Due to Azure's eventual consistency model, the new credential will not be
// available immediately, and hence we check periodically and delete the old credential
// only once the new credential is at least a minute old
func (b *azureAuthBackend) periodicFunc(ctx context.Context, sys *logical.Request) error {
func (b *azureAuthBackend) periodicFunc(ctx context.Context, req *logical.Request) error {
// Root rotation through the periodic func writes to storage. Only run this on the
// active instance in the primary cluster or local mounts. The periodic func doesn't
// run on perf standbys or DR secondaries, but we still protect against this here.
Expand All @@ -106,7 +106,7 @@ func (b *azureAuthBackend) periodicFunc(ctx context.Context, sys *logical.Reques
return nil
}

config, err := b.config(ctx, sys.Storage)
config, err := b.config(ctx, req.Storage)
if err != nil {
return err
}
Expand Down Expand Up @@ -161,7 +161,7 @@ func (b *azureAuthBackend) periodicFunc(ctx context.Context, sys *logical.Reques
config.NewClientSecretKeyID = ""
config.NewClientSecretCreated = time.Time{}

err = b.saveConfig(ctx, config, sys.Storage)
err = b.saveConfig(ctx, config, req.Storage)
if err != nil {
return err
}
Expand Down
36 changes: 22 additions & 14 deletions backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,32 @@ import (
"time"

log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
)

type testSystemViewEnt struct {
logical.StaticSystemView
}

func (d testSystemViewEnt) GenerateIdentityToken(_ context.Context, _ *pluginutil.IdentityTokenRequest) (*pluginutil.IdentityTokenResponse, error) {
return &pluginutil.IdentityTokenResponse{}, nil
}

func getTestBackend(t *testing.T) (*azureAuthBackend, logical.Storage) {
return getTestBackendWithComputeClient(t, nil, nil, nil, nil, nil)
}

func getTestBackendWithResourceClient(t *testing.T, r resourceClientFunc, p providersClientFunc) (*azureAuthBackend, logical.Storage) {
t.Helper()
defaultLeaseTTLVal := time.Hour * 12
maxLeaseTTLVal := time.Hour * 24

sysView := testSystemViewEnt{}
sysView.DefaultLeaseTTLVal = time.Hour * 12
sysView.MaxLeaseTTLVal = time.Hour * 24

config := &logical.BackendConfig{
Logger: log.New(&log.LoggerOptions{Level: log.Trace}),
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: defaultLeaseTTLVal,
MaxLeaseTTLVal: maxLeaseTTLVal,
},
Logger: log.New(&log.LoggerOptions{Level: log.Trace}),
System: &sysView,
StorageView: &logical.InmemStorage{},
}
b := backend()
Expand All @@ -43,14 +52,13 @@ func getTestBackendWithResourceClient(t *testing.T, r resourceClientFunc, p prov

func getTestBackendWithComputeClient(t *testing.T, c computeClientFunc, v vmssClientFunc, m msiClientFunc, ml msiListFunc, g msGraphClientFunc) (*azureAuthBackend, logical.Storage) {
t.Helper()
defaultLeaseTTLVal := time.Hour * 12
maxLeaseTTLVal := time.Hour * 24
sysView := testSystemViewEnt{}
sysView.DefaultLeaseTTLVal = time.Hour * 12
sysView.MaxLeaseTTLVal = time.Hour * 24

config := &logical.BackendConfig{
Logger: log.New(&log.LoggerOptions{Level: log.Trace}),
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: defaultLeaseTTLVal,
MaxLeaseTTLVal: maxLeaseTTLVal,
},
Logger: log.New(&log.LoggerOptions{Level: log.Trace}),
System: &sysView,
StorageView: &logical.InmemStorage{},
}
b := backend()
Expand Down
2 changes: 1 addition & 1 deletion bootstrap/terraform/vm.tf
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ resource "azurerm_public_ip" "vault_azure_pub_ip" {

# Restrict SSH access to the local public IP
data "http" "my_ip" {
url = "https://ifconfig.me/ip"
url = "https://ipinfo.io/ip"
}

resource "azurerm_network_security_group" "vault_azure_sg" {
Expand Down
Loading

0 comments on commit 52ae174

Please sign in to comment.