Skip to content

Commit

Permalink
Update azurerm_policy_remediation - support `resource_discovery_mod…
Browse files Browse the repository at this point in the history
…e` (#9210)

Co-authored-by: kt <kt@katbyte.me>
  • Loading branch information
ArcturusZhang and katbyte committed Nov 9, 2020
1 parent 8ffe768 commit 3cae546
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 12 deletions.
82 changes: 82 additions & 0 deletions azurerm/internal/services/policy/policy_remediation_resource.go
Expand Up @@ -7,7 +7,9 @@ import (
"time"

"github.com/Azure/azure-sdk-for-go/services/preview/policyinsights/mgmt/2019-10-01-preview/policyinsights"
"github.com/hashicorp/terraform-plugin-sdk/helper/resource"
"github.com/hashicorp/terraform-plugin-sdk/helper/schema"
"github.com/hashicorp/terraform-plugin-sdk/helper/validation"
"github.com/terraform-providers/terraform-provider-azurerm/azurerm/helpers/suppress"
"github.com/terraform-providers/terraform-provider-azurerm/azurerm/helpers/tf"
"github.com/terraform-providers/terraform-provider-azurerm/azurerm/internal/clients"
Expand Down Expand Up @@ -77,6 +79,16 @@ func resourceArmPolicyRemediation() *schema.Resource {
// TODO: remove this suppression when github issue https://github.com/Azure/azure-rest-api-specs/issues/8353 is addressed
DiffSuppressFunc: suppress.CaseDifference,
},

"resource_discovery_mode": {
Type: schema.TypeString,
Optional: true,
Default: string(policyinsights.ExistingNonCompliant),
ValidateFunc: validation.StringInSlice([]string{
string(policyinsights.ExistingNonCompliant),
string(policyinsights.ReEvaluateCompliance),
}, false),
},
},
}
}
Expand Down Expand Up @@ -111,6 +123,7 @@ func resourceArmPolicyRemediationCreateUpdate(d *schema.ResourceData, meta inter
},
PolicyAssignmentID: utils.String(d.Get("policy_assignment_id").(string)),
PolicyDefinitionReferenceID: utils.String(d.Get("policy_definition_reference_id").(string)),
ResourceDiscoveryMode: policyinsights.ResourceDiscoveryMode(d.Get("resource_discovery_mode").(string)),
},
}

Expand Down Expand Up @@ -177,6 +190,7 @@ func resourceArmPolicyRemediationRead(d *schema.ResourceData, meta interface{})

d.Set("policy_assignment_id", props.PolicyAssignmentID)
d.Set("policy_definition_reference_id", props.PolicyDefinitionReferenceID)
d.Set("resource_discovery_mode", string(props.ResourceDiscoveryMode))
}

return nil
Expand All @@ -192,6 +206,38 @@ func resourceArmPolicyRemediationDelete(d *schema.ResourceData, meta interface{}
return err
}

// we have to cancel the remediation first before deleting it when the resource_discovery_mode is set to ReEvaluateCompliance
// therefore we first retrieve the remediation to see if the resource_discovery_mode is switched to ReEvaluateCompliance
existing, err := RemediationGetAtScope(ctx, client, id.Name, id.PolicyScopeId)
if err != nil {
if utils.ResponseWasNotFound(existing.Response) {
return nil
}
return fmt.Errorf("retrieving Policy Remediation %q (Scope %q): %+v", id.Name, id.ScopeId(), err)
}

if existing.RemediationProperties != nil && existing.RemediationProperties.ResourceDiscoveryMode == policyinsights.ReEvaluateCompliance {
log.Printf("[DEBUG] cancelling the remediation first before deleting it when `resource_discovery_mode` is set to `ReEvaluateCompliance`")
if err := cancelRemediation(ctx, client, id.Name, id.PolicyScopeId); err != nil {
return fmt.Errorf("cancelling Policy Remediation %q (Scope %q): %+v", id.Name, id.ScopeId(), err)
}

log.Printf("[DEBUG] waiting for the Policy Remediation %q (Scope %q) to be canceled", id.Name, id.ScopeId())
stateConf := &resource.StateChangeConf{
Pending: []string{"Cancelling"},
Target: []string{
"Succeeded", "Canceled", "Failed",
},
Refresh: policyRemediationCancellationRefreshFunc(ctx, client, id.Name, id.PolicyScopeId),
MinTimeout: 10 * time.Second,
Timeout: d.Timeout(schema.TimeoutDelete),
}

if _, err := stateConf.WaitForState(); err != nil {
return fmt.Errorf("waiting for Policy Remediation %q to be canceled: %+v", id.Name, err)
}
}

