Skip to content

Commit

Permalink
Fix Azure join for identities across resource groups (#28961)
Browse files Browse the repository at this point in the history
This change fixes a bug in the Azure join method where a VM's identity can't be
verified if it's in a different resource group from its managed identity.
  • Loading branch information
atburke committed Jul 12, 2023
1 parent a2a7c50 commit 64c17cf
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 20 deletions.
2 changes: 1 addition & 1 deletion lib/auth/join_azure.go
Expand Up @@ -268,7 +268,7 @@ func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken
// If the token is from a user-assigned managed identity, the resource ID is
// for the identity and we need to look the VM up by VM ID.
} else {
vm, err = vmClient.GetByVMID(ctx, resourceID.ResourceGroupName, vmID)
vm, err = vmClient.GetByVMID(ctx, types.Wildcard, vmID)
if err != nil {
if trace.IsNotFound(err) {
return nil, trace.AccessDenied("no VM found with matching VM ID")
Expand Down
19 changes: 19 additions & 0 deletions lib/cloud/azure/mocks.go
Expand Up @@ -508,6 +508,25 @@ func (m *ARMComputeMock) NewListPager(resourceGroup string, _ *armcompute.Virtua
})
}

func (m *ARMComputeMock) NewListAllPager(_ *armcompute.VirtualMachinesClientListAllOptions) *runtime.Pager[armcompute.VirtualMachinesClientListAllResponse] {
var vms []*armcompute.VirtualMachine
for _, resourceGroupVMs := range m.VirtualMachines {
vms = append(vms, resourceGroupVMs...)
}
return runtime.NewPager(runtime.PagingHandler[armcompute.VirtualMachinesClientListAllResponse]{
More: func(page armcompute.VirtualMachinesClientListAllResponse) bool {
return page.NextLink != nil && len(*page.NextLink) > 0
},
Fetcher: func(ctx context.Context, page *armcompute.VirtualMachinesClientListAllResponse) (armcompute.VirtualMachinesClientListAllResponse, error) {
return armcompute.VirtualMachinesClientListAllResponse{
VirtualMachineListResult: armcompute.VirtualMachineListResult{
Value: vms,
},
}, nil
},
})
}

func (m *ARMComputeMock) Get(_ context.Context, _ string, _ string, _ *armcompute.VirtualMachinesClientGetOptions) (armcompute.VirtualMachinesClientGetResponse, error) {
return armcompute.VirtualMachinesClientGetResponse{
VirtualMachine: m.GetResult,
Expand Down
73 changes: 54 additions & 19 deletions lib/cloud/azure/vm.go
Expand Up @@ -23,27 +23,31 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v3"
"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/types"
)

// armCompute provides an interface for an Azure Virtual Machine client.
// armCompute provides an interface for an Azure virtual machine client.
type armCompute interface {
// Get retrieves information about an Azure Virtual Machine.
// Get retrieves information about an Azure virtual machine.
Get(ctx context.Context, resourceGroupName string, vmName string, options *armcompute.VirtualMachinesClientGetOptions) (armcompute.VirtualMachinesClientGetResponse, error)
// NewListPagers lists Azure Virtual Machines.
// NewListPagers lists Azure virtual Machines.
NewListPager(resourceGroup string, opts *armcompute.VirtualMachinesClientListOptions) *runtime.Pager[armcompute.VirtualMachinesClientListResponse]
// NewListAllPager lists Azure virtual machines in any resource group.
NewListAllPager(opts *armcompute.VirtualMachinesClientListAllOptions) *runtime.Pager[armcompute.VirtualMachinesClientListAllResponse]
}

// VirtualMachinesClient is a client for Azure Virtual Machines.
// VirtualMachinesClient is a client for Azure virtual machines.
type VirtualMachinesClient interface {
// Get returns the Virtual Machine for the given resource ID.
// Get returns the virtual machine for the given resource ID.
Get(ctx context.Context, resourceID string) (*VirtualMachine, error)
// GetByVMID returns the Virtual Machine for a given VM ID.
// GetByVMID returns the virtual machine for a given VM ID.
GetByVMID(ctx context.Context, resourceGroup, vmID string) (*VirtualMachine, error)
// ListVirtualMachines gets all of the virtual machines in the given resource group.
ListVirtualMachines(ctx context.Context, resourceGroup string) ([]*armcompute.VirtualMachine, error)
}

// VirtualMachine represents an Azure Virtual Machine.
// VirtualMachine represents an Azure virtual machine.
type VirtualMachine struct {
// ID resource ID.
ID string `json:"id,omitempty"`
Expand All @@ -59,18 +63,18 @@ type VirtualMachine struct {
Identities []Identity
}

// Identitiy represents an Azure Virtual Machine identity.
// Identitiy represents an Azure virtual machine identity.
type Identity struct {
// ResourceID the identity resource ID.
ResourceID string
}

type vmClient struct {
// api is the Azure Virtual Machine client.
// api is the Azure virtual machine client.
api armCompute
}

// NewVirtualMachinesClient creates a new Azure Virtual Machines client by
// NewVirtualMachinesClient creates a new Azure virtual machines client by
// subscription and credentials.
func NewVirtualMachinesClient(subscription string, cred azcore.TokenCredential, options *arm.ClientOptions) (VirtualMachinesClient, error) {
computeAPI, err := armcompute.NewVirtualMachinesClient(subscription, cred, options)
Expand All @@ -81,7 +85,7 @@ func NewVirtualMachinesClient(subscription string, cred azcore.TokenCredential,
return NewVirtualMachinesClientByAPI(computeAPI), nil
}

// NewVirtualMachinesClientByAPI creates a new Azure Virtual Machines client by
// NewVirtualMachinesClientByAPI creates a new Azure virtual machines client by
// ARM API client.
func NewVirtualMachinesClientByAPI(api armCompute) VirtualMachinesClient {
return &vmClient{
Expand Down Expand Up @@ -121,7 +125,7 @@ func parseVirtualMachine(vm *armcompute.VirtualMachine) (*VirtualMachine, error)
}, nil
}

// Get returns the Virtual Machine for the given resource ID.
// Get returns the virtual machine for the given resource ID.
func (c *vmClient) Get(ctx context.Context, resourceID string) (*VirtualMachine, error) {
parsedResourceID, err := arm.ParseResourceID(resourceID)
if err != nil {
Expand All @@ -137,7 +141,7 @@ func (c *vmClient) Get(ctx context.Context, resourceID string) (*VirtualMachine,
return vm, trace.Wrap(err)
}

// GetByVMID returns the Virtual Machine for a given VM ID.
// GetByVMID returns the virtual machine for a given VM ID.
func (c *vmClient) GetByVMID(ctx context.Context, resourceGroup, vmID string) (*VirtualMachine, error) {
vms, err := c.ListVirtualMachines(ctx, resourceGroup)
if err != nil {
Expand All @@ -152,17 +156,48 @@ func (c *vmClient) GetByVMID(ctx context.Context, resourceGroup, vmID string) (*
return nil, trace.NotFound("no VM with ID %q", vmID)
}

// ListVirtualMachines lists all virtual machines in a given resource group using the Azure Virtual Machines API.
type vmPager struct {
more func() bool
nextPage func(context.Context) ([]*armcompute.VirtualMachine, error)
}

func newListPager(azurePager *runtime.Pager[armcompute.VirtualMachinesClientListResponse]) vmPager {
return vmPager{
more: azurePager.More,
nextPage: func(ctx context.Context) ([]*armcompute.VirtualMachine, error) {
res, err := azurePager.NextPage(ctx)
return res.Value, trace.Wrap(err)
},
}
}

func newListAllPager(azurePager *runtime.Pager[armcompute.VirtualMachinesClientListAllResponse]) vmPager {
return vmPager{
more: azurePager.More,
nextPage: func(ctx context.Context) ([]*armcompute.VirtualMachine, error) {
res, err := azurePager.NextPage(ctx)
return res.Value, trace.Wrap(err)
},
}
}

// ListVirtualMachines lists all virtual machines in a given resource group
// using the Azure virtual machines API. If resourceGroup is "*", it lists
// all virtual machines in any resource group.
func (c *vmClient) ListVirtualMachines(ctx context.Context, resourceGroup string) ([]*armcompute.VirtualMachine, error) {
pagerOpts := &armcompute.VirtualMachinesClientListOptions{}
pager := c.api.NewListPager(resourceGroup, pagerOpts)
var pager vmPager
if resourceGroup == types.Wildcard {
pager = newListAllPager(c.api.NewListAllPager(&armcompute.VirtualMachinesClientListAllOptions{}))
} else {
pager = newListPager(c.api.NewListPager(resourceGroup, &armcompute.VirtualMachinesClientListOptions{}))
}
var virtualMachines []*armcompute.VirtualMachine
for pager.More() {
res, err := pager.NextPage(ctx)
for pager.more() {
res, err := pager.nextPage(ctx)
if err != nil {
return nil, trace.Wrap(ConvertResponseError(err))
}
virtualMachines = append(virtualMachines, res.Value...)
virtualMachines = append(virtualMachines, res...)
}

return virtualMachines, nil
Expand Down
7 changes: 7 additions & 0 deletions lib/cloud/azure/vm_test.go
Expand Up @@ -22,6 +22,8 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v3"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/types"
)

func TestGetVirtualMachine(t *testing.T) {
Expand Down Expand Up @@ -169,6 +171,11 @@ func TestListVirtualMachines(t *testing.T) {
resourceGroup: "rgfake",
wantIDs: []string{},
},
{
name: "all resource groups",
resourceGroup: types.Wildcard,
wantIDs: []string{"vm1", "vm2", "vm3", "vm4"},
},
}

for _, tc := range tests {
Expand Down

0 comments on commit 64c17cf

Please sign in to comment.