Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[v13] Fix Azure join for identities across resource groups #28961

Merged
merged 1 commit into from Jul 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion lib/auth/join_azure.go
Expand Up @@ -268,7 +268,7 @@ func verifyVMIdentity(ctx context.Context, cfg *azureRegisterConfig, accessToken
// If the token is from a user-assigned managed identity, the resource ID is
// for the identity and we need to look the VM up by VM ID.
} else {
vm, err = vmClient.GetByVMID(ctx, resourceID.ResourceGroupName, vmID)
vm, err = vmClient.GetByVMID(ctx, types.Wildcard, vmID)
if err != nil {
if trace.IsNotFound(err) {
return nil, trace.AccessDenied("no VM found with matching VM ID")
Expand Down
19 changes: 19 additions & 0 deletions lib/cloud/azure/mocks.go
Expand Up @@ -508,6 +508,25 @@ func (m *ARMComputeMock) NewListPager(resourceGroup string, _ *armcompute.Virtua
})
}

func (m *ARMComputeMock) NewListAllPager(_ *armcompute.VirtualMachinesClientListAllOptions) *runtime.Pager[armcompute.VirtualMachinesClientListAllResponse] {
var vms []*armcompute.VirtualMachine
for _, resourceGroupVMs := range m.VirtualMachines {
vms = append(vms, resourceGroupVMs...)
}
return runtime.NewPager(runtime.PagingHandler[armcompute.VirtualMachinesClientListAllResponse]{
More: func(page armcompute.VirtualMachinesClientListAllResponse) bool {
return page.NextLink != nil && len(*page.NextLink) > 0
},
Fetcher: func(ctx context.Context, page *armcompute.VirtualMachinesClientListAllResponse) (armcompute.VirtualMachinesClientListAllResponse, error) {
return armcompute.VirtualMachinesClientListAllResponse{
VirtualMachineListResult: armcompute.VirtualMachineListResult{
Value: vms,
},
}, nil
},
})
}

func (m *ARMComputeMock) Get(_ context.Context, _ string, _ string, _ *armcompute.VirtualMachinesClientGetOptions) (armcompute.VirtualMachinesClientGetResponse, error) {
return armcompute.VirtualMachinesClientGetResponse{
VirtualMachine: m.GetResult,
Expand Down
73 changes: 54 additions & 19 deletions lib/cloud/azure/vm.go
Expand Up @@ -23,27 +23,31 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v3"
"github.com/gravitational/trace"

"github.com/gravitational/teleport/api/types"
)

// armCompute provides an interface for an Azure Virtual Machine client.
// armCompute provides an interface for an Azure virtual machine client.
type armCompute interface {
// Get retrieves information about an Azure Virtual Machine.
// Get retrieves information about an Azure virtual machine.
Get(ctx context.Context, resourceGroupName string, vmName string, options *armcompute.VirtualMachinesClientGetOptions) (armcompute.VirtualMachinesClientGetResponse, error)
// NewListPagers lists Azure Virtual Machines.
// NewListPagers lists Azure virtual Machines.
NewListPager(resourceGroup string, opts *armcompute.VirtualMachinesClientListOptions) *runtime.Pager[armcompute.VirtualMachinesClientListResponse]
// NewListAllPager lists Azure virtual machines in any resource group.
NewListAllPager(opts *armcompute.VirtualMachinesClientListAllOptions) *runtime.Pager[armcompute.VirtualMachinesClientListAllResponse]
}

// VirtualMachinesClient is a client for Azure Virtual Machines.
// VirtualMachinesClient is a client for Azure virtual machines.
type VirtualMachinesClient interface {
// Get returns the Virtual Machine for the given resource ID.
// Get returns the virtual machine for the given resource ID.
Get(ctx context.Context, resourceID string) (*VirtualMachine, error)
// GetByVMID returns the Virtual Machine for a given VM ID.
// GetByVMID returns the virtual machine for a given VM ID.
GetByVMID(ctx context.Context, resourceGroup, vmID string) (*VirtualMachine, error)
// ListVirtualMachines gets all of the virtual machines in the given resource group.
ListVirtualMachines(ctx context.Context, resourceGroup string) ([]*armcompute.VirtualMachine, error)
}

// VirtualMachine represents an Azure Virtual Machine.
// VirtualMachine represents an Azure virtual machine.
type VirtualMachine struct {
// ID resource ID.
ID string `json:"id,omitempty"`
Expand All @@ -59,18 +63,18 @@ type VirtualMachine struct {
Identities []Identity
}

// Identitiy represents an Azure Virtual Machine identity.
// Identitiy represents an Azure virtual machine identity.
type Identity struct {
// ResourceID the identity resource ID.
ResourceID string
}

type vmClient struct {
// api is the Azure Virtual Machine client.
// api is the Azure virtual machine client.
api armCompute
}

// NewVirtualMachinesClient creates a new Azure Virtual Machines client by
// NewVirtualMachinesClient creates a new Azure virtual machines client by
// subscription and credentials.
func NewVirtualMachinesClient(subscription string, cred azcore.TokenCredential, options *arm.ClientOptions) (VirtualMachinesClient, error) {
computeAPI, err := armcompute.NewVirtualMachinesClient(subscription, cred, options)
Expand All @@ -81,7 +85,7 @@ func NewVirtualMachinesClient(subscription string, cred azcore.TokenCredential,
return NewVirtualMachinesClientByAPI(computeAPI), nil
}

// NewVirtualMachinesClientByAPI creates a new Azure Virtual Machines client by
// NewVirtualMachinesClientByAPI creates a new Azure virtual machines client by
// ARM API client.
func NewVirtualMachinesClientByAPI(api armCompute) VirtualMachinesClient {
return &vmClient{
Expand Down Expand Up @@ -121,7 +125,7 @@ func parseVirtualMachine(vm *armcompute.VirtualMachine) (*VirtualMachine, error)
}, nil
}

// Get returns the Virtual Machine for the given resource ID.
// Get returns the virtual machine for the given resource ID.
func (c *vmClient) Get(ctx context.Context, resourceID string) (*VirtualMachine, error) {
parsedResourceID, err := arm.ParseResourceID(resourceID)
if err != nil {
Expand All @@ -137,7 +141,7 @@ func (c *vmClient) Get(ctx context.Context, resourceID string) (*VirtualMachine,
return vm, trace.Wrap(err)
}

// GetByVMID returns the Virtual Machine for a given VM ID.
// GetByVMID returns the virtual machine for a given VM ID.
func (c *vmClient) GetByVMID(ctx context.Context, resourceGroup, vmID string) (*VirtualMachine, error) {
vms, err := c.ListVirtualMachines(ctx, resourceGroup)
if err != nil {
Expand All @@ -152,17 +156,48 @@ func (c *vmClient) GetByVMID(ctx context.Context, resourceGroup, vmID string) (*
return nil, trace.NotFound("no VM with ID %q", vmID)
}

// ListVirtualMachines lists all virtual machines in a given resource group using the Azure Virtual Machines API.
type vmPager struct {
more func() bool
nextPage func(context.Context) ([]*armcompute.VirtualMachine, error)
}

func newListPager(azurePager *runtime.Pager[armcompute.VirtualMachinesClientListResponse]) vmPager {
return vmPager{
more: azurePager.More,
nextPage: func(ctx context.Context) ([]*armcompute.VirtualMachine, error) {
res, err := azurePager.NextPage(ctx)
return res.Value, trace.Wrap(err)
},
}
}

func newListAllPager(azurePager *runtime.Pager[armcompute.VirtualMachinesClientListAllResponse]) vmPager {
return vmPager{
more: azurePager.More,
nextPage: func(ctx context.Context) ([]*armcompute.VirtualMachine, error) {
res, err := azurePager.NextPage(ctx)
return res.Value, trace.Wrap(err)
},
}
}

// ListVirtualMachines lists all virtual machines in a given resource group
// using the Azure virtual machines API. If resourceGroup is "*", it lists
// all virtual machines in any resource group.
func (c *vmClient) ListVirtualMachines(ctx context.Context, resourceGroup string) ([]*armcompute.VirtualMachine, error) {
pagerOpts := &armcompute.VirtualMachinesClientListOptions{}
pager := c.api.NewListPager(resourceGroup, pagerOpts)
var pager vmPager
if resourceGroup == types.Wildcard {
pager = newListAllPager(c.api.NewListAllPager(&armcompute.VirtualMachinesClientListAllOptions{}))
} else {
pager = newListPager(c.api.NewListPager(resourceGroup, &armcompute.VirtualMachinesClientListOptions{}))
}
var virtualMachines []*armcompute.VirtualMachine
for pager.More() {
res, err := pager.NextPage(ctx)
for pager.more() {
res, err := pager.nextPage(ctx)
if err != nil {
return nil, trace.Wrap(ConvertResponseError(err))
}
virtualMachines = append(virtualMachines, res.Value...)
virtualMachines = append(virtualMachines, res...)
}

return virtualMachines, nil
Expand Down
7 changes: 7 additions & 0 deletions lib/cloud/azure/vm_test.go
Expand Up @@ -22,6 +22,8 @@ import (
"github.com/Azure/azure-sdk-for-go/sdk/azcore/to"
"github.com/Azure/azure-sdk-for-go/sdk/resourcemanager/compute/armcompute/v3"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/types"
)

func TestGetVirtualMachine(t *testing.T) {
Expand Down Expand Up @@ -169,6 +171,11 @@ func TestListVirtualMachines(t *testing.T) {
resourceGroup: "rgfake",
wantIDs: []string{},
},
{
name: "all resource groups",
resourceGroup: types.Wildcard,
wantIDs: []string{"vm1", "vm2", "vm3", "vm4"},
},
}

for _, tc := range tests {
Expand Down