switch scope := id.PolicyScopeId.(type) {
case parse.ScopeAtSubscription:
_, err = client.DeleteAtSubscription(ctx, scope.SubscriptionId, id.Name)
Expand All @@ -211,6 +257,42 @@ func resourceArmPolicyRemediationDelete(d *schema.ResourceData, meta interface{}
return nil
}

func cancelRemediation(ctx context.Context, client *policyinsights.RemediationsClient, name string, scopeId parse.PolicyScopeId) error {
switch scopeId := scopeId.(type) {
case parse.ScopeAtSubscription:
_, err := client.CancelAtSubscription(ctx, scopeId.SubscriptionId, name)
return err
case parse.ScopeAtResourceGroup:
_, err := client.CancelAtResourceGroup(ctx, scopeId.SubscriptionId, scopeId.ResourceGroup, name)
return err
case parse.ScopeAtResource:
_, err := client.CancelAtResource(ctx, scopeId.ScopeId(), name)
return err
case parse.ScopeAtManagementGroup:
_, err := client.CancelAtManagementGroup(ctx, scopeId.ManagementGroupName, name)
return err
default:
return fmt.Errorf("nvalid scope type")
}
}

func policyRemediationCancellationRefreshFunc(ctx context.Context, client *policyinsights.RemediationsClient, name string, scopeId parse.PolicyScopeId) resource.StateRefreshFunc {
return func() (interface{}, string, error) {
resp, err := RemediationGetAtScope(ctx, client, name, scopeId)
if err != nil {
return nil, "", fmt.Errorf("issuing read request in policyRemediationCancellationRefreshFunc for Policy Remediation %q (Scope %q): %+v", name, scopeId.ScopeId(), err)
}

if resp.RemediationProperties == nil {
return nil, "", fmt.Errorf("`properties` was nil")
}
if resp.RemediationProperties.ProvisioningState == nil {
return nil, "", fmt.Errorf("`properties.ProvisioningState` was nil")
}
return resp, *resp.RemediationProperties.ProvisioningState, nil
}
}

// RemediationGetAtScope is a wrapper of the 4 Get functions on RemediationsClient, combining them into one to simplify code.
func RemediationGetAtScope(ctx context.Context, client *policyinsights.RemediationsClient, name string, scopeId parse.PolicyScopeId) (policyinsights.Remediation, error) {
switch scopeId := scopeId.(type) {
Expand Down
Expand Up @@ -25,8 +25,6 @@ func TestAccAzureRMPolicyRemediation_atSubscription(t *testing.T) {
Config: testAccAzureRMPolicyRemediation_atSubscription(data),
Check: resource.ComposeTestCheckFunc(
testCheckAzureRMPolicyRemediationExists(data.ResourceName),
resource.TestCheckResourceAttrSet(data.ResourceName, "scope"),
resource.TestCheckResourceAttrSet(data.ResourceName, "policy_assignment_id"),
),
},
data.ImportStep(),
Expand All @@ -47,8 +45,6 @@ func TestAccAzureRMPolicyRemediation_atSubscriptionWithDefinitionSet(t *testing.
Check: resource.ComposeTestCheckFunc(
testCheckAzureRMPolicyRemediationExists(data.ResourceName),
resource.TestCheckResourceAttrSet(data.ResourceName, "scope"),
resource.TestCheckResourceAttrSet(data.ResourceName, "policy_assignment_id"),
resource.TestCheckResourceAttrSet(data.ResourceName, "policy_definition_reference_id"),
),
},
data.ImportStep(),
Expand All @@ -68,8 +64,25 @@ func TestAccAzureRMPolicyRemediation_atResourceGroup(t *testing.T) {
Config: testAccAzureRMPolicyRemediation_atResourceGroup(data),
Check: resource.ComposeTestCheckFunc(
testCheckAzureRMPolicyRemediationExists(data.ResourceName),
resource.TestCheckResourceAttrSet(data.ResourceName, "scope"),
resource.TestCheckResourceAttrSet(data.ResourceName, "policy_assignment_id"),
),
},
data.ImportStep(),
},
})
}

