Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

bugfix: resolve several potential crashes in new PIM resources #1375

Merged
merged 2 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -101,31 +101,31 @@ func (r PrivilegedAccessGroupAssignmentScheduleResource) Read() sdk.ResourceFunc
return sdk.ResourceFunc{
Timeout: 5 * time.Minute,
Func: func(ctx context.Context, metadata sdk.ResourceMetaData) error {
cSchedule := metadata.Client.IdentityGovernance.PrivilegedAccessGroupAssignmentScheduleClient
cRequests := metadata.Client.IdentityGovernance.PrivilegedAccessGroupAssignmentScheduleRequestsClient

var request *msgraph.PrivilegedAccessGroupAssignmentScheduleRequest
scheduleClient := metadata.Client.IdentityGovernance.PrivilegedAccessGroupAssignmentScheduleClient
requestsClient := metadata.Client.IdentityGovernance.PrivilegedAccessGroupAssignmentScheduleRequestsClient

id, err := parse.ParsePrivilegedAccessGroupScheduleID(metadata.ResourceData.Id())
if err != nil {
return err
}

var model PrivilegedAccessGroupScheduleModel
if err := metadata.Decode(&model); err != nil {
if err = metadata.Decode(&model); err != nil {
return fmt.Errorf("decoding: %+v", err)
}

schedule, status, err := cSchedule.Get(ctx, id.ID())
if err != nil && status != http.StatusNotFound {
schedule, scheduleStatus, err := scheduleClient.Get(ctx, id.ID())
if err != nil && scheduleStatus != http.StatusNotFound {
return fmt.Errorf("retrieving %s: %+v", id, err)
}

var request *msgraph.PrivilegedAccessGroupAssignmentScheduleRequest

// Some details are only available on the request which is used for the create/update of the schedule.
// Schedule requests are never deleted. New ones are created when changes are made.
// Therefore on a read, we need to find the latest version of the request.
// This is to cater for changes being made outside of Terraform.
requests, _, err := cRequests.List(ctx, odata.Query{
requests, _, err := requestsClient.List(ctx, odata.Query{
Filter: fmt.Sprintf("groupId eq '%s' and targetScheduleId eq '%s'", id.GroupId, id.ID()),
OrderBy: odata.OrderBy{
Field: "createdDateTime",
Expand All @@ -135,45 +135,57 @@ func (r PrivilegedAccessGroupAssignmentScheduleResource) Read() sdk.ResourceFunc
if err != nil {
return fmt.Errorf("listing requests: %+v", err)
}
if len(*requests) == 0 {
if status == http.StatusNotFound {
if requests == nil || len(*requests) == 0 {
if scheduleStatus == http.StatusNotFound {
// No request and no schedule was found
return metadata.MarkAsGone(id)
}
} else {
request = pointer.To((*requests)[0])

model.Justification = *request.Justification
if request.TicketInfo.TicketNumber != nil {
model.TicketNumber = *request.TicketInfo.TicketNumber
}
if request.TicketInfo.TicketSystem != nil {
model.TicketSystem = *request.TicketInfo.TicketSystem
}
if request.ScheduleInfo.Expiration.Duration != nil {
model.Duration = *request.ScheduleInfo.Expiration.Duration
}
}

// Typically this is because the request has expired
// So we populate the model with the schedule details
if status == http.StatusNotFound {
var scheduleInfo *msgraph.RequestSchedule

if request != nil {
// The request is still present, populate from the request
scheduleInfo = request.ScheduleInfo

model.AssignmentType = request.AccessId
model.ExpirationDate = request.ScheduleInfo.Expiration.EndDateTime.Format(time.RFC3339)
model.GroupId = *request.GroupId
model.PermanentAssignment = *request.ScheduleInfo.Expiration.Type == msgraph.ExpirationPatternTypeNoExpiration
model.PrincipalId = *request.PrincipalId
model.StartDate = request.ScheduleInfo.StartDateTime.Format(time.RFC3339)
model.GroupId = pointer.From(request.GroupId)
model.Justification = pointer.From(request.Justification)
model.PrincipalId = pointer.From(request.PrincipalId)
model.Status = request.Status

if ticketInfo := request.TicketInfo; ticketInfo != nil {
model.TicketNumber = pointer.From(ticketInfo.TicketNumber)
model.TicketSystem = pointer.From(ticketInfo.TicketSystem)
}
} else {
// The request has likely expired, so populate from the schedule
scheduleInfo = schedule.ScheduleInfo

model.AssignmentType = schedule.AccessId
model.ExpirationDate = schedule.ScheduleInfo.Expiration.EndDateTime.Format(time.RFC3339)
model.GroupId = *schedule.GroupId
model.PermanentAssignment = *schedule.ScheduleInfo.Expiration.Type == msgraph.ExpirationPatternTypeNoExpiration
model.PrincipalId = *schedule.PrincipalId
model.StartDate = schedule.ScheduleInfo.StartDateTime.Format(time.RFC3339)
model.GroupId = pointer.From(schedule.GroupId)
model.PrincipalId = pointer.From(schedule.PrincipalId)
model.Status = schedule.Status
}

if scheduleInfo != nil {
if expiration := scheduleInfo.Expiration; expiration != nil {
model.Duration = pointer.From(expiration.Duration)

if expiration.EndDateTime != nil {
model.ExpirationDate = expiration.EndDateTime.Format(time.RFC3339)
}
if expiration.Type != nil {
model.PermanentAssignment = *expiration.Type == msgraph.ExpirationPatternTypeNoExpiration
}
}
if scheduleInfo.StartDateTime != nil {
model.StartDate = scheduleInfo.StartDateTime.Format(time.RFC3339)
}
}

return metadata.Encode(&model)
},
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,31 +101,31 @@ func (r PrivilegedAccessGroupEligibilityScheduleResource) Read() sdk.ResourceFun
return sdk.ResourceFunc{
Timeout: 5 * time.Minute,
Func: func(ctx context.Context, metadata sdk.ResourceMetaData) error {
cSchedule := metadata.Client.IdentityGovernance.PrivilegedAccessGroupEligibilityScheduleClient
cRequests := metadata.Client.IdentityGovernance.PrivilegedAccessGroupEligibilityScheduleRequestsClient

var request *msgraph.PrivilegedAccessGroupEligibilityScheduleRequest
scheduleClient := metadata.Client.IdentityGovernance.PrivilegedAccessGroupEligibilityScheduleClient
requestsClient := metadata.Client.IdentityGovernance.PrivilegedAccessGroupEligibilityScheduleRequestsClient

id, err := parse.ParsePrivilegedAccessGroupScheduleID(metadata.ResourceData.Id())
if err != nil {
return err
}

var model PrivilegedAccessGroupScheduleModel
if err := metadata.Decode(&model); err != nil {
if err = metadata.Decode(&model); err != nil {
return fmt.Errorf("decoding: %+v", err)
}

schedule, status, err := cSchedule.Get(ctx, id.ID())
if err != nil && status != http.StatusNotFound {
schedule, scheduleStatus, err := scheduleClient.Get(ctx, id.ID())
if err != nil && scheduleStatus != http.StatusNotFound {
return fmt.Errorf("retrieving %s: %+v", id, err)
}

var request *msgraph.PrivilegedAccessGroupEligibilityScheduleRequest

// Some details are only available on the request which is used for the create/update of the schedule.
// Schedule requests are never deleted. New ones are created when changes are made.
// Therefore on a read, we need to find the latest version of the request.
// This is to cater for changes being made outside of Terraform.
requests, _, err := cRequests.List(ctx, odata.Query{
requests, _, err := requestsClient.List(ctx, odata.Query{
Filter: fmt.Sprintf("groupId eq '%s' and targetScheduleId eq '%s'", id.GroupId, id.ID()),
OrderBy: odata.OrderBy{
Field: "createdDateTime",
Expand All @@ -135,45 +135,57 @@ func (r PrivilegedAccessGroupEligibilityScheduleResource) Read() sdk.ResourceFun
if err != nil {
return fmt.Errorf("listing requests: %+v", err)
}
if len(*requests) == 0 {
if status == http.StatusNotFound {
if requests == nil || len(*requests) == 0 {
if scheduleStatus == http.StatusNotFound {
// No request and no schedule was found
return metadata.MarkAsGone(id)
}
} else {
request = pointer.To((*requests)[0])

model.Justification = *request.Justification
if request.TicketInfo.TicketNumber != nil {
model.TicketNumber = *request.TicketInfo.TicketNumber
}
if request.TicketInfo.TicketSystem != nil {
model.TicketSystem = *request.TicketInfo.TicketSystem
}
if request.ScheduleInfo.Expiration.Duration != nil {
model.Duration = *request.ScheduleInfo.Expiration.Duration
}
}

// Typically this is because the request has expired
// So we populate the model with the schedule details
if status == http.StatusNotFound {
var scheduleInfo *msgraph.RequestSchedule

if request != nil {
// The request is still present, populate from the request
scheduleInfo = request.ScheduleInfo

model.AssignmentType = request.AccessId
model.ExpirationDate = request.ScheduleInfo.Expiration.EndDateTime.Format(time.RFC3339)
model.GroupId = *request.GroupId
model.PermanentAssignment = *request.ScheduleInfo.Expiration.Type == msgraph.ExpirationPatternTypeNoExpiration
model.PrincipalId = *request.PrincipalId
model.StartDate = request.ScheduleInfo.StartDateTime.Format(time.RFC3339)
model.GroupId = pointer.From(request.GroupId)
model.Justification = pointer.From(request.Justification)
model.PrincipalId = pointer.From(request.PrincipalId)
model.Status = request.Status

if ticketInfo := request.TicketInfo; ticketInfo != nil {
model.TicketNumber = pointer.From(ticketInfo.TicketNumber)
model.TicketSystem = pointer.From(ticketInfo.TicketSystem)
}
} else {
// The request has likely expired, so populate from the schedule
scheduleInfo = schedule.ScheduleInfo

model.AssignmentType = schedule.AccessId
model.ExpirationDate = schedule.ScheduleInfo.Expiration.EndDateTime.Format(time.RFC3339)
model.GroupId = *schedule.GroupId
model.PermanentAssignment = *schedule.ScheduleInfo.Expiration.Type == msgraph.ExpirationPatternTypeNoExpiration
model.PrincipalId = *schedule.PrincipalId
model.StartDate = schedule.ScheduleInfo.StartDateTime.Format(time.RFC3339)
model.GroupId = pointer.From(schedule.GroupId)
model.PrincipalId = pointer.From(schedule.PrincipalId)
model.Status = schedule.Status
}

if scheduleInfo != nil {
if expiration := scheduleInfo.Expiration; expiration != nil {
model.Duration = pointer.From(expiration.Duration)

if expiration.EndDateTime != nil {
model.ExpirationDate = expiration.EndDateTime.Format(time.RFC3339)
}
if expiration.Type != nil {
model.PermanentAssignment = *expiration.Type == msgraph.ExpirationPatternTypeNoExpiration
}
}
if scheduleInfo.StartDateTime != nil {
model.StartDate = scheduleInfo.StartDateTime.Format(time.RFC3339)
}
}

return metadata.Encode(&model)
},
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,10 @@ func buildScheduleRequest(model *PrivilegedAccessGroupScheduleModel, metadata *s
schedule := msgraph.RequestSchedule{}
schedule.Expiration = &msgraph.ExpirationPattern{}
var startDate, expiryDate time.Time
var err error

if model.StartDate != "" {
startDate, err := time.Parse(time.RFC3339, model.StartDate)
startDate, err = time.Parse(time.RFC3339, model.StartDate)
if err != nil {
return nil, fmt.Errorf("parsing %s: %+v", model.StartDate, err)
}
Expand All @@ -159,7 +160,7 @@ func buildScheduleRequest(model *PrivilegedAccessGroupScheduleModel, metadata *s

switch {
case model.ExpirationDate != "":
expiryDate, err := time.Parse(time.RFC3339, model.ExpirationDate)
expiryDate, err = time.Parse(time.RFC3339, model.ExpirationDate)
if err != nil {
return nil, fmt.Errorf("parsing %s: %+v", model.ExpirationDate, err)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ package policies

import (
"context"
"errors"
"fmt"

"github.com/hashicorp/go-azure-helpers/lang/pointer"
"github.com/hashicorp/go-azure-sdk/sdk/odata"
"github.com/hashicorp/terraform-plugin-sdk/v2/helper/schema"
"github.com/hashicorp/terraform-provider-azuread/internal/sdk"
Expand Down Expand Up @@ -89,7 +89,7 @@ func (r GroupRoleManagementPolicyDataSource) Read() sdk.ResourceFunc {
roleID := metadata.ResourceData.Get("role_id").(string)
id, err := getPolicyId(ctx, metadata, groupID, roleID)
if err != nil {
return errors.New("Bad API response")
return fmt.Errorf("determining Policy ID: %+v", err)
}

result, _, err := clientPolicy.Get(ctx, id.ID())
Expand All @@ -106,15 +106,18 @@ func (r GroupRoleManagementPolicyDataSource) Read() sdk.ResourceFunc {
if err != nil {
return fmt.Errorf("retrieving %s: %+v", id, err)
}
if assignments == nil {
return fmt.Errorf("retrieving %s: expected 1 assignment, got nil result", id)
}
if len(*assignments) != 1 {
return fmt.Errorf("retrieving %s: expected 1 assignment, got %d", id, len(*assignments))
}

state := GroupRoleManagementPolicyDataSourceModel{
Description: *result.Description,
DisplayName: *result.DisplayName,
GroupId: *result.ScopeId,
RoleId: *(*assignments)[0].RoleDefinitionId,
Description: pointer.From(result.Description),
DisplayName: pointer.From(result.DisplayName),
GroupId: pointer.From(result.ScopeId),
RoleId: pointer.From((*assignments)[0].RoleDefinitionId),
}

metadata.ResourceData.SetId(id.ID())
Expand Down
Loading
Loading