Skip to content

Commit

Permalink
pluginsdk wrapper for diag package
Browse files Browse the repository at this point in the history
  • Loading branch information
manicminer committed Sep 27, 2023
1 parent 93a61f3 commit af67e10
Show file tree
Hide file tree
Showing 58 changed files with 247 additions and 252 deletions.
82 changes: 41 additions & 41 deletions internal/provider/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,11 @@ import (

"github.com/hashicorp/go-azure-sdk/sdk/auth"
"github.com/hashicorp/go-azure-sdk/sdk/environments"
"github.com/hashicorp/terraform-plugin-sdk/v2/diag"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/validation"
"github.com/hashicorp/terraform-provider-azuread/internal/clients"
"github.com/hashicorp/terraform-provider-azuread/internal/sdk"
"github.com/hashicorp/terraform-provider-azuread/internal/tf/pluginsdk"
"github.com/hashicorp/terraform-provider-azuread/internal/tf/validation"
)

// Microsoft’s Terraform Partner ID is this specific GUID
Expand All @@ -31,10 +31,10 @@ type ServiceRegistration interface {
WebsiteCategories() []string

// SupportedDataSources returns the supported Data Sources supported by this Service
SupportedDataSources() map[string]*schema.Resource
SupportedDataSources() map[string]*pluginsdk.Resource

// SupportedResources returns the supported Resources supported by this Service
SupportedResources() map[string]*schema.Resource
SupportedResources() map[string]*pluginsdk.Resource
}

// AzureADProvider returns a schema.Provider.
Expand All @@ -52,8 +52,8 @@ func AzureADProvider() *schema.Provider {
log.Printf(f, v...)
}

dataSources := make(map[string]*schema.Resource)
resources := make(map[string]*schema.Resource)
dataSources := make(map[string]*pluginsdk.Resource)
resources := make(map[string]*pluginsdk.Resource)

