From 55f2f2cb079b326978632aa8a62a345cd632db87 Mon Sep 17 00:00:00 2001
From: Erik Godding Boye <egboye@gmail.com>
Date: Fri, 22 Sep 2023 21:38:32 +0200
Subject: [PATCH] Refactor Evaluate preparing for CEL

Signed-off-by: Erik Godding Boye <egboye@gmail.com>
---
 pkg/internal/approver/allowed/evaluator.go | 285 +++++++++++----------
 1 file changed, 149 insertions(+), 136 deletions(-)

diff --git a/pkg/internal/approver/allowed/evaluator.go b/pkg/internal/approver/allowed/evaluator.go
index 14bd78ce..f3207f15 100644
--- a/pkg/internal/approver/allowed/evaluator.go
+++ b/pkg/internal/approver/allowed/evaluator.go
@@ -18,6 +18,8 @@ package allowed
 
 import (
 	"context"
+	"crypto/x509"
+	"crypto/x509/pkix"
 	"strconv"
 	"strings"
 
@@ -55,178 +57,189 @@ func (a allowed) Evaluate(_ context.Context, policy *policyapi.CertificateReques
 		return approver.EvaluationResponse{}, err
 	}
 
-	if len(csr.Subject.CommonName) > 0 {
-		if allowed.CommonName == nil || allowed.CommonName.Value == nil {
-			el = append(el, field.Invalid(fldPath.Child("commonName", "value"), csr.Subject.CommonName, "nil"))
-		} else if !util.WildcardMatches(*allowed.CommonName.Value, csr.Subject.CommonName) {
-			el = append(el, field.Invalid(fldPath.Child("commonName", "value"), csr.Subject.CommonName, *allowed.CommonName.Value))
-		}
-	} else if allowed.CommonName != nil && allowed.CommonName.Required != nil && *allowed.CommonName.Required {
-		el = append(el, field.Required(fldPath.Child("commonName", "required"), strconv.FormatBool(*allowed.CommonName.Required)))
+	evaluate := evaluator{
+		request: request,
+		csr:     csr,
+		allowed: allowed,
+		fldPath: fldPath,
 	}
+	evaluateSubject := evaluate.Subject()
 
-	if len(csr.DNSNames) > 0 {
-		if allowed.DNSNames == nil || allowed.DNSNames.Values == nil {
-			el = append(el, field.Invalid(fldPath.Child("dnsNames", "values"), csr.DNSNames, "nil"))
-		} else if !util.WildcardSubset(*allowed.DNSNames.Values, csr.DNSNames) {
-			el = append(el, field.Invalid(fldPath.Child("dnsNames", "values"), csr.DNSNames, strings.Join(*allowed.DNSNames.Values, ", ")))
-		}
-	} else if allowed.DNSNames != nil && allowed.DNSNames.Required != nil && *allowed.DNSNames.Required {
-		el = append(el, field.Required(fldPath.Child("dnsNames", "required"), strconv.FormatBool(*allowed.DNSNames.Required)))
+	evaluateFns := []func() *field.Error{
+		evaluate.CommonName,
+		evaluate.DNSNames,
+		evaluate.IPAddresses,
+		evaluate.URIs,
+		evaluate.EmailAddresses,
+		evaluate.IsCA,
+		evaluate.Usages,
+		evaluateSubject.Organization,
+		evaluateSubject.Country,
+		evaluateSubject.OrganizationalUnit,
+		evaluateSubject.Locality,
+		evaluateSubject.Province,
+		evaluateSubject.StreetAddress,
+		evaluateSubject.PostalCode,
+		evaluateSubject.SerialNumber,
 	}
-
-	if len(csr.IPAddresses) > 0 {
-		var ips []string
-		for _, ip := range csr.IPAddresses {
-			ips = append(ips, ip.String())
+	for _, fn := range evaluateFns {
+		if e := fn(); e != nil {
+			el = append(el, e)
 		}
-		if allowed.IPAddresses == nil || allowed.IPAddresses.Values == nil {
-			el = append(el, field.Invalid(fldPath.Child("ipAddresses", "values"), ips, "nil"))
-		} else if !util.WildcardSubset(*allowed.IPAddresses.Values, ips) {
-			el = append(el, field.Invalid(fldPath.Child("ipAddresses", "values"), ips, strings.Join(*allowed.IPAddresses.Values, ", ")))
-		}
-	} else if allowed.IPAddresses != nil && allowed.IPAddresses.Required != nil && *allowed.IPAddresses.Required {
-		el = append(el, field.Required(fldPath.Child("ipAddresses", "required"), strconv.FormatBool(*allowed.IPAddresses.Required)))
 	}
 
-	if len(csr.URIs) > 0 {
-		var uris []string
-		for _, uri := range csr.URIs {
-			uris = append(uris, uri.String())
-		}
-		if allowed.URIs == nil || allowed.URIs.Values == nil {
-			el = append(el, field.Invalid(fldPath.Child("uris", "values"), uris, "nil"))
-		} else if !util.WildcardSubset(*allowed.URIs.Values, uris) {
-			el = append(el, field.Invalid(fldPath.Child("uris", "values"), uris, strings.Join(*allowed.URIs.Values, ", ")))
-		}
-	} else if allowed.URIs != nil && allowed.URIs.Required != nil && *allowed.URIs.Required {
-		el = append(el, field.Required(fldPath.Child("uris", "required"), strconv.FormatBool(*allowed.URIs.Required)))
+	// If there are errors, then return not approved and the aggregated errors
+	if len(el) > 0 {
+		return approver.EvaluationResponse{Result: approver.ResultDenied, Message: el.ToAggregate().Error()}, nil
 	}
 
-	if len(csr.EmailAddresses) > 0 {
-		if allowed.EmailAddresses == nil || allowed.EmailAddresses.Values == nil {
-			el = append(el, field.Invalid(fldPath.Child("emailAddresses", "values"), csr.EmailAddresses, "nil"))
-		} else if !util.WildcardSubset(*allowed.EmailAddresses.Values, csr.EmailAddresses) {
-			el = append(el, field.Invalid(fldPath.Child("emailAddresses", "values"), csr.EmailAddresses, strings.Join(*allowed.EmailAddresses.Values, ", ")))
-		}
-	} else if allowed.EmailAddresses != nil && allowed.EmailAddresses.Required != nil && *allowed.EmailAddresses.Required {
-		el = append(el, field.Required(fldPath.Child("emailAddresses", "required"), strconv.FormatBool(*allowed.EmailAddresses.Required)))
+	// If no evaluation errors resulting from this policy, return not denied
+	return approver.EvaluationResponse{Result: approver.ResultNotDenied}, nil
+}
+
+type evaluator struct {
+	request *cmapi.CertificateRequest
+	csr     *x509.CertificateRequest
+	allowed *policyapi.CertificateRequestPolicyAllowed
+	fldPath *field.Path
+}
+
+func (e evaluator) CommonName() *field.Error {
+	return evaluateString(e.csr.Subject.CommonName, e.allowed.CommonName, e.fldPath.Child("commonName"))
+}
+
+func (e evaluator) DNSNames() *field.Error {
+	return evaluateSlice(e.csr.DNSNames, e.allowed.DNSNames, e.fldPath.Child("dnsNames"))
+}
+
+func (e evaluator) IPAddresses() *field.Error {
+	var ips []string
+	for _, ip := range e.csr.IPAddresses {
+		ips = append(ips, ip.String())
 	}
+	return evaluateSlice(ips, e.allowed.IPAddresses, e.fldPath.Child("ipAddresses"))
+}
 
-	if request.Spec.IsCA {
-		if allowed.IsCA == nil {
-			el = append(el, field.Invalid(fldPath.Child("isCA"), request.Spec.IsCA, "nil"))
-		} else if !*allowed.IsCA {
-			el = append(el, field.Invalid(fldPath.Child("isCA"), request.Spec.IsCA, strconv.FormatBool(*allowed.IsCA)))
-		}
+func (e evaluator) URIs() *field.Error {
+	var uris []string
+	for _, uri := range e.csr.URIs {
+		uris = append(uris, uri.String())
 	}
+	return evaluateSlice(uris, e.allowed.URIs, e.fldPath.Child("uris"))
+}
+
+func (e evaluator) EmailAddresses() *field.Error {
+	return evaluateSlice(e.csr.EmailAddresses, e.allowed.EmailAddresses, e.fldPath.Child("emailAddresses"))
+}
 
-	if len(request.Spec.Usages) > 0 {
+func (e evaluator) IsCA() *field.Error {
+	return evaluateBool(e.request.Spec.IsCA, e.allowed.IsCA, e.fldPath.Child("isCA"))
+}
+
+func (e evaluator) Usages() *field.Error {
+	if len(e.request.Spec.Usages) > 0 {
 		var requestUsages []string
-		for _, usage := range request.Spec.Usages {
+		for _, usage := range e.request.Spec.Usages {
 			requestUsages = append(requestUsages, string(usage))
 		}
-		if allowed.Usages == nil {
-			el = append(el, field.Invalid(fldPath.Child("usages"), requestUsages, "nil"))
+		if e.allowed.Usages == nil {
+			return field.Invalid(e.fldPath.Child("usages"), requestUsages, "nil")
 		} else {
 			var policyUsages []string
-			for _, usage := range *allowed.Usages {
+			for _, usage := range *e.allowed.Usages {
 				policyUsages = append(policyUsages, string(usage))
 			}
 			if !util.WildcardSubset(policyUsages, requestUsages) {
-				el = append(el, field.Invalid(fldPath.Child("usages"), requestUsages, strings.Join(policyUsages, ", ")))
+				return field.Invalid(e.fldPath.Child("usages"), requestUsages, strings.Join(policyUsages, ", "))
 			}
 		}
 	}
+	return nil
+}
 
-	fldPath = fldPath.Child("subject")
-	allowedSub := allowed.Subject
-
-	if len(csr.Subject.Organization) > 0 {
-		if allowedSub == nil || allowedSub.Organizations == nil || allowedSub.Organizations.Values == nil {
-			el = append(el, field.Invalid(fldPath.Child("organizations", "values"), csr.Subject.Organization, "nil"))
-		} else if !util.WildcardSubset(*allowedSub.Organizations.Values, csr.Subject.Organization) {
-			el = append(el, field.Invalid(fldPath.Child("organizations", "values"), csr.Subject.Organization, strings.Join(*allowedSub.Organizations.Values, ", ")))
-		}
-	} else if allowedSub != nil && allowedSub.Organizations != nil && allowedSub.Organizations.Required != nil && *allowedSub.Organizations.Required {
-		el = append(el, field.Required(fldPath.Child("organizations", "required"), strconv.FormatBool(*allowedSub.Organizations.Required)))
+func (e evaluator) Subject() subjectEvaluator {
+	allowed := e.allowed.Subject
+	if allowed == nil {
+		allowed = new(policyapi.CertificateRequestPolicyAllowedX509Subject)
 	}
-
-	if len(csr.Subject.Country) > 0 {
-		if allowedSub == nil || allowedSub.Countries == nil {
-			el = append(el, field.Invalid(fldPath.Child("countries", "values"), csr.Subject.Country, "nil"))
-		} else if !util.WildcardSubset(*allowedSub.Countries.Values, csr.Subject.Country) {
-			el = append(el, field.Invalid(fldPath.Child("countries", "values"), csr.Subject.Country, strings.Join(*allowedSub.Countries.Values, ", ")))
-		}
-	} else if allowedSub != nil && allowedSub.Countries != nil && allowedSub.Countries.Required != nil && *allowedSub.Countries.Required {
-		el = append(el, field.Required(fldPath.Child("countries", "required"), strconv.FormatBool(*allowedSub.Countries.Required)))
+	return subjectEvaluator{
+		sub:     e.csr.Subject,
+		allowed: allowed,
+		fldPath: e.fldPath.Child("subject"),
 	}
+}
 
-	if len(csr.Subject.OrganizationalUnit) > 0 {
-		if allowedSub == nil || allowedSub.OrganizationalUnits == nil {
-			el = append(el, field.Invalid(fldPath.Child("organizationalUnits", "values"), csr.Subject.OrganizationalUnit, "nil"))
-		} else if !util.WildcardSubset(*allowedSub.OrganizationalUnits.Values, csr.Subject.OrganizationalUnit) {
-			el = append(el, field.Invalid(fldPath.Child("organizationalUnits", "values"), csr.Subject.OrganizationalUnit, strings.Join(*allowedSub.OrganizationalUnits.Values, ", ")))
-		}
-	} else if allowedSub != nil && allowedSub.OrganizationalUnits != nil && allowedSub.OrganizationalUnits.Required != nil && *allowedSub.OrganizationalUnits.Required {
-		el = append(el, field.Required(fldPath.Child("organizationalUnits", "required"), strconv.FormatBool(*allowedSub.OrganizationalUnits.Required)))
-	}
+type subjectEvaluator struct {
+	sub     pkix.Name
+	allowed *policyapi.CertificateRequestPolicyAllowedX509Subject
+	fldPath *field.Path
+}
 
-	if len(csr.Subject.Locality) > 0 {
-		if allowedSub == nil || allowedSub.Localities == nil {
-			el = append(el, field.Invalid(fldPath.Child("localities", "values"), csr.Subject.Locality, "nil"))
-		} else if !util.WildcardSubset(*allowedSub.Localities.Values, csr.Subject.Locality) {
-			el = append(el, field.Invalid(fldPath.Child("localities", "values"), csr.Subject.Locality, strings.Join(*allowedSub.Localities.Values, ", ")))
-		}
-	} else if allowedSub != nil && allowedSub.Localities != nil && allowedSub.Localities.Required != nil && *allowedSub.Localities.Required {
-		el = append(el, field.Required(fldPath.Child("localities", "required"), strconv.FormatBool(*allowedSub.Localities.Required)))
-	}
+func (e subjectEvaluator) Organization() *field.Error {
+	return evaluateSlice(e.sub.Organization, e.allowed.Organizations, e.fldPath.Child("organizations"))
+}
 
-	if len(csr.Subject.Province) > 0 {
-		if allowedSub == nil || allowedSub.Provinces == nil {
-			el = append(el, field.Invalid(fldPath.Child("provinces", "values"), csr.Subject.Province, "nil"))
-		} else if !util.WildcardSubset(*allowedSub.Provinces.Values, csr.Subject.Province) {
-			el = append(el, field.Invalid(fldPath.Child("provinces", "values"), csr.Subject.Province, strings.Join(*allowedSub.Provinces.Values, ", ")))
-		}
-	} else if allowedSub != nil && allowedSub.Provinces != nil && allowedSub.Provinces.Required != nil && *allowedSub.Provinces.Required {
-		el = append(el, field.Required(fldPath.Child("provinces", "required"), strconv.FormatBool(*allowedSub.Provinces.Required)))
-	}
+func (e subjectEvaluator) Country() *field.Error {
+	return evaluateSlice(e.sub.Country, e.allowed.Countries, e.fldPath.Child("countries"))
+}
 
-	if len(csr.Subject.StreetAddress) > 0 {
-		if allowedSub == nil || allowedSub.StreetAddresses == nil {
-			el = append(el, field.Invalid(fldPath.Child("streetAddresses", "values"), csr.Subject.StreetAddress, "nil"))
-		} else if !util.WildcardSubset(*allowedSub.StreetAddresses.Values, csr.Subject.StreetAddress) {
-			el = append(el, field.Invalid(fldPath.Child("streetAddresses", "values"), csr.Subject.StreetAddress, strings.Join(*allowedSub.StreetAddresses.Values, ", ")))
-		}
-	} else if allowedSub != nil && allowedSub.StreetAddresses != nil && allowedSub.StreetAddresses.Required != nil && *allowedSub.StreetAddresses.Required {
-		el = append(el, field.Required(fldPath.Child("streetAddresses", "required"), strconv.FormatBool(*allowedSub.StreetAddresses.Required)))
-	}
+func (e subjectEvaluator) OrganizationalUnit() *field.Error {
+	return evaluateSlice(e.sub.OrganizationalUnit, e.allowed.OrganizationalUnits, e.fldPath.Child("organizationalUnits"))
+}
+
+func (e subjectEvaluator) Locality() *field.Error {
+	return evaluateSlice(e.sub.Locality, e.allowed.Localities, e.fldPath.Child("localities"))
+}
+
+func (e subjectEvaluator) Province() *field.Error {
+	return evaluateSlice(e.sub.Province, e.allowed.Provinces, e.fldPath.Child("provinces"))
+}
+
+func (e subjectEvaluator) StreetAddress() *field.Error {
+	return evaluateSlice(e.sub.StreetAddress, e.allowed.StreetAddresses, e.fldPath.Child("streetAddresses"))
+}
+
+func (e subjectEvaluator) PostalCode() *field.Error {
+	return evaluateSlice(e.sub.PostalCode, e.allowed.PostalCodes, e.fldPath.Child("postalCodes"))
+}
+
+func (e subjectEvaluator) SerialNumber() *field.Error {
+	return evaluateString(e.sub.SerialNumber, e.allowed.SerialNumber, e.fldPath.Child("serialNumber"))
+}
 
-	if len(csr.Subject.PostalCode) > 0 {
-		if allowedSub == nil || allowedSub.PostalCodes == nil {
-			el = append(el, field.Invalid(fldPath.Child("postalCodes", "values"), csr.Subject.PostalCode, "nil"))
-		} else if !util.WildcardSubset(*allowedSub.PostalCodes.Values, csr.Subject.PostalCode) {
-			el = append(el, field.Invalid(fldPath.Child("postalCodes", "values"), csr.Subject.PostalCode, strings.Join(*allowedSub.PostalCodes.Values, ", ")))
+func evaluateString(s string, crp *policyapi.CertificateRequestPolicyAllowedString, fldPath *field.Path) *field.Error {
+	if len(s) > 0 {
+		if crp == nil || crp.Value == nil {
+			return field.Invalid(fldPath.Child("value"), s, "nil")
+		} else if !util.WildcardMatches(*crp.Value, s) {
+			return field.Invalid(fldPath.Child("value"), s, *crp.Value)
 		}
-	} else if allowedSub != nil && allowedSub.PostalCodes != nil && allowedSub.PostalCodes.Required != nil && *allowedSub.PostalCodes.Required {
-		el = append(el, field.Required(fldPath.Child("postalCodes", "required"), strconv.FormatBool(*allowedSub.PostalCodes.Required)))
+	} else if crp != nil && crp.Required != nil && *crp.Required {
+		return field.Required(fldPath.Child("required"), strconv.FormatBool(*crp.Required))
 	}
+	return nil
+}
 
-	if len(csr.Subject.SerialNumber) > 0 {
-		if allowedSub == nil || allowedSub.SerialNumber == nil {
-			el = append(el, field.Invalid(fldPath.Child("serialNumber", "value"), csr.Subject.SerialNumber, "nil"))
-		} else if !util.WildcardMatches(*allowedSub.SerialNumber.Value, csr.Subject.SerialNumber) {
-			el = append(el, field.Invalid(fldPath.Child("serialNumber", "value"), csr.Subject.SerialNumber, *allowedSub.SerialNumber.Value))
+func evaluateSlice(s []string, crp *policyapi.CertificateRequestPolicyAllowedStringSlice, fldPath *field.Path) *field.Error {
+	if len(s) > 0 {
+		if crp == nil || crp.Values == nil {
+			return field.Invalid(fldPath.Child("values"), s, "nil")
+		} else if !util.WildcardSubset(*crp.Values, s) {
+			return field.Invalid(fldPath.Child("values"), s, strings.Join(*crp.Values, ", "))
 		}
-	} else if allowedSub != nil && allowedSub.SerialNumber != nil && allowedSub.SerialNumber.Required != nil && *allowedSub.SerialNumber.Required {
-		el = append(el, field.Required(fldPath.Child("serialNumber", "required"), strconv.FormatBool(*allowedSub.SerialNumber.Required)))
+	} else if crp != nil && crp.Required != nil && *crp.Required {
+		return field.Required(fldPath.Child("required"), strconv.FormatBool(*crp.Required))
 	}
+	return nil
+}
 
-	// If there are errors, then return not approved and the aggregated errors
-	if len(el) > 0 {
-		return approver.EvaluationResponse{Result: approver.ResultDenied, Message: el.ToAggregate().Error()}, nil
+func evaluateBool(b bool, crp *bool, fldPath *field.Path) *field.Error {
+	if b {
+		if crp == nil {
+			return field.Invalid(fldPath, b, "nil")
+		} else if !*crp {
+			return field.Invalid(fldPath, b, strconv.FormatBool(*crp))
+		}
 	}
-
-	// If no evaluation errors resulting from this policy, return not denied
-	return approver.EvaluationResponse{Result: approver.ResultNotDenied}, nil
+	return nil
 }