func TestAccAzureRMPolicyRemediation_atResourceGroupWithDiscoveryMode(t *testing.T) {
data := acceptance.BuildTestData(t, "azurerm_policy_remediation", "test")

resource.ParallelTest(t, resource.TestCase{
PreCheck: func() { acceptance.PreCheck(t) },
Providers: acceptance.SupportedProviders,
CheckDestroy: testCheckAzureRMPolicyRemediationDestroy,
Steps: []resource.TestStep{
{
Config: testAccAzureRMPolicyRemediation_atResourceGroupWithDiscoveryMode(data),
Check: resource.ComposeTestCheckFunc(
testCheckAzureRMPolicyRemediationExists(data.ResourceName),
),
},
data.ImportStep(),
Expand All @@ -89,8 +102,6 @@ func TestAccAzureRMPolicyRemediation_atManagementGroup(t *testing.T) {
Config: testAccAzureRMPolicyRemediation_atManagementGroup(data),
Check: resource.ComposeTestCheckFunc(
testCheckAzureRMPolicyRemediationExists(data.ResourceName),
resource.TestCheckResourceAttrSet(data.ResourceName, "scope"),
resource.TestCheckResourceAttrSet(data.ResourceName, "policy_assignment_id"),
),
},
data.ImportStep(),
Expand All @@ -110,8 +121,6 @@ func TestAccAzureRMPolicyRemediation_atResource(t *testing.T) {
Config: testAccAzureRMPolicyRemediation_atResource(data),
Check: resource.ComposeTestCheckFunc(
testCheckAzureRMPolicyRemediationExists(data.ResourceName),
resource.TestCheckResourceAttrSet(data.ResourceName, "scope"),
resource.TestCheckResourceAttrSet(data.ResourceName, "policy_assignment_id"),
),
},
data.ImportStep(),
Expand All @@ -131,14 +140,13 @@ func TestAccAzureRMPolicyRemediation_updateLocation(t *testing.T) {
Config: testAccAzureRMPolicyRemediation_atResourceGroup(data),
Check: resource.ComposeTestCheckFunc(
testCheckAzureRMPolicyRemediationExists(data.ResourceName),
resource.TestCheckResourceAttr(data.ResourceName, "location_filters.#", "0"),
),
},
data.ImportStep(),
{
Config: testAccAzureRMPolicyRemediation_updateLocation(data),
Check: resource.ComposeTestCheckFunc(
testCheckAzureRMPolicyRemediationExists(data.ResourceName),
resource.TestCheckResourceAttr(data.ResourceName, "location_filters.#", "1"),
),
},
data.ImportStep(),
Expand Down Expand Up @@ -451,6 +459,77 @@ resource "azurerm_policy_remediation" "test" {
`, data.RandomString, data.Locations.Primary)
}

func testAccAzureRMPolicyRemediation_atResourceGroupWithDiscoveryMode(data acceptance.TestData) string {
return fmt.Sprintf(`
provider "azurerm" {
features {}
}
resource "azurerm_resource_group" "test" {
name = "acctestRG-policy-%[1]s"
location = "%[2]s"
}
resource "azurerm_policy_definition" "test" {
name = "acctestDef-%[1]s"
policy_type = "Custom"
mode = "All"
display_name = "my-policy-definition"
policy_rule = <<POLICY_RULE
{
"if": {
"not": {
"field": "location",
"in": "[parameters('allowedLocations')]"
}
},
"then": {
"effect": "audit"
}
}
POLICY_RULE
parameters = <<PARAMETERS
{
"allowedLocations": {
"type": "Array",
"metadata": {
"description": "The list of allowed locations for resources.",
"displayName": "Allowed locations",
"strongType": "location"
}
}
}
PARAMETERS
}
resource "azurerm_policy_assignment" "test" {
name = "acctestAssign-%[1]s"
scope = azurerm_resource_group.test.id
policy_definition_id = azurerm_policy_definition.test.id
description = "Policy Assignment created via an Acceptance Test"
display_name = "My Example Policy Assignment"
parameters = <<PARAMETERS
{
"allowedLocations": {
"value": [ "West Europe" ]
}
}
PARAMETERS
}
resource "azurerm_policy_remediation" "test" {
name = "acctestremediation-%[1]s"
scope = azurerm_policy_assignment.test.scope
policy_assignment_id = azurerm_policy_assignment.test.id
resource_discovery_mode = "ReEvaluateCompliance"
}
`, data.RandomString, data.Locations.Primary)
}

func testAccAzureRMPolicyRemediation_updateLocation(data acceptance.TestData) string {
return fmt.Sprintf(`
provider "azurerm" {
Expand Down
2 changes: 2 additions & 0 deletions website/docs/r/policy_remediation.html.markdown
Expand Up @@ -95,6 +95,8 @@ The following arguments are supported:

* `location_filters` - (Optional) A list of the resource locations that will be remediated.

* `resource_discovery_mode` - (Optional) The way that resources to remediate are discovered. Possible values are `ExistingNonCompliant`, `ReEvaluateCompliance`. Defaults to `ExistingNonCompliant`.

## Attributes Reference

The following attributes are exported:
Expand Down

0 comments on commit 3cae546

Please sign in to comment.