diff --git a/internal/compile/manager.go b/internal/compile/manager.go index fccbfc884..72e1fbfb8 100644 --- a/internal/compile/manager.go +++ b/internal/compile/manager.go @@ -5,6 +5,7 @@ package compile import ( "context" + "errors" "fmt" "time" @@ -302,3 +303,7 @@ func (pce PolicyCompilationErr) Error() string { func (pce PolicyCompilationErr) Unwrap() error { return pce.underlying } + +func (pce PolicyCompilationErr) Is(target error) bool { + return errors.As(target, &PolicyCompilationErr{}) +} diff --git a/internal/engine/engine.go b/internal/engine/engine.go index f81ad9de8..eb16ff4d8 100644 --- a/internal/engine/engine.go +++ b/internal/engine/engine.go @@ -182,6 +182,115 @@ func (engine *Engine) submitWork(ctx context.Context, work workIn) error { } } +func (engine *Engine) PlanResources(ctx context.Context, input *enginev1.PlanResourcesInput) (*enginev1.PlanResourcesOutput, error) { + output, err := measurePlanLatency(func() (output *enginev1.PlanResourcesOutput, err error) { + ctx, span := tracing.StartSpan(ctx, "engine.Plan") + defer span.End() + + output, err = engine.doPlanResources(ctx, input) + if err != nil { + tracing.MarkFailed(span, http.StatusBadRequest, err) + } + + return output, err + }) + + return engine.logPlanDecision(ctx, input, output, err) +} + +func (engine *Engine) doPlanResources(ctx context.Context, input *enginev1.PlanResourcesInput) (*enginev1.PlanResourcesOutput, error) { + // exit early if the context is cancelled + if err := ctx.Err(); err != nil { + return nil, err + } + + // get the principal policy check + ppName, ppVersion, ppScope := engine.policyAttr(input.Principal.Id, input.Principal.PolicyVersion, input.Principal.Scope) + policySet, err := engine.getPrincipalPolicySet(ctx, ppName, ppVersion, ppScope) + if err != nil { + return nil, fmt.Errorf("failed to get check for [%s.%s]: %w", ppName, ppVersion, err) + } + + result := new(planner.PolicyPlanResult) + + if policy := policySet.GetPrincipalPolicy(); policy != nil { + policyEvaluator := planner.PrincipalPolicyEvaluator{Policy: policy} + result, err = policyEvaluator.EvaluateResourcesQueryPlan(ctx, input) + if err != nil { + return nil, err + } + } + + // get the resource policy check + rpName, rpVersion, rpScope := engine.policyAttr(input.Resource.Kind, input.Resource.PolicyVersion, input.Resource.Scope) + policySet, err = engine.getResourcePolicySet(ctx, rpName, rpVersion, rpScope) + if err != nil { + return nil, fmt.Errorf("failed to get check for [%s.%s]: %w", rpName, rpVersion, err) + } + + if policy := policySet.GetResourcePolicy(); policy != nil { + policyEvaluator := planner.ResourcePolicyEvaluator{Policy: policy, SchemaMgr: engine.schemaMgr} + plan, err := policyEvaluator.EvaluateResourcesQueryPlan(ctx, input) + if err != nil { + return nil, err + } + + result = planner.CombinePlans(result, plan) + } + + output, err := result.ToPlanResourcesOutput(input) + if err != nil { + return nil, err + } + + if result.Empty() { + output.FilterDebug = noPolicyMatch + } + + return output, nil +} + +func (engine *Engine) logPlanDecision(ctx context.Context, input *enginev1.PlanResourcesInput, output *enginev1.PlanResourcesOutput, planErr error) (*enginev1.PlanResourcesOutput, error) { + if err := engine.auditLog.WriteDecisionLogEntry(ctx, func() (*auditv1.DecisionLogEntry, error) { + callID, ok := audit.CallIDFromContext(ctx) + if !ok { + var err error + callID, err = audit.NewID() + if err != nil { + return nil, err + } + } + + planRes := &auditv1.DecisionLogEntry_PlanResources{ + Input: input, + Output: output, + } + + if planErr != nil { + planRes.Error = planErr.Error() + } + + entry := &auditv1.DecisionLogEntry{ + CallId: string(callID), + Timestamp: timestamppb.New(time.Now()), + Peer: audit.PeerFromContext(ctx), + Method: &auditv1.DecisionLogEntry_PlanResources_{ + PlanResources: planRes, + }, + } + + if engine.metadataExtractor != nil { + entry.Metadata = engine.metadataExtractor(ctx) + } + + return entry, nil + }); err != nil { + logging.FromContext(ctx).Warn("Failed to log decision", zap.Error(err)) + } + + return output, planErr +} + func (engine *Engine) Check(ctx context.Context, inputs []*enginev1.CheckInput, opts ...CheckOpt) ([]*enginev1.CheckOutput, error) { outputs, err := measureCheckLatency(len(inputs), func() (outputs []*enginev1.CheckOutput, err error) { ctx, span := tracing.StartSpan(ctx, "engine.Check") @@ -298,115 +407,6 @@ func (engine *Engine) checkParallel(ctx context.Context, inputs []*enginev1.Chec return outputs, nil } -func (engine *Engine) PlanResources(ctx context.Context, input *enginev1.PlanResourcesInput) (*enginev1.PlanResourcesOutput, error) { - output, err := measurePlanLatency(func() (output *enginev1.PlanResourcesOutput, err error) { - ctx, span := tracing.StartSpan(ctx, "engine.Plan") - defer span.End() - - output, err = engine.doPlanResources(ctx, input) - if err != nil { - tracing.MarkFailed(span, http.StatusBadRequest, err) - } - - return output, err - }) - - return engine.logPlanDecision(ctx, input, output, err) -} - -func (engine *Engine) doPlanResources(ctx context.Context, input *enginev1.PlanResourcesInput) (*enginev1.PlanResourcesOutput, error) { - // exit early if the context is cancelled - if err := ctx.Err(); err != nil { - return nil, err - } - - // get the principal policy check - ppName, ppVersion, ppScope := engine.policyAttr(input.Principal.Id, input.Principal.PolicyVersion, input.Principal.Scope) - policySet, err := engine.getPrincipalPolicySet(ctx, ppName, ppVersion, ppScope) - if err != nil { - return nil, fmt.Errorf("failed to get check for [%s.%s]: %w", ppName, ppVersion, err) - } - - result := new(planner.PolicyPlanResult) - - if policy := policySet.GetPrincipalPolicy(); policy != nil { - policyEvaluator := planner.PrincipalPolicyEvaluator{Policy: policy} - result, err = policyEvaluator.EvaluateResourcesQueryPlan(ctx, input) - if err != nil { - return nil, err - } - } - - // get the resource policy check - rpName, rpVersion, rpScope := engine.policyAttr(input.Resource.Kind, input.Resource.PolicyVersion, input.Resource.Scope) - policySet, err = engine.getResourcePolicySet(ctx, rpName, rpVersion, rpScope) - if err != nil { - return nil, fmt.Errorf("failed to get check for [%s.%s]: %w", rpName, rpVersion, err) - } - - if policy := policySet.GetResourcePolicy(); policy != nil { - policyEvaluator := planner.ResourcePolicyEvaluator{Policy: policy, SchemaMgr: engine.schemaMgr} - plan, err := policyEvaluator.EvaluateResourcesQueryPlan(ctx, input) - if err != nil { - return nil, err - } - - result = planner.CombinePlans(result, plan) - } - - output, err := result.ToPlanResourcesOutput(input) - if err != nil { - return nil, err - } - - if result.Empty() { - output.FilterDebug = noPolicyMatch - } - - return output, nil -} - -func (engine *Engine) logPlanDecision(ctx context.Context, input *enginev1.PlanResourcesInput, output *enginev1.PlanResourcesOutput, planErr error) (*enginev1.PlanResourcesOutput, error) { - if err := engine.auditLog.WriteDecisionLogEntry(ctx, func() (*auditv1.DecisionLogEntry, error) { - callID, ok := audit.CallIDFromContext(ctx) - if !ok { - var err error - callID, err = audit.NewID() - if err != nil { - return nil, err - } - } - - planRes := &auditv1.DecisionLogEntry_PlanResources{ - Input: input, - Output: output, - } - - if planErr != nil { - planRes.Error = planErr.Error() - } - - entry := &auditv1.DecisionLogEntry{ - CallId: string(callID), - Timestamp: timestamppb.New(time.Now()), - Peer: audit.PeerFromContext(ctx), - Method: &auditv1.DecisionLogEntry_PlanResources_{ - PlanResources: planRes, - }, - } - - if engine.metadataExtractor != nil { - entry.Metadata = engine.metadataExtractor(ctx) - } - - return entry, nil - }); err != nil { - logging.FromContext(ctx).Warn("Failed to log decision", zap.Error(err)) - } - - return output, planErr -} - func (engine *Engine) evaluate(ctx context.Context, input *enginev1.CheckInput, checkOpts *checkOptions) (*enginev1.CheckOutput, error) { ctx, span := tracing.StartSpan(ctx, "engine.Evaluate") defer span.End() @@ -482,6 +482,19 @@ func (engine *Engine) buildEvaluationCtx(ctx context.Context, eparams evalParams return ec, nil } +func (engine *Engine) getPrincipalPolicyEvaluator(ctx context.Context, eparams evalParams, principal, policyVer, scope string) (Evaluator, error) { + rps, err := engine.getPrincipalPolicySet(ctx, principal, policyVer, scope) + if err != nil { + return nil, err + } + + if rps == nil { + return nil, nil + } + + return NewEvaluator(rps, engine.schemaMgr, eparams), nil +} + func (engine *Engine) getPrincipalPolicySet(ctx context.Context, principal, policyVer, scope string) (*runtimev1.RunnablePolicySet, error) { ctx, span := tracing.StartSpan(ctx, "engine.GetPrincipalPolicy") defer span.End() @@ -497,8 +510,8 @@ func (engine *Engine) getPrincipalPolicySet(ctx context.Context, principal, poli return rps, nil } -func (engine *Engine) getPrincipalPolicyEvaluator(ctx context.Context, eparams evalParams, principal, policyVer, scope string) (Evaluator, error) { - rps, err := engine.getPrincipalPolicySet(ctx, principal, policyVer, scope) +func (engine *Engine) getResourcePolicyEvaluator(ctx context.Context, eparams evalParams, resource, policyVer, scope string) (Evaluator, error) { + rps, err := engine.getResourcePolicySet(ctx, resource, policyVer, scope) if err != nil { return nil, err } @@ -525,19 +538,6 @@ func (engine *Engine) getResourcePolicySet(ctx context.Context, resource, policy return rps, nil } -func (engine *Engine) getResourcePolicyEvaluator(ctx context.Context, eparams evalParams, resource, policyVer, scope string) (Evaluator, error) { - rps, err := engine.getResourcePolicySet(ctx, resource, policyVer, scope) - if err != nil { - return nil, err - } - - if rps == nil { - return nil, nil - } - - return NewEvaluator(rps, engine.schemaMgr, eparams), nil -} - func (engine *Engine) policyAttr(name, version, scope string) (pName, pVersion, pScope string) { pName = name pVersion = version diff --git a/internal/svc/cerbos_svc.go b/internal/svc/cerbos_svc.go index fed5f3f8d..d42a4eb26 100644 --- a/internal/svc/cerbos_svc.go +++ b/internal/svc/cerbos_svc.go @@ -68,6 +68,9 @@ func (cs *CerbosService) PlanResources(ctx context.Context, request *requestv1.P output, err := cs.eng.PlanResources(logging.ToContext(ctx, log), input) if err != nil { log.Error("Resources query plan request failed", zap.Error(err)) + if errors.Is(err, compile.PolicyCompilationErr{}) { + return nil, status.Errorf(codes.FailedPrecondition, "Resources query plan failed due to invalid policy") + } return nil, status.Errorf(codes.Internal, "Resources query plan request failed") } @@ -135,7 +138,7 @@ func (cs *CerbosService) CheckResourceSet(ctx context.Context, req *requestv1.Ch outputs, err := cs.eng.Check(logging.ToContext(ctx, log), inputs) if err != nil { log.Error("Policy check failed", zap.Error(err)) - if errors.As(err, &compile.PolicyCompilationErr{}) { + if errors.Is(err, compile.PolicyCompilationErr{}) { return nil, status.Errorf(codes.FailedPrecondition, "Check failed due to invalid policy") } return nil, status.Errorf(codes.Internal, "Policy check failed") @@ -183,6 +186,9 @@ func (cs *CerbosService) CheckResourceBatch(ctx context.Context, req *requestv1. outputs, err := cs.eng.Check(logging.ToContext(ctx, log), inputs) if err != nil { log.Error("Policy check failed", zap.Error(err)) + if errors.Is(err, compile.PolicyCompilationErr{}) { + return nil, status.Errorf(codes.FailedPrecondition, "Check failed due to invalid policy") + } return nil, status.Errorf(codes.Internal, "Policy check failed") } @@ -240,6 +246,9 @@ func (cs *CerbosService) CheckResources(ctx context.Context, req *requestv1.Chec outputs, err := cs.eng.Check(logging.ToContext(ctx, log), inputs) if err != nil { log.Error("Policy check failed", zap.Error(err)) + if errors.Is(err, compile.PolicyCompilationErr{}) { + return nil, status.Errorf(codes.FailedPrecondition, "Check failed due to invalid policy") + } return nil, status.Errorf(codes.Internal, "Policy check failed") }