Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add directory Cache implementation #9

Merged
merged 4 commits into from
Mar 27, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
119 changes: 119 additions & 0 deletions cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@ package certify
import (
"context"
"crypto/tls"
"encoding/gob"
"errors"
"io/ioutil"
"os"
"path/filepath"
"sync"
)

Expand Down Expand Up @@ -66,3 +70,118 @@ func (m *memCache) Delete(_ context.Context, key string) error {
delete(m.cache, key)
return nil
}

// DirCache implements Cache using a directory on the local filesystem.
// If the directory does not exist, it will be created with 0700 permissions.
//
// It is strongly based on the acme/autocert DirCache type.
// https://github.com/golang/crypto/blob/88942b9c40a4c9d203b82b3731787b672d6e809b/acme/autocert/cache.go#L40
type DirCache string

// Get reads a certificate data from the specified file name.
func (d DirCache) Get(ctx context.Context, name string) (*tls.Certificate, error) {
name = filepath.Join(string(d), name)

var (
cert tls.Certificate
err error
done = make(chan struct{})
)

go func() {
defer close(done)
var f *os.File
f, err = os.Open(name)
if err != nil {
return
}
err = gob.NewDecoder(f).Decode(&cert)
_ = f.Close()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any reason too keep the assignment?

Copy link
Owner Author

@johanbrandhorst johanbrandhorst Mar 27, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like showing that we're explicitly ignoring the error rather than just calling the function without assignment, which could be mistaken for accidentally ignoring the error.

}()

select {
case <-ctx.Done():
return nil, ctx.Err()
case <-done:
}

if os.IsNotExist(err) {
return nil, ErrCacheMiss
}
if err != nil {
return nil, err
}

return &cert, nil
}

// Put writes the certificate data to the specified file name.
// The file will be created with 0600 permissions.
func (d DirCache) Put(ctx context.Context, name string, cert *tls.Certificate) error {
if err := os.MkdirAll(string(d), 0700); err != nil {
return err
}

done := make(chan struct{})
var err error
go func() {
defer close(done)

var tmp string
if tmp, err = d.writeTempFile(name, cert); err != nil {
return
}

select {
case <-ctx.Done():
// Don't overwrite the file if the context was canceled.
default:
newName := filepath.Join(string(d), name)
err = os.Rename(tmp, newName)
}
}()

select {
case <-ctx.Done():
return ctx.Err()
case <-done:
}

return err
}

// Delete removes the specified file name.
func (d DirCache) Delete(ctx context.Context, name string) error {
name = filepath.Join(string(d), name)
var (
err error
done = make(chan struct{})
)
go func() {
err = os.Remove(name)
close(done)
}()
select {
case <-ctx.Done():
return ctx.Err()
case <-done:
}
if err != nil && !os.IsNotExist(err) {
return err
}
return nil
}

