Skip to content

Commit

Permalink
crypto/x509: add support for CertPool to load certs lazily
Browse files Browse the repository at this point in the history
This will allow building CertPools that consume less memory. (Most
certs are never accessed. Different users/programs access different
ones, but not many.)

This CL only adds the new internal mechanism (and uses it for the
old AddCert) but does not modify any existing root pool behavior.
(That is, the default Unix roots are still all slurped into memory as
of this CL)

Change-Id: Ib3a42e4050627b5e34413c595d8ced839c7bfa14
Reviewed-on: https://go-review.googlesource.com/c/go/+/229917
Run-TryBot: Brad Fitzpatrick <bradfitz@golang.org>
TryBot-Result: Go Bot <gobot@golang.org>
Trust: Brad Fitzpatrick <bradfitz@golang.org>
Trust: Roland Shoemaker <roland@golang.org>
Reviewed-by: Filippo Valsorda <filippo@golang.org>
Reviewed-by: Roland Shoemaker <roland@golang.org>
  • Loading branch information
bradfitz committed Nov 7, 2020
1 parent 2c80de7 commit e8379ab
Show file tree
Hide file tree
Showing 9 changed files with 192 additions and 56 deletions.
124 changes: 98 additions & 26 deletions src/crypto/x509/cert_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,35 +6,87 @@ package x509

import (
"bytes"
"crypto/sha256"
"encoding/pem"
"errors"
"runtime"
)

type sum224 [sha256.Size224]byte

// CertPool is a set of certificates.
type CertPool struct {
byName map[string][]int
certs []*Certificate
byName map[string][]int // cert.RawSubject => index into lazyCerts

// lazyCerts contains funcs that return a certificate,
// lazily parsing/decompressing it as needed.
lazyCerts []lazyCert

// haveSum maps from sum224(cert.Raw) to true. It's used only
// for AddCert duplicate detection, to avoid CertPool.contains
// calls in the AddCert path (because the contains method can
// call getCert and otherwise negate savings from lazy getCert
// funcs).
haveSum map[sum224]bool
}

// lazyCert is minimal metadata about a Cert and a func to retrieve it
// in its normal expanded *Certificate form.
type lazyCert struct {
// rawSubject is the Certificate.RawSubject value.
// It's the same as the CertPool.byName key, but in []byte
// form to make CertPool.Subjects (as used by crypto/tls) do
// fewer allocations.
rawSubject []byte

// getCert returns the certificate.
//
// It is not meant to do network operations or anything else
// where a failure is likely; the func is meant to lazily
// parse/decompress data that is already known to be good. The
// error in the signature primarily is meant for use in the
// case where a cert file existed on local disk when the program
// started up is deleted later before it's read.
getCert func() (*Certificate, error)
}

// NewCertPool returns a new, empty CertPool.
func NewCertPool() *CertPool {
return &CertPool{
byName: make(map[string][]int),
byName: make(map[string][]int),
haveSum: make(map[sum224]bool),
}
}

// len returns the number of certs in the set.
// A nil set is a valid empty set.
func (s *CertPool) len() int {
if s == nil {
return 0
}
return len(s.lazyCerts)
}

// cert returns cert index n in s.
func (s *CertPool) cert(n int) (*Certificate, error) {
return s.lazyCerts[n].getCert()
}

func (s *CertPool) copy() *CertPool {
p := &CertPool{
byName: make(map[string][]int, len(s.byName)),
certs: make([]*Certificate, len(s.certs)),
byName: make(map[string][]int, len(s.byName)),
lazyCerts: make([]lazyCert, len(s.lazyCerts)),
haveSum: make(map[sum224]bool, len(s.haveSum)),
}
for k, v := range s.byName {
indexes := make([]int, len(v))
copy(indexes, v)
p.byName[k] = indexes
}
copy(p.certs, s.certs)
for k := range s.haveSum {
p.haveSum[k] = true
}
copy(p.lazyCerts, s.lazyCerts)
return p
}

