Skip to content

Commit

Permalink
Merge pull request #7139 from evanchaoli/soft-policy
Browse files Browse the repository at this point in the history
Support soft policy enforcement
  • Loading branch information
chenbh committed Jun 24, 2021
2 parents d7346bf + 945343e commit 5558b97
Show file tree
Hide file tree
Showing 28 changed files with 1,019 additions and 286 deletions.
8 changes: 4 additions & 4 deletions atc/api/policychecker/checker.go
Expand Up @@ -16,7 +16,7 @@ import (

//counterfeiter:generate . PolicyChecker
type PolicyChecker interface {
Check(string, accessor.Access, *http.Request) (policy.PolicyCheckOutput, error)
Check(string, accessor.Access, *http.Request) (policy.PolicyCheckResult, error)
}

type checker struct {
Expand All @@ -27,7 +27,7 @@ func NewApiPolicyChecker(policyChecker policy.Checker) PolicyChecker {
return &checker{policyChecker: policyChecker}
}

func (c *checker) Check(action string, acc accessor.Access, req *http.Request) (policy.PolicyCheckOutput, error) {
func (c *checker) Check(action string, acc accessor.Access, req *http.Request) (policy.PolicyCheckResult, error) {
// Ignore self invoked API calls.
if acc.IsSystem() {
return policy.PassedPolicyCheck(), nil
Expand Down Expand Up @@ -59,15 +59,15 @@ func (c *checker) Check(action string, acc accessor.Access, req *http.Request) (
case "application/json", "text/vnd.yaml", "text/yaml", "text/x-yaml", "application/x-yaml":
body, err := ioutil.ReadAll(req.Body)
if err != nil {
return policy.FailedPolicyCheck(), err
return nil, err
} else if len(body) > 0 {
if ct == "application/json" {
err = json.Unmarshal(body, &input.Data)
} else {
err = yaml.Unmarshal(body, &input.Data)
}
if err != nil {
return policy.FailedPolicyCheck(), err
return nil, err
}

req.Body = ioutil.NopCloser(bytes.NewBuffer(body))
Expand Down
46 changes: 26 additions & 20 deletions atc/api/policychecker/checker_test.go
Expand Up @@ -19,16 +19,19 @@ import (

var _ = Describe("PolicyChecker", func() {
var (
policyFilter policy.Filter
fakeAccess *accessorfakes.FakeAccess
fakeRequest *http.Request
result policy.PolicyCheckOutput
checkErr error
policyFilter policy.Filter
fakeAccess *accessorfakes.FakeAccess
fakeRequest *http.Request
result policy.PolicyCheckResult
checkErr error
fakePolicyCheckResult *policyfakes.FakePolicyCheckResult
)

BeforeEach(func() {
fakeAccess = new(accessorfakes.FakeAccess)
fakePolicyCheckResult = new(policyfakes.FakePolicyCheckResult)
fakePolicyAgent = new(policyfakes.FakeAgent)
fakePolicyAgent.CheckReturns(fakePolicyCheckResult, nil)
fakePolicyAgentFactory.NewAgentReturns(fakePolicyAgent, nil)

policyFilter = policy.Filter{
Expand All @@ -51,7 +54,7 @@ var _ = Describe("PolicyChecker", func() {
})
It("should pass", func() {
Expect(checkErr).ToNot(HaveOccurred())
Expect(result.Allowed).To(BeTrue())
Expect(result.Allowed()).To(BeTrue())
})
It("Agent should not be called", func() {
Expect(fakePolicyAgent.CheckCallCount()).To(Equal(0))
Expand All @@ -69,7 +72,7 @@ var _ = Describe("PolicyChecker", func() {
})
It("should pass", func() {
Expect(checkErr).ToNot(HaveOccurred())
Expect(result.Allowed).To(BeTrue())
Expect(result.Allowed()).To(BeTrue())
})
It("Agent should not be called", func() {
Expect(fakePolicyAgent.CheckCallCount()).To(Equal(0))
Expand All @@ -83,7 +86,7 @@ var _ = Describe("PolicyChecker", func() {
})
It("should pass", func() {
Expect(checkErr).ToNot(HaveOccurred())
Expect(result.Allowed).To(BeTrue())
Expect(result.Allowed()).To(BeTrue())
})
It("Agent should not be called", func() {
Expect(fakePolicyAgent.CheckCallCount()).To(Equal(0))
Expand All @@ -98,7 +101,7 @@ var _ = Describe("PolicyChecker", func() {
})
It("should pass", func() {
Expect(checkErr).ToNot(HaveOccurred())
Expect(result.Allowed).To(BeTrue())
Expect(result.Allowed()).To(BeTrue())
})
It("Agent should not be called", func() {
Expect(fakePolicyAgent.CheckCallCount()).To(Equal(0))
Expand All @@ -121,7 +124,7 @@ var _ = Describe("PolicyChecker", func() {
It("should error", func() {
Expect(checkErr).To(HaveOccurred())
Expect(checkErr.Error()).To(Equal(`invalid character 'h' looking for beginning of value`))
Expect(result.Allowed).To(BeFalse())
Expect(result).To(BeNil())
})
It("Agent should not be called", func() {
Expect(fakePolicyAgent.CheckCallCount()).To(Equal(0))
Expand All @@ -138,8 +141,9 @@ var _ = Describe("PolicyChecker", func() {
It("should error", func() {
Expect(checkErr).To(HaveOccurred())
Expect(checkErr.Error()).To(Equal(`error converting YAML to JSON: yaml: line 3: could not find expected ':'`))
Expect(result.Allowed).To(BeFalse())
Expect(result).To(BeNil())
})

It("Agent should not be called", func() {
Expect(fakePolicyAgent.CheckCallCount()).To(Equal(0))
})
Expand All @@ -160,9 +164,11 @@ var _ = Describe("PolicyChecker", func() {
It("should not error", func() {
Expect(checkErr).ToNot(HaveOccurred())
})

It("Agent should be called", func() {
Expect(fakePolicyAgent.CheckCallCount()).To(Equal(1))
})

It("Agent should take correct input", func() {
Expect(fakePolicyAgent.CheckArgsForCall(0)).To(Equal(policy.PolicyCheckInput{
Service: "concourse",
Expand Down Expand Up @@ -191,34 +197,34 @@ var _ = Describe("PolicyChecker", func() {

It("it should pass", func() {
Expect(checkErr).ToNot(HaveOccurred())
Expect(result.Allowed).To(BeTrue())
Expect(result.Allowed()).To(BeTrue())
})
})

Context("when Agent says not-pass", func() {
BeforeEach(func() {
fakePolicyAgent.CheckReturns(policy.PolicyCheckOutput{
Allowed: false,
Reasons: []string{"a policy says you can't do that"},
}, nil)
fakePolicyCheckResult.AllowedReturns(false)
fakePolicyCheckResult.ShouldBlockReturns(true)
fakePolicyCheckResult.MessagesReturns([]string{"a policy says you can't do that"})
})

It("should not pass", func() {
Expect(checkErr).ToNot(HaveOccurred())
Expect(result.Allowed).To(BeFalse())
Expect(result.Reasons).To(ConsistOf("a policy says you can't do that"))
Expect(result.Allowed()).To(BeFalse())
Expect(result.ShouldBlock()).To(BeTrue())
Expect(result.Messages()).To(ConsistOf("a policy says you can't do that"))
})
})

Context("when Agent says error", func() {
BeforeEach(func() {
fakePolicyAgent.CheckReturns(policy.FailedPolicyCheck(), errors.New("some-error"))
fakePolicyAgent.CheckReturns(nil, errors.New("some-error"))
})

It("should not pass", func() {
Expect(checkErr).To(HaveOccurred())
Expect(checkErr.Error()).To(Equal("some-error"))
Expect(result.Allowed).To(BeFalse())
Expect(result).To(BeNil())
})
})
})
Expand Down
14 changes: 9 additions & 5 deletions atc/api/policychecker/handler.go
Expand Up @@ -41,13 +41,17 @@ func (h policyCheckingHandler) ServeHTTP(w http.ResponseWriter, r *http.Request)
return
}

if !result.Allowed {
w.WriteHeader(http.StatusForbidden)
if !result.Allowed() {
policyCheckErr := policy.PolicyCheckNotPass{
Reasons: result.Reasons,
Messages: result.Messages(),
}
if result.ShouldBlock() {
w.WriteHeader(http.StatusForbidden)
fmt.Fprint(w, policyCheckErr.Error())
return
} else {
w.Header().Add("X-Concourse-Policy-Check-Warning", policyCheckErr.Error())
}
fmt.Fprint(w, policyCheckErr.Error())
return
}

h.handler.ServeHTTP(w, r)
Expand Down
61 changes: 43 additions & 18 deletions atc/api/policychecker/handler_test.go
Expand Up @@ -11,25 +11,28 @@ import (
"github.com/concourse/concourse/atc/api/policychecker"
"github.com/concourse/concourse/atc/api/policychecker/policycheckerfakes"
"github.com/concourse/concourse/atc/policy"
"github.com/concourse/concourse/atc/policy/policyfakes"

. "github.com/onsi/ginkgo"
. "github.com/onsi/gomega"
)

var _ = Describe("Handler", func() {
var (
innerHandlerCalled bool
dummyHandler http.HandlerFunc
policyCheckerHandler http.Handler
req *http.Request
fakePolicyChecker *policycheckerfakes.FakePolicyChecker
responseWriter *httptest.ResponseRecorder
innerHandlerCalled bool
dummyHandler http.HandlerFunc
policyCheckerHandler http.Handler
req *http.Request
fakePolicyChecker *policycheckerfakes.FakePolicyChecker
fakePolicyCheckResult *policyfakes.FakePolicyCheckResult
responseWriter *httptest.ResponseRecorder

logger = lagertest.NewTestLogger("test")
)

BeforeEach(func() {
fakePolicyChecker = new(policycheckerfakes.FakePolicyChecker)
fakePolicyCheckResult = new(policyfakes.FakePolicyCheckResult)

innerHandlerCalled = false
dummyHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
Expand Down Expand Up @@ -63,28 +66,50 @@ var _ = Describe("Handler", func() {

Context("policy check doesn't pass", func() {
BeforeEach(func() {
fakePolicyChecker.CheckReturns(policy.PolicyCheckOutput{
Allowed: false,
Reasons: []string{"a policy says you can't do that", "another policy also says you can't do that"},
}, nil)
fakePolicyCheckResult.AllowedReturns(false)
fakePolicyCheckResult.MessagesReturns([]string{"a policy says you can't do that", "another policy also says you can't do that"})
fakePolicyChecker.CheckReturns(fakePolicyCheckResult, nil)
})

It("return http forbidden", func() {
Expect(responseWriter.Code).To(Equal(http.StatusForbidden))
Context("when should block", func() {
BeforeEach(func() {
fakePolicyCheckResult.ShouldBlockReturns(true)
})

msg, err := ioutil.ReadAll(responseWriter.Body)
Expect(err).ToNot(HaveOccurred())
Expect(string(msg)).To(Equal("policy check failed: a policy says you can't do that, another policy also says you can't do that"))
It("return http forbidden", func() {
Expect(responseWriter.Code).To(Equal(http.StatusForbidden))

msg, err := ioutil.ReadAll(responseWriter.Body)
Expect(err).ToNot(HaveOccurred())
Expect(string(msg)).To(ContainSubstring("a policy says you can't do that"))
Expect(string(msg)).To(ContainSubstring("another policy also says you can't do that"))
})

It("not call the inner handler", func() {
Expect(innerHandlerCalled).To(BeFalse())
})
})

It("not call the inner handler", func() {
Expect(innerHandlerCalled).To(BeFalse())
Context("when should not block", func() {
BeforeEach(func() {
fakePolicyCheckResult.ShouldBlockReturns(false)
})

It("calls the inner handler", func() {
Expect(innerHandlerCalled).To(BeTrue())
})

It("response should have a header about policy check warning", func() {
value := responseWriter.Header().Get("X-Concourse-Policy-Check-Warning")
Expect(value).To(ContainSubstring("a policy says you can't do that"))
Expect(value).To(ContainSubstring("another policy also says you can't do that"))
})
})
})

Context("policy check errors", func() {
BeforeEach(func() {
fakePolicyChecker.CheckReturns(policy.FailedPolicyCheck(), errors.New("some-error"))
fakePolicyChecker.CheckReturns(nil, errors.New("some-error"))
})

It("return http bad request", func() {
Expand Down
20 changes: 10 additions & 10 deletions atc/api/policychecker/policycheckerfakes/fake_policy_checker.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 5558b97

Please sign in to comment.