Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support workload identity federation #151

Merged
merged 11 commits into from
May 8, 2024
2 changes: 1 addition & 1 deletion .go-version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
1.20.7
1.22.2
56 changes: 55 additions & 1 deletion azure.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,14 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/coreos/go-oidc"
"github.com/hashicorp/go-cleanhttp"
"github.com/hashicorp/vault-plugin-auth-azure/client"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/pluginidentityutil"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/helper/useragent"
"github.com/hashicorp/vault/sdk/logical"
"golang.org/x/oauth2"

"github.com/hashicorp/vault-plugin-auth-azure/client"
)

// https://learn.microsoft.com/en-us/graph/sdks/national-clouds
Expand All @@ -54,6 +58,8 @@ type azureProvider struct {
oidcVerifier *oidc.IDTokenVerifier
settings *azureSettings
httpClient *http.Client
logger hclog.Logger
systemView logical.SystemView
}

type oidcDiscoveryInfo struct {
Expand Down Expand Up @@ -136,6 +142,8 @@ func (b *azureAuthBackend) newAzureProvider(ctx context.Context, config *azureCo
settings: settings,
oidcVerifier: oidcVerifier,
httpClient: httpClient,
logger: b.Logger(),
systemView: b.System(),
}, nil
}

Expand Down Expand Up @@ -266,6 +274,24 @@ func (p *azureProvider) getTokenCredential() (azcore.TokenCredential, error) {
return cred, nil
}

if p.settings.IdentityTokenAudience != "" {
options := &azidentity.ClientAssertionCredentialOptions{
ClientOptions: clientCloudOpts,
}
getAssertion := getAssertionFunc(p.logger, p.systemView, p.settings)
cred, err := azidentity.NewClientAssertionCredential(
p.settings.TenantID,
p.settings.ClientID,
getAssertion,
options,
)
if err != nil {
return nil, fmt.Errorf("failed to create client assertion credential: %w", err)
}

return cred, nil
}