// writeTempFile writes b to a temporary file, closes the file and returns its path.
func (d DirCache) writeTempFile(prefix string, cert *tls.Certificate) (string, error) {
// TempFile uses 0600 permissions
f, err := ioutil.TempFile(string(d), prefix)
if err != nil {
return "", err
}
if err = gob.NewEncoder(f).Encode(cert); err != nil {
_ = f.Close()
return "", err
}
return f.Name(), f.Close()
}
132 changes: 73 additions & 59 deletions certify_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ import (

//go:generate protoc --go_out=plugins=grpc:./ ./proto/test.proto

func mustMakeTempDir() string {
n, err := ioutil.TempDir("", "")
if err != nil {
panic(err)
}
return n
}

var _ = Describe("Certify", func() {
var issuer certify.Issuer

Expand Down Expand Up @@ -63,7 +71,7 @@ var _ = Describe("Certify", func() {
Expect(cert1.Leaf.IPAddresses).To(HaveLen(1))
Expect(cert1.Leaf.IPAddresses[0].Equal(cli.CertConfig.IPSubjectAlternativeNames[0])).To(BeTrue())
Expect(cert1.Leaf.NotBefore).To(BeTemporally("<", time.Now()))
Expect(cert1.Leaf.NotAfter).To(BeTemporally("~", time.Now().Add(cli.CertConfig.TimeToLive), 2*time.Second))
Expect(cert1.Leaf.NotAfter).To(BeTemporally("~", time.Now().Add(cli.CertConfig.TimeToLive), 5*time.Second))
Expect(cert1.Leaf.Issuer.SerialNumber).To(Equal(caCert.Subject.SerialNumber))

cert2, err := cli.GetClientCertificate(nil)
Expand All @@ -76,7 +84,7 @@ var _ = Describe("Certify", func() {
Expect(cert2.Leaf.IPAddresses).To(HaveLen(1))
Expect(cert2.Leaf.IPAddresses[0].Equal(cli.CertConfig.IPSubjectAlternativeNames[0])).To(BeTrue())
Expect(cert2.Leaf.NotBefore).To(BeTemporally("<", time.Now()))
Expect(cert2.Leaf.NotAfter).To(BeTemporally("~", time.Now().Add(cli.CertConfig.TimeToLive), 2*time.Second))
Expect(cert2.Leaf.NotAfter).To(BeTemporally("~", time.Now().Add(cli.CertConfig.TimeToLive), 5*time.Second))
Expect(cert2.Leaf.Issuer.SerialNumber).To(Equal(caCert.Subject.SerialNumber))
})

Expand Down Expand Up @@ -128,7 +136,7 @@ var _ = Describe("Certify", func() {
Expect(cert1.Leaf.Subject.CommonName).To(Equal(cli.CommonName))
Expect(cert1.Leaf.DNSNames).To(ConsistOf(cli.CommonName))
Expect(cert1.Leaf.NotBefore).To(BeTemporally("<", time.Now()))
Expect(cert1.Leaf.NotAfter).To(BeTemporally("~", time.Now().Add(cli.CertConfig.TimeToLive), 2*time.Second))
Expect(cert1.Leaf.NotAfter).To(BeTemporally("~", time.Now().Add(cli.CertConfig.TimeToLive), 5*time.Second))
Expect(cert1.Leaf.Issuer.SerialNumber).To(Equal(caCert.Subject.SerialNumber))

cert2, err := cli.GetClientCertificate(nil)
Expand All @@ -139,7 +147,7 @@ var _ = Describe("Certify", func() {
Expect(cert2.Leaf.Subject.CommonName).To(Equal(cli.CommonName))
Expect(cert2.Leaf.DNSNames).To(Equal(append(cli.CertConfig.SubjectAlternativeNames, cli.CommonName)))
Expect(cert2.Leaf.NotBefore).To(BeTemporally("<", time.Now()))
Expect(cert2.Leaf.NotAfter).To(BeTemporally("~", time.Now().Add(cli.CertConfig.TimeToLive), 2*time.Second))
Expect(cert2.Leaf.NotAfter).To(BeTemporally("~", time.Now().Add(cli.CertConfig.TimeToLive), 5*time.Second))
Expect(cert2.Leaf.Issuer.SerialNumber).To(Equal(caCert.Subject.SerialNumber))
})
})
Expand All @@ -148,72 +156,78 @@ var _ = Describe("Certify", func() {
})

var _ = Describe("The Cache", func() {
var c certify.Cache

Context("when using the memcache", func() {
BeforeEach(func() {
c = certify.NewMemCache()
})

Context("after putting in a certificate", func() {
It("allows a user to get and delete it", func() {
cert := &tls.Certificate{
Leaf: &x509.Certificate{
IsCA: true,
},
}
Expect(c.Put(context.Background(), "key1", cert)).To(Succeed())
Expect(c.Get(context.Background(), "key1")).To(Equal(cert))
Expect(c.Delete(context.Background(), "key1")).To(Succeed())
_, err := c.Get(context.Background(), "key1")
Expect(err).To(Equal(certify.ErrCacheMiss))
caches := []struct {
Type string
Cache certify.Cache
}{
{Type: "MemCache", Cache: certify.NewMemCache()},
{Type: "DirCache", Cache: certify.DirCache(mustMakeTempDir())},
}

for _, cache := range caches {
c := cache
Context("when using a "+c.Type, func() {
Context("after putting in a certificate", func() {
It("allows a user to get and delete it", func() {
cert := &tls.Certificate{
Leaf: &x509.Certificate{
IsCA: true,
},
}
Expect(c.Cache.Put(context.Background(), "key1", cert)).To(Succeed())
Expect(c.Cache.Get(context.Background(), "key1")).To(Equal(cert))
Expect(c.Cache.Delete(context.Background(), "key1")).To(Succeed())
_, err := c.Cache.Get(context.Background(), "key1")
Expect(err).To(Equal(certify.ErrCacheMiss))
})
})
})

Context("when getting a key that doesn't exist", func() {
It("returns ErrCacheMiss", func() {
_, err := c.Get(context.Background(), "key1")
Expect(err).To(Equal(certify.ErrCacheMiss))
Context("when getting a key that doesn't exist", func() {
It("returns ErrCacheMiss", func() {
_, err := c.Cache.Get(context.Background(), "key1")
Expect(err).To(Equal(certify.ErrCacheMiss))
})
})
})

Context("when deleting a key that doesn't exist", func() {
It("does not return an error", func() {
Expect(c.Delete(context.Background(), "key1")).To(Succeed())
Context("when deleting a key that doesn't exist", func() {
It("does not return an error", func() {
Expect(c.Cache.Delete(context.Background(), "key1")).To(Succeed())
})
})
})

Context("when accessing the cache concurrently", func() {
It("does not cause any race conditions", func() {
start := make(chan struct{})
wg := sync.WaitGroup{}
key := "key1"
Context("when accessing the cache concurrently", func() {
It("does not cause any race conditions", func() {
start := make(chan struct{})
wg := sync.WaitGroup{}
key := "key1"

cert := &tls.Certificate{
Leaf: &x509.Certificate{
IsCA: true,
},
}
cert := &tls.Certificate{
Leaf: &x509.Certificate{
IsCA: true,
},
}

for i := 0; i < 3; i++ {
wg.Add(1)
go func() {
defer wg.Done()
defer GinkgoRecover()

Eventually(start).Should(BeClosed())
Expect(c.Delete(context.Background(), key)).To(Succeed())
Expect(c.Put(context.Background(), key, cert)).To(Succeed())
Expect(c.Get(context.Background(), key)).NotTo(BeNil())
}()
}
for i := 0; i < 3; i++ {
wg.Add(1)
go func() {
defer wg.Done()
defer GinkgoRecover()

Eventually(start).Should(BeClosed())
Expect(c.Cache.Put(context.Background(), key, cert)).To(Succeed())
Expect(c.Cache.Get(context.Background(), key)).NotTo(BeNil())
}()
}

// Synchronize goroutines
close(start)
wg.Wait()
// Synchronize goroutines
close(start)
wg.Wait()

Expect(c.Cache.Delete(context.Background(), key)).To(Succeed())
})
})
})
})
}
})

type backend struct{}
Expand Down
2 changes: 1 addition & 1 deletion vault.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func connect(

dl, ok := ctx.Deadline()
if ok {
vConf.Timeout = dl.Sub(time.Now())
vConf.Timeout = time.Until(dl)
}
vConf.Address = vaultURL.String()
cli, err := api.NewClient(vConf)
Expand Down