Expand Down Expand Up @@ -64,7 +116,7 @@ func SystemCertPool() (*CertPool, error) {

// findPotentialParents returns the indexes of certificates in s which might
// have signed cert.
func (s *CertPool) findPotentialParents(cert *Certificate) []int {
func (s *CertPool) findPotentialParents(cert *Certificate) []*Certificate {
if s == nil {
return nil
}
Expand All @@ -75,41 +127,46 @@ func (s *CertPool) findPotentialParents(cert *Certificate) []int {
// AKID and SKID match
// AKID present, SKID missing / AKID missing, SKID present
// AKID and SKID don't match
var matchingKeyID, oneKeyID, mismatchKeyID []int
var matchingKeyID, oneKeyID, mismatchKeyID []*Certificate
for _, c := range s.byName[string(cert.RawIssuer)] {
candidate := s.certs[c]
candidate, err := s.cert(c)
if err != nil {
continue
}
kidMatch := bytes.Equal(candidate.SubjectKeyId, cert.AuthorityKeyId)
switch {
case kidMatch:
matchingKeyID = append(matchingKeyID, c)
matchingKeyID = append(matchingKeyID, candidate)
case (len(candidate.SubjectKeyId) == 0 && len(cert.AuthorityKeyId) > 0) ||
(len(candidate.SubjectKeyId) > 0 && len(cert.AuthorityKeyId) == 0):
oneKeyID = append(oneKeyID, c)
oneKeyID = append(oneKeyID, candidate)
default:
mismatchKeyID = append(mismatchKeyID, c)
mismatchKeyID = append(mismatchKeyID, candidate)
}
}

found := len(matchingKeyID) + len(oneKeyID) + len(mismatchKeyID)
if found == 0 {
return nil
}
candidates := make([]int, 0, found)
candidates := make([]*Certificate, 0, found)
candidates = append(candidates, matchingKeyID...)
candidates = append(candidates, oneKeyID...)
candidates = append(candidates, mismatchKeyID...)

return candidates
}

func (s *CertPool) contains(cert *Certificate) bool {
if s == nil {
return false
}

candidates := s.byName[string(cert.RawSubject)]
for _, c := range candidates {
if s.certs[c].Equal(cert) {
for _, i := range candidates {
c, err := s.cert(i)
if err != nil {
return false
}
if c.Equal(cert) {
return true
}
}
Expand All @@ -122,17 +179,32 @@ func (s *CertPool) AddCert(cert *Certificate) {
if cert == nil {
panic("adding nil Certificate to CertPool")
}
s.addCertFunc(sha256.Sum224(cert.Raw), string(cert.RawSubject), func() (*Certificate, error) {
return cert, nil
})
}

// addCertFunc adds metadata about a certificate to a pool, along with
// a func to fetch that certificate later when needed.
//
// The rawSubject is Certificate.RawSubject and must be non-empty.
// The getCert func may be called 0 or more times.
func (s *CertPool) addCertFunc(rawSum224 sum224, rawSubject string, getCert func() (*Certificate, error)) {
if getCert == nil {
panic("getCert can't be nil")
}

// Check that the certificate isn't being added twice.
if s.contains(cert) {
if s.haveSum[rawSum224] {
return
}

n := len(s.certs)
s.certs = append(s.certs, cert)

name := string(cert.RawSubject)
s.byName[name] = append(s.byName[name], n)
s.haveSum[rawSum224] = true
s.lazyCerts = append(s.lazyCerts, lazyCert{
rawSubject: []byte(rawSubject),
getCert: getCert,
})
s.byName[rawSubject] = append(s.byName[rawSubject], len(s.lazyCerts)-1)
}

// AppendCertsFromPEM attempts to parse a series of PEM encoded certificates.
Expand Down Expand Up @@ -167,9 +239,9 @@ func (s *CertPool) AppendCertsFromPEM(pemCerts []byte) (ok bool) {
// Subjects returns a list of the DER-encoded subjects of
// all of the certificates in the pool.
func (s *CertPool) Subjects() [][]byte {
res := make([][]byte, len(s.certs))
for i, c := range s.certs {
res[i] = c.RawSubject
res := make([][]byte, s.len())
for i, lc := range s.lazyCerts {
res[i] = lc.rawSubject
}
return res
}
12 changes: 6 additions & 6 deletions src/crypto/x509/name_constraints_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1941,7 +1941,7 @@ func TestConstraintCases(t *testing.T) {
// Skip tests with CommonName set because OpenSSL will try to match it
// against name constraints, while we ignore it when it's not hostname-looking.
if !test.noOpenSSL && testNameConstraintsAgainstOpenSSL && test.leaf.cn == "" {
output, err := testChainAgainstOpenSSL(leafCert, intermediatePool, rootPool)
output, err := testChainAgainstOpenSSL(t, leafCert, intermediatePool, rootPool)
if err == nil && len(test.expectedError) > 0 {
t.Errorf("#%d: unexpectedly succeeded against OpenSSL", i)
if debugOpenSSLFailure {
Expand Down Expand Up @@ -1993,7 +1993,7 @@ func TestConstraintCases(t *testing.T) {
pem.Encode(&buf, &pem.Block{Type: "CERTIFICATE", Bytes: cert.Raw})
return buf.String()
}
t.Errorf("#%d: root:\n%s", i, certAsPEM(rootPool.certs[0]))
t.Errorf("#%d: root:\n%s", i, certAsPEM(rootPool.mustCert(t, 0)))
t.Errorf("#%d: leaf:\n%s", i, certAsPEM(leafCert))
}

Expand All @@ -2019,19 +2019,19 @@ func writePEMsToTempFile(certs []*Certificate) *os.File {
return file
}

func testChainAgainstOpenSSL(leaf *Certificate, intermediates, roots *CertPool) (string, error) {
func testChainAgainstOpenSSL(t *testing.T, leaf *Certificate, intermediates, roots *CertPool) (string, error) {
args := []string{"verify", "-no_check_time"}

rootsFile := writePEMsToTempFile(roots.certs)
rootsFile := writePEMsToTempFile(allCerts(t, roots))
if debugOpenSSLFailure {
println("roots file:", rootsFile.Name())
} else {
defer os.Remove(rootsFile.Name())
}
args = append(args, "-CAfile", rootsFile.Name())

if len(intermediates.certs) > 0 {
intermediatesFile := writePEMsToTempFile(intermediates.certs)
if intermediates.len() > 0 {
intermediatesFile := writePEMsToTempFile(allCerts(t, intermediates))
if debugOpenSSLFailure {
println("intermediates file:", intermediatesFile.Name())
} else {
Expand Down
6 changes: 5 additions & 1 deletion src/crypto/x509/root_cgo_darwin.go
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,11 @@ func _loadSystemRootsWithCgo() (*CertPool, error) {
untrustedRoots.AppendCertsFromPEM(buf)

trustedRoots := NewCertPool()
for _, c := range roots.certs {
for _, lc := range roots.lazyCerts {
c, err := lc.getCert()
if err != nil {
return nil, err
}
if !untrustedRoots.contains(c) {
trustedRoots.AddCert(c)
}
Expand Down
11 changes: 7 additions & 4 deletions src/crypto/x509/root_darwin_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ func TestSystemRoots(t *testing.T) {

// There are 174 system roots on Catalina, and 163 on iOS right now, require
// at least 100 to make sure this is not completely broken.
if want, have := 100, len(sysRoots.certs); have < want {
if want, have := 100, sysRoots.len(); have < want {
t.Errorf("want at least %d system roots, have %d", want, have)
}

Expand All @@ -43,11 +43,14 @@ func TestSystemRoots(t *testing.T) {
t.Logf("loadSystemRootsWithCgo: %v", cgoSysRootsDuration)

// Check that the two cert pools are the same.
sysPool := make(map[string]*Certificate, len(sysRoots.certs))
for _, c := range sysRoots.certs {
sysPool := make(map[string]*Certificate, sysRoots.len())
for i := 0; i < sysRoots.len(); i++ {
c := sysRoots.mustCert(t, i)
sysPool[string(c.Raw)] = c
}
for _, c := range cgoRoots.certs {
for i := 0; i < cgoRoots.len(); i++ {
c := cgoRoots.mustCert(t, i)

if _, ok := sysPool[string(c.Raw)]; ok {
delete(sysPool, string(c.Raw))
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/crypto/x509/root_unix.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func loadSystemRoots() (*CertPool, error) {
}
}

if len(roots.certs) > 0 || firstErr == nil {
if roots.len() > 0 || firstErr == nil {
return roots, nil
}

Expand Down
15 changes: 8 additions & 7 deletions src/crypto/x509/root_unix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,15 +113,15 @@ func TestEnvVars(t *testing.T) {

// Verify that the returned certs match, otherwise report where the mismatch is.
for i, cn := range tc.cns {
if i >= len(r.certs) {
if i >= r.len() {
t.Errorf("missing cert %v @ %v", cn, i)
} else if r.certs[i].Subject.CommonName != cn {
fmt.Printf("%#v\n", r.certs[0].Subject)
t.Errorf("unexpected cert common name %q, want %q", r.certs[i].Subject.CommonName, cn)
} else if r.mustCert(t, i).Subject.CommonName != cn {
fmt.Printf("%#v\n", r.mustCert(t, 0).Subject)
t.Errorf("unexpected cert common name %q, want %q", r.mustCert(t, i).Subject.CommonName, cn)
}
}
if len(r.certs) > len(tc.cns) {
t.Errorf("got %v certs, which is more than %v wanted", len(r.certs), len(tc.cns))
if r.len() > len(tc.cns) {
t.Errorf("got %v certs, which is more than %v wanted", r.len(), len(tc.cns))
}
})
}
Expand Down Expand Up @@ -197,7 +197,8 @@ func TestLoadSystemCertsLoadColonSeparatedDirs(t *testing.T) {
strCertPool := func(p *CertPool) string {
return string(bytes.Join(p.Subjects(), []byte("\n")))
}
if !reflect.DeepEqual(gotPool, wantPool) {

if !certPoolEqual(gotPool, wantPool) {
g, w := strCertPool(gotPool), strCertPool(wantPool)
t.Fatalf("Mismatched certPools\nGot:\n%s\n\nWant:\n%s", g, w)
}
Expand Down
6 changes: 5 additions & 1 deletion src/crypto/x509/root_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,11 @@ func createStoreContext(leaf *Certificate, opts *VerifyOptions) (*syscall.CertCo
}

if opts.Intermediates != nil {
for _, intermediate := range opts.Intermediates.certs {
for i := 0; i < opts.Intermediates.len(); i++ {
intermediate, err := opts.Intermediates.cert(i)
if err != nil {
return nil, err
}
ctx, err := syscall.CertCreateCertificateContext(syscall.X509_ASN_ENCODING|syscall.PKCS_7_ASN_ENCODING, &intermediate.Raw[0], uint32(len(intermediate.Raw)))
if err != nil {
return nil, err
Expand Down
20 changes: 11 additions & 9 deletions src/crypto/x509/verify.go
Original file line number Diff line number Diff line change
Expand Up @@ -761,11 +761,13 @@ func (c *Certificate) Verify(opts VerifyOptions) (chains [][]*Certificate, err e
if len(c.Raw) == 0 {
return nil, errNotParsed
}
if opts.Intermediates != nil {
for _, intermediate := range opts.Intermediates.certs {
if len(intermediate.Raw) == 0 {
return nil, errNotParsed
}
for i := 0; i < opts.Intermediates.len(); i++ {
c, err := opts.Intermediates.cert(i)
if err != nil {
return nil, fmt.Errorf("crypto/x509: error fetching intermediate: %w", err)
}
if len(c.Raw) == 0 {
return nil, errNotParsed
}
}

Expand Down Expand Up @@ -891,11 +893,11 @@ func (c *Certificate) buildChains(cache map[*Certificate][][]*Certificate, curre
}
}

for _, rootNum := range opts.Roots.findPotentialParents(c) {
considerCandidate(rootCertificate, opts.Roots.certs[rootNum])
for _, root := range opts.Roots.findPotentialParents(c) {
considerCandidate(rootCertificate, root)
}
for _, intermediateNum := range opts.Intermediates.findPotentialParents(c) {
considerCandidate(intermediateCertificate, opts.Intermediates.certs[intermediateNum])
for _, intermediate := range opts.Intermediates.findPotentialParents(c) {
considerCandidate(intermediateCertificate, intermediate)
}

if len(chains) > 0 {
Expand Down

0 comments on commit e8379ab

Please sign in to comment.