Skip to content

Commit

Permalink
Add validating/defaulting webhook for machines Azure provider spec
Browse files Browse the repository at this point in the history
  • Loading branch information
JoelSpeed committed Jun 5, 2020
1 parent 7f8f353 commit 562403b
Show file tree
Hide file tree
Showing 2 changed files with 651 additions and 3 deletions.
189 changes: 187 additions & 2 deletions pkg/apis/machine/v1beta1/machine_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,14 @@ import (
"k8s.io/klog"
"k8s.io/utils/pointer"
aws "sigs.k8s.io/cluster-api-provider-aws/pkg/apis/awsprovider/v1beta1"
azure "sigs.k8s.io/cluster-api-provider-azure/pkg/apis/azureprovider/v1beta1"
ctrl "sigs.k8s.io/controller-runtime"
"sigs.k8s.io/controller-runtime/pkg/webhook/admission"
yaml "sigs.k8s.io/yaml"
)

var (
// AWS Defaults
defaultAWSIAMInstanceProfile = func(clusterID string) *string {
return pointer.StringPtr(fmt.Sprintf("%s-worker-profile", clusterID))
}
Expand All @@ -31,12 +33,40 @@ var (
defaultAWSSubnet = func(clusterID, az string) string {
return fmt.Sprintf("%s-private-%s", clusterID, az)
}

// Azure Defaults
defaultAzureVnet = func(clusterID string) string {
return fmt.Sprintf("%s-vnet", clusterID)
}
defaultAzureSubnet = func(clusterID string) string {
return fmt.Sprintf("%s-worker-subnet", clusterID)
}
defaultAzureNetworkResourceGroup = func(clusterID string) string {
return fmt.Sprintf("%s-rg", clusterID)
}
defaultAzureImageResourceID = func(clusterID string) string {
return fmt.Sprintf("/resourceGroups/%s/providers/Microsoft.Compute/images/%s", clusterID+"-rg", clusterID)
}
defaultAzureManagedIdentiy = func(clusterID string) string {
return fmt.Sprintf("%s-identity", clusterID)
}
defaultAzureResourceGroup = func(clusterID string) string {
return fmt.Sprintf("%s-rg", clusterID)
}
)

const (
defaultAWSUserDataSecret = "worker-user-data"
defaultUserDataSecret = "worker-user-data"

// AWS Defaults
defaultAWSCredentialsSecret = "aws-cloud-credentials"
defaultAWSInstanceType = "m4.large"

// Azure Defaults
defaultAzureVMSize = "Standard_D4s_V3"
defaultAzureCredentialsSecret = "azure-cloud-credentials"
defaultAzureOSDiskOSType = "Linux"
defaultAzureOSDiskStorageType = "Premium_LRS"
)

func getInfra() (*osconfigv1.Infrastructure, error) {
Expand Down Expand Up @@ -94,6 +124,8 @@ func createMachineValidator(platform osconfigv1.PlatformType, clusterID string)
switch platform {
case osconfigv1.AWSPlatformType:
h.webhookOperations = validateAWS
case osconfigv1.AzurePlatformType:
h.webhookOperations = validateAzure
default:
// just no-op
h.webhookOperations = func(h *validatorHandler, m *Machine) (bool, utilerrors.Aggregate) {
Expand Down Expand Up @@ -121,6 +153,8 @@ func createMachineDefaulter(platform osconfigv1.PlatformType, clusterID string)
switch platform {
case osconfigv1.AWSPlatformType:
h.webhookOperations = defaultAWS
case osconfigv1.AzurePlatformType:
h.webhookOperations = defaultAzure
default:
// just no-op
h.webhookOperations = func(h *defaulterHandler, m *Machine) (bool, utilerrors.Aggregate) {
Expand Down Expand Up @@ -209,7 +243,7 @@ func defaultAWS(h *defaulterHandler, m *Machine) (bool, utilerrors.Aggregate) {
providerSpec.IAMInstanceProfile = &aws.AWSResourceReference{ID: defaultAWSIAMInstanceProfile(h.clusterID)}
}
if providerSpec.UserDataSecret == nil {
providerSpec.UserDataSecret = &corev1.LocalObjectReference{Name: defaultAWSUserDataSecret}
providerSpec.UserDataSecret = &corev1.LocalObjectReference{Name: defaultUserDataSecret}
}

if providerSpec.CredentialsSecret == nil {
Expand Down Expand Up @@ -346,3 +380,154 @@ func validateAWS(h *validatorHandler, m *Machine) (bool, utilerrors.Aggregate) {

return true, nil
}

func defaultAzure(h *defaulterHandler, m *Machine) (bool, utilerrors.Aggregate) {
klog.V(3).Infof("Defaulting Azure providerSpec")

var errs []error
providerSpec := new(azure.AzureMachineProviderSpec)
if err := yaml.Unmarshal(m.Spec.ProviderSpec.Value.Raw, &providerSpec); err != nil {
errs = append(
errs,
field.Invalid(
field.NewPath("providerSpec", "value"),
providerSpec,
err.Error(),
),
)
return false, utilerrors.NewAggregate(errs)
}

if providerSpec.VMSize == "" {
providerSpec.VMSize = defaultAzureVMSize
}

// Vnet and Subnet need to be provided together by the user
if providerSpec.Vnet == "" && providerSpec.Subnet == "" {
providerSpec.Vnet = defaultAzureVnet(h.clusterID)
providerSpec.Subnet = defaultAzureSubnet(h.clusterID)

// NetworkResourceGroup can be set by the user without Vnet and Subnet,
// only override if they didn't set it
if providerSpec.NetworkResourceGroup == "" {
providerSpec.NetworkResourceGroup = defaultAzureNetworkResourceGroup(h.clusterID)
}
}

if providerSpec.Image.ResourceID == "" {
providerSpec.Image.ResourceID = defaultAzureImageResourceID(h.clusterID)
}

if providerSpec.ManagedIdentity == "" {
providerSpec.ManagedIdentity = defaultAzureManagedIdentiy(h.clusterID)
}

if providerSpec.ResourceGroup == "" {
providerSpec.ResourceGroup = defaultAzureResourceGroup(h.clusterID)
}

if providerSpec.UserDataSecret == nil {
providerSpec.UserDataSecret = &corev1.SecretReference{Name: defaultUserDataSecret}
}

if providerSpec.CredentialsSecret == nil {
providerSpec.CredentialsSecret = &corev1.SecretReference{Name: defaultAzureCredentialsSecret}
}

if providerSpec.OSDisk.OSType == "" {
providerSpec.OSDisk.OSType = defaultAzureOSDiskOSType
}

if providerSpec.OSDisk.ManagedDisk.StorageAccountType == "" {
providerSpec.OSDisk.ManagedDisk.StorageAccountType = defaultAzureOSDiskStorageType
}

rawBytes, err := json.Marshal(providerSpec)
if err != nil {
errs = append(errs, err)
}

if len(errs) > 0 {
return false, utilerrors.NewAggregate(errs)
}

m.Spec.ProviderSpec.Value = &runtime.RawExtension{Raw: rawBytes}
return true, nil
}

func validateAzure(h *validatorHandler, m *Machine) (bool, utilerrors.Aggregate) {
klog.V(3).Infof("Validating Azure providerSpec")

var errs []error
providerSpec := new(azure.AzureMachineProviderSpec)
if err := yaml.Unmarshal(m.Spec.ProviderSpec.Value.Raw, &providerSpec); err != nil {
errs = append(
errs,
field.Invalid(
field.NewPath("providerSpec", "value"),
providerSpec,
err.Error(),
),
)
return false, utilerrors.NewAggregate(errs)
}

if providerSpec.Location == "" {
errs = append(errs, field.Required(field.NewPath("providerSpec", "location"), "location should be set to one of the supported Azure regions"))
}

if providerSpec.VMSize == "" {
errs = append(errs, field.Required(field.NewPath("providerSpec", "vmSize"), "vmSize should be set to one of the supported Azure VM sizes"))
}

// Vnet requires Subnet
if providerSpec.Vnet != "" && providerSpec.Subnet == "" {
errs = append(errs, field.Required(field.NewPath("providerSpec", "subnet"), "must provide a subnet when a virtual network is specified"))
}

// Subnet requires Vnet
if providerSpec.Subnet != "" && providerSpec.Vnet == "" {
errs = append(errs, field.Required(field.NewPath("providerSpec", "vnet"), "must provide a virtual network when supplying subnets"))
}

// Vnet + Subnet requires NetworkResourceGroup
if (providerSpec.Vnet != "" || providerSpec.Subnet != "") && providerSpec.NetworkResourceGroup == "" {
errs = append(errs, field.Required(field.NewPath("providerSpec", "networkResourceGroup"), "must provide a network resource group when a virtual network or subnet is specified"))
}

if providerSpec.Image.ResourceID == "" {
errs = append(errs, field.Required(field.NewPath("providerSpec", "image", "resourceID"), "resourceID must be provided"))
}

if providerSpec.ManagedIdentity == "" {
errs = append(errs, field.Required(field.NewPath("providerSpec", "managedIdentity"), "managedIdentity must be provided"))
}

if providerSpec.ResourceGroup == "" {
errs = append(errs, field.Required(field.NewPath("providerSpec", "resourceGropu"), "resourceGroup must be provided"))
}

if providerSpec.UserDataSecret == nil {
errs = append(errs, field.Required(field.NewPath("providerSpec", "userDataSecret"), "userDataSecret must be provided"))
}

if providerSpec.CredentialsSecret == nil {
errs = append(errs, field.Required(field.NewPath("providerSpec", "credentialsSecret"), "credentialsSecret must be provided"))
}

if providerSpec.OSDisk.DiskSizeGB <= 0 {
errs = append(errs, field.Invalid(field.NewPath("providerSpec", "osDisk", "diskSizeGB"), providerSpec.OSDisk.DiskSizeGB, "diskSizeGB must be greater than zero"))
}

if providerSpec.OSDisk.OSType == "" {
errs = append(errs, field.Required(field.NewPath("providerSpec", "osDisk", "osType"), "osType must be provided"))
}
if providerSpec.OSDisk.ManagedDisk.StorageAccountType == "" {
errs = append(errs, field.Required(field.NewPath("providerSpec", "osDisk", "managedDisk", "storageAccountType"), "storageAccountType must be provided"))
}

if len(errs) > 0 {
return false, utilerrors.NewAggregate(errs)
}
return true, nil
}

0 comments on commit 562403b

Please sign in to comment.