Skip to content

Commit

Permalink
Move verify func to a separate function, make subject field a type, s…
Browse files Browse the repository at this point in the history
…tyle
  • Loading branch information
radekg committed Sep 16, 2020
1 parent 8177347 commit 7f01d78
Showing 1 changed file with 38 additions and 27 deletions.
65 changes: 38 additions & 27 deletions proxy/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,17 @@ import (
"github.com/pkg/errors"
)

type clientCertSubjectField string

const (
clientCertSubjectCommonName = "CN"
clientCertSubjectCountry = "C"
clientCertSubjectProvince = "S"
clientCertSubjectLocality = "L"
clientCertSubjectOrganization = "O"
clientCertSubjectOrganizationalUnit = "OU"
)

var (
defaultCurvePreferences = []tls.CurveID{
tls.CurveP256,
Expand Down Expand Up @@ -123,46 +134,52 @@ func newTLSListenerConfig(conf *config.Config) (*tls.Config, error) {
cfg.ClientAuth = tls.RequireAndVerifyClientCert
}

cfg.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
cfg.VerifyPeerCertificate = tlsClientCertVerificationFunc(conf)

return cfg, nil
}

func tlsClientCertVerificationFunc(conf *config.Config) func([][]byte, [][]*x509.Certificate) error {
return func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
if conf.Proxy.TLS.ClientCert.ValidateSubject {

expectedFields := map[string]string{}
expectedFields := map[clientCertSubjectField]string{}
expectedParts := []string{"s:"}
values := []string{}

if conf.Proxy.TLS.ClientCert.Subject.CommonName != "" {
expectedFields["CN"] = conf.Proxy.TLS.ClientCert.Subject.CommonName
expectedParts = append(expectedParts, fmt.Sprintf("%s=%s", "CN", expectedFields["CN"]))
expectedFields[clientCertSubjectCommonName] = conf.Proxy.TLS.ClientCert.Subject.CommonName
expectedParts = append(expectedParts, fmt.Sprintf("%s=%s", clientCertSubjectCommonName, expectedFields[clientCertSubjectCommonName]))
}
values = removeEmptyStrings(conf.Proxy.TLS.ClientCert.Subject.Country)
if len(values) > 0 {
sort.Strings(values)
expectedFields["C"] = fmt.Sprintf("%v", values)
expectedParts = append(expectedParts, fmt.Sprintf("%s=%s", "C", expectedFields["C"]))
expectedFields[clientCertSubjectCountry] = fmt.Sprintf("%v", values)
expectedParts = append(expectedParts, fmt.Sprintf("%s=%s", clientCertSubjectCountry, expectedFields[clientCertSubjectCountry]))
}
values = removeEmptyStrings(conf.Proxy.TLS.ClientCert.Subject.Province)
if len(values) > 0 {
sort.Strings(values)
expectedFields["S"] = fmt.Sprintf("%v", values)
expectedParts = append(expectedParts, fmt.Sprintf("%s=%s", "S", expectedFields["S"]))
expectedFields[clientCertSubjectProvince] = fmt.Sprintf("%v", values)
expectedParts = append(expectedParts, fmt.Sprintf("%s=%s", clientCertSubjectProvince, expectedFields[clientCertSubjectProvince]))
}
values = removeEmptyStrings(conf.Proxy.TLS.ClientCert.Subject.Locality)
if len(values) > 0 {
sort.Strings(values)
expectedFields["L"] = fmt.Sprintf("%v", values)
expectedParts = append(expectedParts, fmt.Sprintf("%s=%s", "L", expectedFields["L"]))
expectedFields[clientCertSubjectLocality] = fmt.Sprintf("%v", values)
expectedParts = append(expectedParts, fmt.Sprintf("%s=%s", clientCertSubjectLocality, expectedFields[clientCertSubjectLocality]))
}
values = removeEmptyStrings(conf.Proxy.TLS.ClientCert.Subject.Organization)
if len(values) > 0 {
sort.Strings(values)
expectedFields["O"] = fmt.Sprintf("%v", values)
expectedParts = append(expectedParts, fmt.Sprintf("%s=%s", "O", expectedFields["O"]))
expectedFields[clientCertSubjectOrganization] = fmt.Sprintf("%v", values)
expectedParts = append(expectedParts, fmt.Sprintf("%s=%s", clientCertSubjectOrganization, expectedFields[clientCertSubjectOrganization]))
}
values = removeEmptyStrings(conf.Proxy.TLS.ClientCert.Subject.OrganizationalUnit)
if len(values) > 0 {
sort.Strings(values)
expectedFields["OU"] = fmt.Sprintf("%v", values)
expectedParts = append(expectedParts, fmt.Sprintf("%s=%s", "OU", expectedFields["OU"]))
expectedFields[clientCertSubjectOrganizationalUnit] = fmt.Sprintf("%v", values)
expectedParts = append(expectedParts, fmt.Sprintf("%s=%s", clientCertSubjectOrganizationalUnit, expectedFields[clientCertSubjectOrganizationalUnit]))
}

if len(expectedFields) == 0 {
Expand All @@ -175,45 +192,41 @@ func newTLSListenerConfig(conf *config.Config) (*tls.Config, error) {
certificateAcceptable := true

for k, v := range expectedFields {
if k == "CN" {
switch k {
case clientCertSubjectCommonName:
if v != cert.Subject.CommonName {
certificateAcceptable = false
break
}
}
if k == "C" {
case clientCertSubjectCountry:
currentValues := cert.Subject.Country
sort.Strings(currentValues)
if fmt.Sprintf("%v", currentValues) != v {
certificateAcceptable = false
break
}
}
if k == "S" {
case clientCertSubjectProvince:
currentValues := cert.Subject.Province
sort.Strings(currentValues)
if fmt.Sprintf("%v", currentValues) != v {
certificateAcceptable = false
break
}
}
if k == "L" {
case clientCertSubjectLocality:
currentValues := cert.Subject.Locality
sort.Strings(currentValues)
if fmt.Sprintf("%v", currentValues) != v {
certificateAcceptable = false
break
}
}
if k == "O" {
case clientCertSubjectOrganization:
currentValues := cert.Subject.Organization
sort.Strings(currentValues)
if fmt.Sprintf("%v", currentValues) != v {
certificateAcceptable = false
break
}
}
if k == "OU" {
case clientCertSubjectOrganizationalUnit:
currentValues := cert.Subject.OrganizationalUnit
sort.Strings(currentValues)
if fmt.Sprintf("%v", currentValues) != v {
Expand All @@ -235,8 +248,6 @@ func newTLSListenerConfig(conf *config.Config) (*tls.Config, error) {
}
return nil
}

return cfg, nil
}

func removeEmptyStrings(input []string) []string {
Expand Down

0 comments on commit 7f01d78

Please sign in to comment.