diff --git a/pkg/core/policy/policy.go b/pkg/core/policy/policy.go index 6a991aaa4..a5d5dadc0 100644 --- a/pkg/core/policy/policy.go +++ b/pkg/core/policy/policy.go @@ -5,5 +5,5 @@ type Provider interface { } type Checker interface { - Check(organisation string, namespace string, projectname string, input interface{}) (bool, error) + Check(organisation string, namespace string, projectname string, command string, requestedBy string) (bool, error) } diff --git a/pkg/digger/digger.go b/pkg/digger/digger.go index 55fba5dc9..b8ab4d00d 100644 --- a/pkg/digger/digger.go +++ b/pkg/digger/digger.go @@ -76,24 +76,12 @@ func RunCommandsPerProject( plansToPublish := make([]string, 0) organisation := strings.Split(projectNamespace, "/")[0] - teams, err := ciService.GetUserTeams(organisation, requestedBy) - if err != nil { - fmt.Printf("Error while fetching user teams for CI service: %v", err) - } for _, projectCommands := range commandsPerProject { for _, command := range projectCommands.Commands { fmt.Printf("Running '%s' for project '%s'\n", command, projectCommands.ProjectName) - policyInput := map[string]interface{}{ - "user": requestedBy, - "organisation": organisation, - "teams": teams, - "action": command, - "project": projectCommands.ProjectName, - } - - allowedToPerformCommand, err := policyChecker.Check(organisation, projectNamespace, projectCommands.ProjectName, policyInput) + allowedToPerformCommand, err := policyChecker.Check(organisation, projectNamespace, projectCommands.ProjectName, command, requestedBy) if err != nil { return false, false, fmt.Errorf("error checking policy: %v", err) diff --git a/pkg/policy/policy.go b/pkg/policy/policy.go index 0d97503ec..26b5fb23b 100644 --- a/pkg/policy/policy.go +++ b/pkg/policy/policy.go @@ -2,6 +2,7 @@ package policy import ( "context" + "digger/pkg/ci" "errors" "fmt" "github.com/open-policy-agent/opa/rego" @@ -23,7 +24,7 @@ type DiggerHttpPolicyProvider struct { type NoOpPolicyChecker struct { } -func (p NoOpPolicyChecker) Check(_ string, _ string, _ string, _ interface{}) (bool, error) { +func (p NoOpPolicyChecker) Check(_ string, _ string, _ string, _ string, _ string) (bool, error) { return true, nil } @@ -100,15 +101,26 @@ func (p *DiggerHttpPolicyProvider) GetPolicy(organisation string, namespace stri type DiggerPolicyChecker struct { PolicyProvider PolicyProvider + ciService ci.CIService } -func (p DiggerPolicyChecker) Check(organisation string, namespace string, projectName string, input interface{}) (bool, error) { +func (p DiggerPolicyChecker) Check(organisation string, namespace string, projectName string, command string, requestedBy string) (bool, error) { policy, err := p.PolicyProvider.GetPolicy(organisation, namespace, projectName) + teams, err := p.ciService.GetUserTeams(organisation, requestedBy) if err != nil { + fmt.Printf("Error while fetching user teams for CI service: %v", err) return false, err } + input := map[string]interface{}{ + "user": requestedBy, + "organisation": organisation, + "teams": teams, + "action": command, + "project": projectName, + } + if policy == "" { return true, nil } diff --git a/pkg/policy/policy_test.go b/pkg/policy/policy_test.go index 501963f3a..3079b2754 100644 --- a/pkg/policy/policy_test.go +++ b/pkg/policy/policy_test.go @@ -1,6 +1,7 @@ package policy import ( + "digger/pkg/utils" "testing" ) @@ -95,87 +96,58 @@ func TestDiggerPolicyChecker_Check(t *testing.T) { args args want bool wantErr bool + command string + requestedBy string }{ - { - name: "test opa example", - fields: fields{ - PolicyProvider: &OpaExamplePolicyProvider{}, - }, - args: args{ - input: map[string]interface{}{ - "organisation": "diggerhq", - "user": "alice", - "action": "read", - "object": "server123", - }, - }, - want: true, - wantErr: false, - }, { name: "test digger example", fields: fields{ PolicyProvider: &DiggerExamplePolicyProvider{}, }, - args: args{ - input: map[string]interface{}{ - "user": "motatoes", - "action": "digger plan", - }, - }, - want: true, - wantErr: false, + want: true, + wantErr: false, + command: "digger plan", + requestedBy: "motatoes", }, { name: "test digger example 2", fields: fields{ PolicyProvider: &DiggerExamplePolicyProvider{}, }, - args: args{ - input: map[string]interface{}{ - "user": "Spartakovic", - "action": "digger unlock", - }, - }, - want: false, - wantErr: false, + want: false, + wantErr: false, + command: "digger unlock", + requestedBy: "Spartakovic", }, { name: "test digger example 3", fields: fields{ PolicyProvider: &DiggerExamplePolicyProvider{}, }, - args: args{ - input: map[string]interface{}{ - "user": "rando", - "action": "digger apply", - }, - }, - want: false, - wantErr: false, + want: false, + wantErr: false, + command: "digger apply", + requestedBy: "rando", }, { name: "test digger example 4", fields: fields{ PolicyProvider: &DiggerExamplePolicyProvider2{}, }, - args: args{ - input: map[string]interface{}{ - "user": "motatoes", - "action": "digger plan", - }, - }, - want: true, - wantErr: false, + want: true, + wantErr: false, + command: "digger plan", + requestedBy: "motatoes", }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - p := &DiggerPolicyChecker{ + var p = &DiggerPolicyChecker{ PolicyProvider: tt.fields.PolicyProvider, + ciService: utils.MockPullRequestManager{Teams: []string{"engineering"}}, } - got, err := p.Check(tt.organisation, tt.name, tt.name, tt.args.input) + got, err := p.Check(tt.organisation, tt.name, tt.name, tt.command, tt.requestedBy) if (err != nil) != tt.wantErr { t.Errorf("DiggerPolicyChecker.Check() error = %v, wantErr %v", err, tt.wantErr) return diff --git a/pkg/utils/mocks.go b/pkg/utils/mocks.go index fd8b34e3c..50366fd70 100644 --- a/pkg/utils/mocks.go +++ b/pkg/utils/mocks.go @@ -42,16 +42,17 @@ func (lock *MockLock) GetLock(resource string) (*int, error) { type MockPolicyChecker struct { } -func (t MockPolicyChecker) Check(organisation string, namespace string, projectname string, input interface{}) (bool, error) { +func (t MockPolicyChecker) Check(organisation string, namespace string, projectname string, command string, requestedBy string) (bool, error) { return false, nil } type MockPullRequestManager struct { ChangedFiles []string + Teams []string } func (t MockPullRequestManager) GetUserTeams(organisation string, user string) ([]string, error) { - return []string{}, nil + return t.Teams, nil } func (t MockPullRequestManager) GetChangedFiles(prNumber int) ([]string, error) {