// Fall back to using managed service identity
options := &azidentity.ManagedIdentityCredentialOptions{
ClientOptions: clientCloudOpts,
Expand All @@ -278,7 +304,32 @@ func (p *azureProvider) getTokenCredential() (azcore.TokenCredential, error) {
return cred, nil
}

type getAssertion func(context.Context) (string, error)

func getAssertionFunc(logger hclog.Logger, sys logical.SystemView, s *azureSettings) getAssertion {
return func(ctx context.Context) (string, error) {
req := &pluginutil.IdentityTokenRequest{
Audience: s.IdentityTokenAudience,
TTL: s.IdentityTokenTTL * time.Second,
}
resp, err := sys.GenerateIdentityToken(ctx, req)
if err != nil {
return "", fmt.Errorf("failed to generate plugin identity token: %w", err)
}
logger.Info("fetched new plugin identity token")

if resp.TTL < req.TTL {
logger.Debug("generated plugin identity token has shorter TTL than requested",
"requested", req.TTL, "actual", resp.TTL)
}

return resp.Token.Token(), nil
}
}

type azureSettings struct {
pluginidentityutil.PluginIdentityTokenParams

TenantID string
ClientID string
ClientSecret string
Expand Down Expand Up @@ -330,6 +381,9 @@ func (b *azureAuthBackend) getAzureSettings(ctx context.Context, config *azureCo
}
settings.ClientSecret = clientSecret

settings.IdentityTokenAudience = config.IdentityTokenAudience
settings.IdentityTokenTTL = config.IdentityTokenTTL

vinay-gopalan marked this conversation as resolved.
Show resolved Hide resolved
environment := os.Getenv("AZURE_ENVIRONMENT")
if environment == "" {
// set environment from config
Expand Down
35 changes: 18 additions & 17 deletions azure_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/msi/armmsi"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/resources/armresources"
"github.com/coreos/go-oidc"

"github.com/hashicorp/vault-plugin-auth-azure/client"
)

Expand Down Expand Up @@ -45,41 +46,41 @@ func newMockVerifier() client.TokenVerifier {
}

type mockComputeClient struct {
computeClientFunc func(vmName string) (armcompute.VirtualMachinesClientGetResponse, error)
computeClientFunc computeClientFunc
}

type mockVMSSClient struct {
vmssClientFunc func(vmssName string) (armcompute.VirtualMachineScaleSetsClientGetResponse, error)
vmssClientFunc vmssClientFunc
}

type mockMSIClient struct {
msiClientFunc func(resourceName string) (armmsi.UserAssignedIdentitiesClientGetResponse, error)
msiListFunc func(resourceGroup string) armmsi.UserAssignedIdentitiesClientListByResourceGroupResponse
msiClientFunc msiClientFunc
msiListFunc msiListFunc
}

type mockResourceClient struct {
resourceClientFunc func(resourceID string) (armresources.ClientGetByIDResponse, error)
resourceClientFunc resourceClientFunc
}

type mockProvidersClient struct {
providersClientFunc func(string) (armresources.ProvidersClientGetResponse, error)
providersClientFunc providersClientFunc
}

func (c *mockComputeClient) Get(_ context.Context, _, vmName string, _ *armcompute.VirtualMachinesClientGetOptions) (armcompute.VirtualMachinesClientGetResponse, error) {
func (c *mockComputeClient) Get(ctx context.Context, _, vmName string, _ *armcompute.VirtualMachinesClientGetOptions) (armcompute.VirtualMachinesClientGetResponse, error) {
if c.computeClientFunc != nil {
return c.computeClientFunc(vmName)
}
return armcompute.VirtualMachinesClientGetResponse{}, nil
}

func (c *mockVMSSClient) Get(_ context.Context, _, vmssName string, _ *armcompute.VirtualMachineScaleSetsClientGetOptions) (armcompute.VirtualMachineScaleSetsClientGetResponse, error) {
func (c *mockVMSSClient) Get(ctx context.Context, _, vmssName string, _ *armcompute.VirtualMachineScaleSetsClientGetOptions) (armcompute.VirtualMachineScaleSetsClientGetResponse, error) {
if c.vmssClientFunc != nil {
return c.vmssClientFunc(vmssName)
}
return armcompute.VirtualMachineScaleSetsClientGetResponse{}, nil
}

func (c *mockMSIClient) Get(_ context.Context, _, resourceName string, _ *armmsi.UserAssignedIdentitiesClientGetOptions) (armmsi.UserAssignedIdentitiesClientGetResponse, error) {
func (c *mockMSIClient) Get(ctx context.Context, _, resourceName string, _ *armmsi.UserAssignedIdentitiesClientGetOptions) (armmsi.UserAssignedIdentitiesClientGetResponse, error) {
if c.msiClientFunc != nil {
return c.msiClientFunc(resourceName)
}
Expand All @@ -101,14 +102,14 @@ func (c *mockMSIClient) NewListByResourceGroupPager(resourceGroup string, _ *arm
return nil
}

func (c *mockResourceClient) GetByID(_ context.Context, resourceID, _ string, _ *armresources.ClientGetByIDOptions) (armresources.ClientGetByIDResponse, error) {
func (c *mockResourceClient) GetByID(ctx context.Context, resourceID, _ string, _ *armresources.ClientGetByIDOptions) (armresources.ClientGetByIDResponse, error) {
if c.resourceClientFunc != nil {
return c.resourceClientFunc(resourceID)
}
return armresources.ClientGetByIDResponse{}, nil
}

func (c *mockProvidersClient) Get(_ context.Context, resourceID string, _ *armresources.ProvidersClientGetOptions) (armresources.ProvidersClientGetResponse, error) {
func (c *mockProvidersClient) Get(ctx context.Context, resourceID string, _ *armresources.ProvidersClientGetOptions) (armresources.ProvidersClientGetResponse, error) {
if c.providersClientFunc != nil {
return c.providersClientFunc(resourceID)
}
Expand All @@ -127,7 +128,7 @@ type msGraphClientFunc func() (client.MSGraphClient, error)

type resourceClientFunc func(resourceID string) (armresources.ClientGetByIDResponse, error)

type providersClientFunc func(string) (armresources.ProvidersClientGetResponse, error)
type providersClientFunc func(s string) (armresources.ProvidersClientGetResponse, error)

type mockProvider struct {
computeClientFunc
Expand All @@ -153,19 +154,19 @@ func (*mockProvider) TokenVerifier() client.TokenVerifier {
return newMockVerifier()
}

func (p *mockProvider) ComputeClient(string) (client.ComputeClient, error) {
func (p *mockProvider) ComputeClient(subscriptionID string) (client.ComputeClient, error) {
return &mockComputeClient{
computeClientFunc: p.computeClientFunc,
}, nil
}

func (p *mockProvider) VMSSClient(string) (client.VMSSClient, error) {
func (p *mockProvider) VMSSClient(subscriptionID string) (client.VMSSClient, error) {
return &mockVMSSClient{
vmssClientFunc: p.vmssClientFunc,
}, nil
}

func (p *mockProvider) MSIClient(string) (client.MSIClient, error) {
func (p *mockProvider) MSIClient(subscriptionID string) (client.MSIClient, error) {
return &mockMSIClient{
msiClientFunc: p.msiClientFunc,
msiListFunc: p.msiListFunc,
Expand All @@ -176,13 +177,13 @@ func (p *mockProvider) MSGraphClient() (client.MSGraphClient, error) {
return nil, nil
}

func (p *mockProvider) ResourceClient(string) (client.ResourceClient, error) {
func (p *mockProvider) ResourceClient(subscriptionID string) (client.ResourceClient, error) {
return &mockResourceClient{
resourceClientFunc: p.resourceClientFunc,
}, nil
}

func (p *mockProvider) ProvidersClient(string) (client.ProvidersClient, error) {
func (p *mockProvider) ProvidersClient(subscriptionID string) (client.ProvidersClient, error) {
return &mockProvidersClient{
providersClientFunc: p.providersClientFunc,
}, nil
Expand Down
6 changes: 3 additions & 3 deletions backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ func backend() *azureAuthBackend {
// operations. Due to Azure's eventual consistency model, the new credential will not be
// available immediately, and hence we check periodically and delete the old credential
// only once the new credential is at least a minute old
func (b *azureAuthBackend) periodicFunc(ctx context.Context, sys *logical.Request) error {
func (b *azureAuthBackend) periodicFunc(ctx context.Context, req *logical.Request) error {
// Root rotation through the periodic func writes to storage. Only run this on the
// active instance in the primary cluster or local mounts. The periodic func doesn't
// run on perf standbys or DR secondaries, but we still protect against this here.
Expand All @@ -106,7 +106,7 @@ func (b *azureAuthBackend) periodicFunc(ctx context.Context, sys *logical.Reques
return nil
}

config, err := b.config(ctx, sys.Storage)
config, err := b.config(ctx, req.Storage)
if err != nil {
return err
}
Expand Down Expand Up @@ -161,7 +161,7 @@ func (b *azureAuthBackend) periodicFunc(ctx context.Context, sys *logical.Reques
config.NewClientSecretKeyID = ""
config.NewClientSecretCreated = time.Time{}

err = b.saveConfig(ctx, config, sys.Storage)
err = b.saveConfig(ctx, config, req.Storage)
if err != nil {
return err
}
Expand Down
36 changes: 22 additions & 14 deletions backend_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,32 @@ import (
"time"

log "github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/helper/pluginutil"
"github.com/hashicorp/vault/sdk/logical"
)

type testSystemViewEnt struct {
logical.StaticSystemView
}

func (d testSystemViewEnt) GenerateIdentityToken(_ context.Context, _ *pluginutil.IdentityTokenRequest) (*pluginutil.IdentityTokenResponse, error) {
return &pluginutil.IdentityTokenResponse{}, nil
}

func getTestBackend(t *testing.T) (*azureAuthBackend, logical.Storage) {
return getTestBackendWithComputeClient(t, nil, nil, nil, nil, nil)
}

func getTestBackendWithResourceClient(t *testing.T, r resourceClientFunc, p providersClientFunc) (*azureAuthBackend, logical.Storage) {
t.Helper()
defaultLeaseTTLVal := time.Hour * 12
maxLeaseTTLVal := time.Hour * 24

sysView := testSystemViewEnt{}
sysView.DefaultLeaseTTLVal = time.Hour * 12
sysView.MaxLeaseTTLVal = time.Hour * 24

config := &logical.BackendConfig{
Logger: log.New(&log.LoggerOptions{Level: log.Trace}),
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: defaultLeaseTTLVal,
MaxLeaseTTLVal: maxLeaseTTLVal,
},
Logger: log.New(&log.LoggerOptions{Level: log.Trace}),
System: &sysView,
StorageView: &logical.InmemStorage{},
}
b := backend()
Expand All @@ -43,14 +52,13 @@ func getTestBackendWithResourceClient(t *testing.T, r resourceClientFunc, p prov

func getTestBackendWithComputeClient(t *testing.T, c computeClientFunc, v vmssClientFunc, m msiClientFunc, ml msiListFunc, g msGraphClientFunc) (*azureAuthBackend, logical.Storage) {
t.Helper()
defaultLeaseTTLVal := time.Hour * 12
maxLeaseTTLVal := time.Hour * 24
sysView := testSystemViewEnt{}
sysView.DefaultLeaseTTLVal = time.Hour * 12
sysView.MaxLeaseTTLVal = time.Hour * 24

config := &logical.BackendConfig{
Logger: log.New(&log.LoggerOptions{Level: log.Trace}),
System: &logical.StaticSystemView{
DefaultLeaseTTLVal: defaultLeaseTTLVal,
MaxLeaseTTLVal: maxLeaseTTLVal,
},
Logger: log.New(&log.LoggerOptions{Level: log.Trace}),
System: &sysView,
StorageView: &logical.InmemStorage{},
}
b := backend()
Expand Down
2 changes: 1 addition & 1 deletion bootstrap/terraform/vm.tf
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ resource "azurerm_public_ip" "vault_azure_pub_ip" {

# Restrict SSH access to the local public IP
data "http" "my_ip" {
url = "https://ifconfig.me/ip"
url = "https://ipinfo.io/ip"
thyton marked this conversation as resolved.
Show resolved Hide resolved
}

resource "azurerm_network_security_group" "vault_azure_sg" {
Expand Down