// first handle the typed services
for _, service := range SupportedTypedServices() {
Expand Down Expand Up @@ -111,149 +111,149 @@ func AzureADProvider() *schema.Provider {
}

p := &schema.Provider{
Schema: map[string]*schema.Schema{
Schema: map[string]*pluginsdk.Schema{
"client_id": {
Type: schema.TypeString,
Type: pluginsdk.TypeString,
Optional: true,
DefaultFunc: schema.EnvDefaultFunc("ARM_CLIENT_ID", ""),
Description: "The Client ID which should be used for service principal authentication",
},

"client_id_file_path": {
Type: schema.TypeString,
Type: pluginsdk.TypeString,
Optional: true,
DefaultFunc: schema.EnvDefaultFunc("ARM_CLIENT_ID_FILE_PATH", ""),
Description: "The path to a file containing the Client ID which should be used for service principal authentication",
},

"tenant_id": {
Type: schema.TypeString,
Type: pluginsdk.TypeString,
Optional: true,
DefaultFunc: schema.EnvDefaultFunc("ARM_TENANT_ID", ""),
Description: "The Tenant ID which should be used. Works with all authentication methods except Managed Identity",
},

"environment": {
Type: schema.TypeString,
Type: pluginsdk.TypeString,
Required: true,
DefaultFunc: schema.EnvDefaultFunc("ARM_ENVIRONMENT", "global"),
Description: "The cloud environment which should be used. Possible values are: `global` (also `public`), `usgovernmentl4` (also `usgovernment`), `usgovernmentl5` (also `dod`), and `china`. Defaults to `global`",
},

"metadata_host": {
Type: schema.TypeString,
Type: pluginsdk.TypeString,
Required: true,
DefaultFunc: schema.EnvDefaultFunc("ARM_METADATA_HOSTNAME", ""),
Description: "The Hostname which should be used for the Azure Metadata Service.",
},

// Client Certificate specific fields
"client_certificate": {
Type: schema.TypeString,
Type: pluginsdk.TypeString,
Optional: true,
DefaultFunc: schema.EnvDefaultFunc("ARM_CLIENT_CERTIFICATE", ""),
Description: "Base64 encoded PKCS#12 certificate bundle to use when authenticating as a Service Principal using a Client Certificate",
},

"client_certificate_password": {
Type: schema.TypeString,
Type: pluginsdk.TypeString,
Optional: true,
DefaultFunc: schema.EnvDefaultFunc("ARM_CLIENT_CERTIFICATE_PASSWORD", ""),
Description: "The password to decrypt the Client Certificate. For use when authenticating as a Service Principal using a Client Certificate",
},

"client_certificate_path": {
Type: schema.TypeString,
Type: pluginsdk.TypeString,
Optional: true,
DefaultFunc: schema.EnvDefaultFunc("ARM_CLIENT_CERTIFICATE_PATH", ""),
Description: "The path to the Client Certificate associated with the Service Principal for use when authenticating as a Service Principal using a Client Certificate",
},

// Client Secret specific fields
"client_secret": {
Type: schema.TypeString,
Type: pluginsdk.TypeString,
Optional: true,
DefaultFunc: schema.EnvDefaultFunc("ARM_CLIENT_SECRET", ""),
Description: "The application password to use when authenticating as a Service Principal using a Client Secret",
},

"client_secret_file_path": {
Type: schema.TypeString,
Type: pluginsdk.TypeString,
Optional: true,
DefaultFunc: schema.EnvDefaultFunc("ARM_CLIENT_SECRET_FILE_PATH", ""),
Description: "The path to a file containing the application password to use when authenticating as a Service Principal using a Client Secret",
},

// OIDC specific fields
"use_oidc": {
Type: schema.TypeBool,
Type: pluginsdk.TypeBool,
Optional: true,
DefaultFunc: schema.EnvDefaultFunc("ARM_USE_OIDC", false),
Description: "Allow OpenID Connect to be used for authentication",
},

"oidc_token": {
Type: schema.TypeString,
Type: pluginsdk.TypeString,
Optional: true,
DefaultFunc: schema.EnvDefaultFunc("ARM_OIDC_TOKEN", ""),
Description: "The ID token for use when authenticating as a Service Principal using OpenID Connect.",
},

"oidc_token_file_path": {
Type: schema.TypeString,
Type: pluginsdk.TypeString,
Optional: true,
DefaultFunc: schema.EnvDefaultFunc("ARM_OIDC_TOKEN_FILE_PATH", ""),
Description: "The path to a file containing an ID token for use when authenticating as a Service Principal using OpenID Connect.",
},

"oidc_request_token": {
Type: schema.TypeString,
Type: pluginsdk.TypeString,
Optional: true,
DefaultFunc: schema.MultiEnvDefaultFunc([]string{"ARM_OIDC_REQUEST_TOKEN", "ACTIONS_ID_TOKEN_REQUEST_TOKEN"}, ""),
Description: "The bearer token for the request to the OIDC provider. For use when authenticating as a Service Principal using OpenID Connect.",
},

"oidc_request_url": {
Type: schema.TypeString,
Type: pluginsdk.TypeString,
Optional: true,
DefaultFunc: schema.MultiEnvDefaultFunc([]string{"ARM_OIDC_REQUEST_URL", "ACTIONS_ID_TOKEN_REQUEST_URL"}, ""),
Description: "The URL for the OIDC provider from which to request an ID token. For use when authenticating as a Service Principal using OpenID Connect.",
},

// CLI authentication specific fields
"use_cli": {
Type: schema.TypeBool,
Type: pluginsdk.TypeBool,
Optional: true,
DefaultFunc: schema.EnvDefaultFunc("ARM_USE_CLI", true),
Description: "Allow Azure CLI to be used for Authentication",
},

// Managed Identity specific fields
"use_msi": {
Type: schema.TypeBool,
Type: pluginsdk.TypeBool,
Optional: true,
DefaultFunc: schema.EnvDefaultFunc("ARM_USE_MSI", false),
Description: "Allow Managed Identity to be used for Authentication",
},

"msi_endpoint": {
Type: schema.TypeString,
Type: pluginsdk.TypeString,
Optional: true,
DefaultFunc: schema.EnvDefaultFunc("ARM_MSI_ENDPOINT", ""),
Description: "The path to a custom endpoint for Managed Identity - in most circumstances this should be detected automatically",
},

// Managed Tracking GUID for User-agent
"partner_id": {
Type: schema.TypeString,
Type: pluginsdk.TypeString,
Optional: true,
ValidateFunc: validation.Any(validation.IsUUID, validation.StringIsEmpty),
DefaultFunc: schema.EnvDefaultFunc("ARM_PARTNER_ID", ""),
Description: "A GUID/UUID that is registered with Microsoft to facilitate partner resource usage attribution",
},

"disable_terraform_partner_id": {
Type: schema.TypeBool,
Type: pluginsdk.TypeBool,
Optional: true,
DefaultFunc: schema.EnvDefaultFunc("ARM_DISABLE_TERRAFORM_PARTNER_ID", false),
Description: "Disable the Terraform Partner ID, which is used if a custom `partner_id` isn't specified",
Expand All @@ -270,24 +270,24 @@ func AzureADProvider() *schema.Provider {
}

func providerConfigure(p *schema.Provider) schema.ConfigureContextFunc {
return func(ctx context.Context, d *schema.ResourceData) (interface{}, diag.Diagnostics) {
return func(ctx context.Context, d *pluginsdk.ResourceData) (interface{}, pluginsdk.Diagnostics) {
var certData []byte
if encodedCert := d.Get("client_certificate").(string); encodedCert != "" {
var err error
certData, err = decodeCertificate(encodedCert)
if err != nil {
return nil, diag.FromErr(err)
return nil, pluginsdk.DiagFromErr(err)
}
}

clientSecret, err := getClientSecret(d)
if err != nil {
return nil, diag.FromErr(err)
return nil, pluginsdk.DiagFromErr(err)
}

clientId, err := getClientId(d)
if err != nil {
return nil, diag.FromErr(err)
return nil, pluginsdk.DiagFromErr(err)
}

var (
Expand All @@ -299,21 +299,21 @@ func providerConfigure(p *schema.Provider) schema.ConfigureContextFunc {

if metadataHost != "" {
if env, err = environments.FromEndpoint(ctx, fmt.Sprintf("https://%s", metadataHost), envName); err != nil {
return nil, diag.FromErr(err)
return nil, pluginsdk.DiagFromErr(err)
}
} else if env, err = environments.FromName(envName); err != nil {
return nil, diag.FromErr(err)
return nil, pluginsdk.DiagFromErr(err)
}

if env.MicrosoftGraph == nil {
return nil, diag.Errorf("Microsoft Graph was not configured for the specified environment")
return nil, pluginsdk.DiagErrorf("Microsoft Graph was not configured for the specified environment")
} else if endpoint, ok := env.MicrosoftGraph.Endpoint(); !ok || *endpoint == "" {
return nil, diag.Errorf("Microsoft Graph endpoint could not be determined for the specified environment")
return nil, pluginsdk.DiagErrorf("Microsoft Graph endpoint could not be determined for the specified environment")
}

idToken, err := oidcToken(d)
if err != nil {
return nil, diag.FromErr(err)
return nil, pluginsdk.DiagFromErr(err)
}

authConfig := &auth.Credentials{
Expand Down Expand Up @@ -348,7 +348,7 @@ func providerConfigure(p *schema.Provider) schema.ConfigureContextFunc {
}
}

func buildClient(ctx context.Context, p *schema.Provider, authConfig *auth.Credentials, partnerId string) (*clients.Client, diag.Diagnostics) {
func buildClient(ctx context.Context, p *schema.Provider, authConfig *auth.Credentials, partnerId string) (*clients.Client, pluginsdk.Diagnostics) {
clientBuilder := clients.ClientBuilder{
AuthConfig: authConfig,
PartnerID: partnerId,
Expand All @@ -362,7 +362,7 @@ func buildClient(ctx context.Context, p *schema.Provider, authConfig *auth.Crede

client, err := clientBuilder.Build(stopCtx)
if err != nil {
return nil, diag.FromErr(err)
return nil, pluginsdk.DiagFromErr(err)
}

return client, nil
Expand All @@ -381,7 +381,7 @@ func decodeCertificate(clientCertificate string) ([]byte, error) {
return pfx, nil
}

func oidcToken(d *schema.ResourceData) (*string, error) {
func oidcToken(d *pluginsdk.ResourceData) (*string, error) {
idToken := d.Get("oidc_token").(string)

if path := d.Get("oidc_token_file_path").(string); path != "" {
Expand All @@ -403,7 +403,7 @@ func oidcToken(d *schema.ResourceData) (*string, error) {
return &idToken, nil
}

func getClientId(d *schema.ResourceData) (*string, error) {
func getClientId(d *pluginsdk.ResourceData) (*string, error) {
clientId := strings.TrimSpace(d.Get("client_id").(string))

if path := d.Get("client_id_file_path").(string); path != "" {
Expand All @@ -425,7 +425,7 @@ func getClientId(d *schema.ResourceData) (*string, error) {
return &clientId, nil
}

func getClientSecret(d *schema.ResourceData) (*string, error) {
func getClientSecret(d *pluginsdk.ResourceData) (*string, error) {
clientSecret := strings.TrimSpace(d.Get("client_secret").(string))

if path := d.Get("client_secret_file_path").(string); path != "" {
Expand Down
17 changes: 9 additions & 8 deletions internal/provider/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/terraform-plugin-sdk/v2/terraform"
"github.com/hashicorp/terraform-provider-azuread/internal/clients"
"github.com/hashicorp/terraform-provider-azuread/internal/tf/pluginsdk"
)

func TestProvider(t *testing.T) {
Expand All @@ -36,7 +37,7 @@ func TestAccProvider_cliAuth(t *testing.T) {
ctx := context.Background()

// Support only Azure CLI authentication
provider.ConfigureContextFunc = func(ctx context.Context, d *schema.ResourceData) (interface{}, diag.Diagnostics) {
provider.ConfigureContextFunc = func(ctx context.Context, d *pluginsdk.ResourceData) (interface{}, pluginsdk.Diagnostics) {
envName := d.Get("environment").(string)
env, err := environments.FromName(envName)
if err != nil {
Expand Down Expand Up @@ -74,7 +75,7 @@ func TestAccProvider_clientCertificateAuth(t *testing.T) {
ctx := context.Background()

// Support only client certificate authentication
provider.ConfigureContextFunc = func(ctx context.Context, d *schema.ResourceData) (interface{}, diag.Diagnostics) {
provider.ConfigureContextFunc = func(ctx context.Context, d *pluginsdk.ResourceData) (interface{}, pluginsdk.Diagnostics) {
envName := d.Get("environment").(string)
env, err := environments.FromName(envName)
if err != nil {
Expand Down Expand Up @@ -115,7 +116,7 @@ func TestAccProvider_clientCertificateInlineAuth(t *testing.T) {
ctx := context.Background()

// Support only client certificate authentication
provider.ConfigureContextFunc = func(ctx context.Context, d *schema.ResourceData) (interface{}, diag.Diagnostics) {
provider.ConfigureContextFunc = func(ctx context.Context, d *pluginsdk.ResourceData) (interface{}, pluginsdk.Diagnostics) {
var certData []byte
if encodedCert := d.Get("client_certificate").(string); encodedCert != "" {
var err error
Expand Down Expand Up @@ -181,7 +182,7 @@ func testAccProvider_clientSecretAuthFromEnvironment(t *testing.T) {
ctx := context.Background()

// Support only client secret authentication
provider.ConfigureContextFunc = func(ctx context.Context, d *schema.ResourceData) (interface{}, diag.Diagnostics) {
provider.ConfigureContextFunc = func(ctx context.Context, d *pluginsdk.ResourceData) (interface{}, pluginsdk.Diagnostics) {
envName := d.Get("environment").(string)
env, err := environments.FromName(envName)
if err != nil {
Expand Down Expand Up @@ -242,7 +243,7 @@ func testAccProvider_clientSecretAuthFromFiles(t *testing.T) {
ctx := context.Background()

// Support only client secret authentication
provider.ConfigureContextFunc = func(ctx context.Context, d *schema.ResourceData) (interface{}, diag.Diagnostics) {
provider.ConfigureContextFunc = func(ctx context.Context, d *pluginsdk.ResourceData) (interface{}, pluginsdk.Diagnostics) {
envName := d.Get("environment").(string)
env, err := environments.FromName(envName)
if err != nil {
Expand Down Expand Up @@ -304,7 +305,7 @@ func testAccProvider_genericOidcAuthFromEnvironment(t *testing.T) {
ctx := context.Background()

// Support only oidc authentication
provider.ConfigureContextFunc = func(ctx context.Context, d *schema.ResourceData) (interface{}, diag.Diagnostics) {
provider.ConfigureContextFunc = func(ctx context.Context, d *pluginsdk.ResourceData) (interface{}, pluginsdk.Diagnostics) {
envName := d.Get("environment").(string)
env, err := environments.FromName(envName)
if err != nil {
Expand Down Expand Up @@ -356,7 +357,7 @@ func testAccProvider_genericOidcAuthFromFiles(t *testing.T) {
ctx := context.Background()

// Support only oidc authentication
provider.ConfigureContextFunc = func(ctx context.Context, d *schema.ResourceData) (interface{}, diag.Diagnostics) {
provider.ConfigureContextFunc = func(ctx context.Context, d *pluginsdk.ResourceData) (interface{}, pluginsdk.Diagnostics) {
envName := d.Get("environment").(string)
env, err := environments.FromName(envName)
if err != nil {
Expand Down Expand Up @@ -407,7 +408,7 @@ func TestAccProvider_githubOidcAuth(t *testing.T) {
ctx := context.Background()

// Support only oidc authentication
provider.ConfigureContextFunc = func(ctx context.Context, d *schema.ResourceData) (interface{}, diag.Diagnostics) {
provider.ConfigureContextFunc = func(ctx context.Context, d *pluginsdk.ResourceData) (interface{}, pluginsdk.Diagnostics) {
envName := d.Get("environment").(string)
env, err := environments.FromName(envName)
if err != nil {
Expand Down
Loading

0 comments on commit af67e10

Please sign in to comment.