Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

webtransport: use deterministic TLS certificates #1833

Merged
merged 18 commits into from
Nov 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions p2p/test/webtransport/webtransport_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package webtransport_test

import (
"testing"
"time"

"github.com/benbjohnson/clock"
"github.com/libp2p/go-libp2p"
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/test"
libp2pwebtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport"
ma "github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/require"
)

func extractCertHashes(addr ma.Multiaddr) []string {
var certHashesStr []string
ma.ForEach(addr, func(c ma.Component) bool {
if c.Protocol().Code == ma.P_CERTHASH {
certHashesStr = append(certHashesStr, c.Value())
}
return true
})
return certHashesStr
}

func TestDeterministicCertsAfterReboot(t *testing.T) {
priv, _, err := test.RandTestKeyPair(ic.Ed25519, 256)
require.NoError(t, err)

cl := clock.NewMock()
// Move one year ahead to avoid edge cases around epoch
cl.Add(time.Hour * 24 * 365)
h, err := libp2p.New(libp2p.NoTransports, libp2p.Transport(libp2pwebtransport.New, libp2pwebtransport.WithClock(cl)), libp2p.Identity(priv))
require.NoError(t, err)
err = h.Network().Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport"))
require.NoError(t, err)

prevCerthashes := extractCertHashes(h.Addrs()[0])
h.Close()

h, err = libp2p.New(libp2p.NoTransports, libp2p.Transport(libp2pwebtransport.New, libp2pwebtransport.WithClock(cl)), libp2p.Identity(priv))
require.NoError(t, err)
defer h.Close()
err = h.Network().Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport"))
require.NoError(t, err)

nextCertHashes := extractCertHashes(h.Addrs()[0])

for i := range prevCerthashes {
require.Equal(t, prevCerthashes[i], nextCertHashes[i])
}
}
57 changes: 43 additions & 14 deletions p2p/transport/webtransport/cert_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import (
"context"
"crypto/sha256"
"crypto/tls"
"encoding/binary"
"fmt"
"sync"
"time"

"github.com/benbjohnson/clock"
ic "github.com/libp2p/go-libp2p/core/crypto"
ma "github.com/multiformats/go-multiaddr"
"github.com/multiformats/go-multihash"
)
Expand All @@ -17,6 +19,7 @@ import (
// When we generate a certificate, the NotBefore time is set to clockSkewAllowance before the current time.
// Similarly, we stop using a certificate one clockSkewAllowance before its expiry time.
const clockSkewAllowance = time.Hour
const validityMinusTwoSkew = certValidity - (2 * clockSkewAllowance)

type certConfig struct {
tlsConf *tls.Config
Expand All @@ -26,8 +29,8 @@ type certConfig struct {
func (c *certConfig) Start() time.Time { return c.tlsConf.Certificates[0].Leaf.NotBefore }
func (c *certConfig) End() time.Time { return c.tlsConf.Certificates[0].Leaf.NotAfter }

func newCertConfig(start, end time.Time) (*certConfig, error) {
conf, err := getTLSConf(start, end)
func newCertConfig(key ic.PrivKey, start, end time.Time) (*certConfig, error) {
conf, err := getTLSConf(key, start, end)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -57,32 +60,58 @@ type certManager struct {
serializedCertHashes [][]byte
}

func newCertManager(clock clock.Clock) (*certManager, error) {
func newCertManager(hostKey ic.PrivKey, clock clock.Clock) (*certManager, error) {
m := &certManager{clock: clock}
m.ctx, m.ctxCancel = context.WithCancel(context.Background())
if err := m.init(); err != nil {
if err := m.init(hostKey); err != nil {
return nil, err
}

m.background()
m.background(hostKey)
return m, nil
}

func (m *certManager) init() error {
start := m.clock.Now().Add(-clockSkewAllowance)
var err error
m.nextConfig, err = newCertConfig(start, start.Add(certValidity))
// getCurrentTimeBucket returns the canonical start time of the given time as
// bucketed by ranges of certValidity since unix epoch (plus an offset). This
// lets you get the same time ranges across reboots without having to persist
// state.
// ```
// ... v--- epoch + offset
// ... |--------| |--------| ...
// ... |--------| |--------| ...
// ```
func getCurrentBucketStartTime(now time.Time, offset time.Duration) time.Time {
currentBucket := (now.UnixMilli() - offset.Milliseconds()) / validityMinusTwoSkew.Milliseconds()
return time.UnixMilli(offset.Milliseconds() + currentBucket*validityMinusTwoSkew.Milliseconds())
}

func (m *certManager) init(hostKey ic.PrivKey) error {
start := m.clock.Now()
pubkeyBytes, err := hostKey.GetPublic().Raw()
if err != nil {
return err
}

// We want to add a random offset to each start time so that not all certs
// rotate at the same time across the network. The offset represents moving
// the bucket start time some `offset` earlier.
offset := (time.Duration(binary.LittleEndian.Uint16(pubkeyBytes)) * time.Minute) % certValidity

// We want the certificate have been valid for at least one clockSkewAllowance
start = start.Add(-clockSkewAllowance)
startTime := getCurrentBucketStartTime(start, offset)
m.nextConfig, err = newCertConfig(hostKey, startTime, startTime.Add(certValidity))
if err != nil {
return err
}
return m.rollConfig()
return m.rollConfig(hostKey)
}

func (m *certManager) rollConfig() error {
func (m *certManager) rollConfig(hostKey ic.PrivKey) error {
// We stop using the current certificate clockSkewAllowance before its expiry time.
// At this point, the next certificate needs to be valid for one clockSkewAllowance.
nextStart := m.nextConfig.End().Add(-2 * clockSkewAllowance)
c, err := newCertConfig(nextStart, nextStart.Add(certValidity))
c, err := newCertConfig(hostKey, nextStart, nextStart.Add(certValidity))
if err != nil {
return err
}
Expand All @@ -95,7 +124,7 @@ func (m *certManager) rollConfig() error {
return m.cacheAddrComponent()
}

func (m *certManager) background() {
func (m *certManager) background(hostKey ic.PrivKey) {
d := m.currentConfig.End().Add(-clockSkewAllowance).Sub(m.clock.Now())
log.Debugw("setting timer", "duration", d.String())
t := m.clock.Timer(d)
Expand All @@ -111,7 +140,7 @@ func (m *certManager) background() {
return
case now := <-t.C:
m.mx.Lock()
if err := m.rollConfig(); err != nil {
if err := m.rollConfig(hostKey); err != nil {
log.Errorw("rolling config failed", "error", err)
}
d := m.currentConfig.End().Add(-clockSkewAllowance).Sub(now)
Expand Down
82 changes: 78 additions & 4 deletions p2p/transport/webtransport/cert_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@ package libp2pwebtransport
import (
"crypto/sha256"
"crypto/tls"
"fmt"
"testing"
"testing/quick"
"time"

"github.com/benbjohnson/clock"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/test"
ma "github.com/multiformats/go-multiaddr"
"github.com/multiformats/go-multibase"
"github.com/multiformats/go-multihash"
Expand Down Expand Up @@ -39,14 +43,16 @@ func certHashFromComponent(t *testing.T, comp ma.Component) []byte {
func TestInitialCert(t *testing.T) {
cl := clock.NewMock()
cl.Add(1234567 * time.Hour)
m, err := newCertManager(cl)
priv, _, err := test.RandTestKeyPair(crypto.Ed25519, 256)
require.NoError(t, err)
m, err := newCertManager(priv, cl)
require.NoError(t, err)
defer m.Close()

conf := m.GetConfig()
require.Len(t, conf.Certificates, 1)
cert := conf.Certificates[0]
require.Equal(t, cl.Now().Add(-clockSkewAllowance).UTC(), cert.Leaf.NotBefore)
require.GreaterOrEqual(t, cl.Now().Add(-clockSkewAllowance), cert.Leaf.NotBefore)
require.Equal(t, cert.Leaf.NotBefore.Add(certValidity), cert.Leaf.NotAfter)
addr := m.AddrComponent()
components := splitMultiaddr(addr)
Expand All @@ -59,7 +65,11 @@ func TestInitialCert(t *testing.T) {

func TestCertRenewal(t *testing.T) {
cl := clock.NewMock()
m, err := newCertManager(cl)
// Add a year to avoid edge cases around the epoch
cl.Add(time.Hour * 24 * 365)
priv, _, err := test.SeededTestKeyPair(crypto.Ed25519, 256, 0)
require.NoError(t, err)
m, err := newCertManager(priv, cl)
require.NoError(t, err)
defer m.Close()

Expand All @@ -68,7 +78,7 @@ func TestCertRenewal(t *testing.T) {
require.Len(t, first, 2)
require.NotEqual(t, first[0].Value(), first[1].Value(), "the hashes should differ")
// wait for a new certificate to be generated
cl.Add(certValidity - 2*clockSkewAllowance - time.Second)
cl.Set(m.currentConfig.End().Add(-(clockSkewAllowance + time.Second)))
require.Never(t, func() bool {
for i, c := range splitMultiaddr(m.AddrComponent()) {
if c.Value() != first[i].Value() {
Expand Down Expand Up @@ -100,3 +110,67 @@ func TestCertRenewal(t *testing.T) {
// check that the 2nd certificate from the beginning was rolled over to be the 1st certificate
require.Equal(t, second[1].Value(), third[0].Value())
}

func TestDeterministicCertsAcrossReboots(t *testing.T) {
marten-seemann marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to skip on OSX?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The MacOS issue was around listening, this should be okay (after using a deterministic seed, since otherwise we have random offsets and adding an hour might get us over the border).

// Run this test 100 times to make sure it's deterministic
runs := 100
for i := 0; i < runs; i++ {
t.Run(fmt.Sprintf("Run=%d", i), func(t *testing.T) {
cl := clock.NewMock()
priv, _, err := test.SeededTestKeyPair(crypto.Ed25519, 256, 0)
require.NoError(t, err)
m, err := newCertManager(priv, cl)
require.NoError(t, err)
defer m.Close()

conf := m.GetConfig()
require.Len(t, conf.Certificates, 1)
oldCerts := m.serializedCertHashes

m.Close()

cl.Add(time.Hour)
// reboot
m, err = newCertManager(priv, cl)
require.NoError(t, err)
defer m.Close()

newCerts := m.serializedCertHashes

require.Equal(t, oldCerts, newCerts)
})
}
}

func TestDeterministicTimeBuckets(t *testing.T) {
cl := clock.NewMock()
cl.Add(time.Hour * 24 * 365)
startA := getCurrentBucketStartTime(cl.Now(), 0)
startB := getCurrentBucketStartTime(cl.Now().Add(time.Hour*24), 0)
require.Equal(t, startA, startB)

// 15 Days later
startC := getCurrentBucketStartTime(cl.Now().Add(time.Hour*24*15), 0)
require.NotEqual(t, startC, startB)
}

func TestGetCurrentBucketStartTimeIsWithinBounds(t *testing.T) {
require.NoError(t, quick.Check(func(timeSinceUnixEpoch time.Duration, offset time.Duration) bool {
if offset < 0 {
offset = -offset
}
if timeSinceUnixEpoch < 0 {
timeSinceUnixEpoch = -timeSinceUnixEpoch
}

offset = offset % certValidity
// Bound this to 100 years
timeSinceUnixEpoch = time.Duration(timeSinceUnixEpoch % (time.Hour * 24 * 365 * 100))
// Start a bit further in the future to avoid edge cases around epoch
timeSinceUnixEpoch += time.Hour * 24 * 365
start := time.UnixMilli(timeSinceUnixEpoch.Milliseconds())

bucketStart := getCurrentBucketStartTime(start.Add(-clockSkewAllowance), offset)
return !bucketStart.After(start.Add(-clockSkewAllowance)) || bucketStart.Equal(start.Add(-clockSkewAllowance))
}, nil))
}
58 changes: 51 additions & 7 deletions p2p/transport/webtransport/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,27 @@ import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/binary"
"errors"
"fmt"
"io"
"math/big"
"time"

ic "github.com/libp2p/go-libp2p/core/crypto"

"github.com/multiformats/go-multihash"
"golang.org/x/crypto/hkdf"
)

func getTLSConf(start, end time.Time) (*tls.Config, error) {
cert, priv, err := generateCert(start, end)
const deterministicCertInfo = "determinisitic cert"

func getTLSConf(key ic.PrivKey, start, end time.Time) (*tls.Config, error) {
cert, priv, err := generateCert(key, start, end)
if err != nil {
return nil, err
}
Expand All @@ -32,9 +37,20 @@ func getTLSConf(start, end time.Time) (*tls.Config, error) {
}, nil
}

func generateCert(start, end time.Time) (*x509.Certificate, *ecdsa.PrivateKey, error) {
// generateCert generates certs deterministically based on the `key` and start
// time passed in. Uses `golang.org/x/crypto/hkdf`.
func generateCert(key ic.PrivKey, start, end time.Time) (*x509.Certificate, *ecdsa.PrivateKey, error) {
keyBytes, err := key.Raw()
if err != nil {
return nil, nil, err
}

startTimeSalt := make([]byte, 8)
binary.LittleEndian.PutUint64(startTimeSalt, uint64(start.UnixNano()))
deterministicHKDFReader := newDeterministicReader(keyBytes, startTimeSalt, deterministicCertInfo)

b := make([]byte, 8)
if _, err := rand.Read(b); err != nil {
if _, err := deterministicHKDFReader.Read(b); err != nil {
return nil, nil, err
}
serial := int64(binary.BigEndian.Uint64(b))
Expand All @@ -51,11 +67,12 @@ func generateCert(start, end time.Time) (*x509.Certificate, *ecdsa.PrivateKey, e
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
caPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)

caPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), deterministicHKDFReader)
if err != nil {
return nil, nil, err
}
caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, &caPrivateKey.PublicKey, caPrivateKey)
caBytes, err := x509.CreateCertificate(deterministicHKDFReader, certTempl, certTempl, caPrivateKey.Public(), caPrivateKey)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -106,3 +123,30 @@ func verifyRawCerts(rawCerts [][]byte, certHashes []multihash.DecodedMultihash)
}
return nil
}

// deterministicReader is a hack. It counter-acts the Go library's attempt at
// making ECDSA signatures non-deterministic. Go adds non-determinism by
// randomly dropping a singly byte from the reader stream. This counteracts this
// by detecting when a read is a single byte and using a different reader
// instead.
type deterministicReader struct {
reader io.Reader
singleByteReader io.Reader
}

func newDeterministicReader(seed []byte, salt []byte, info string) io.Reader {
reader := hkdf.New(sha256.New, seed, salt, []byte(info))
singleByteReader := hkdf.New(sha256.New, seed, salt, []byte(info+" single byte"))

return &deterministicReader{
reader: reader,
singleByteReader: singleByteReader,
}
}

func (r *deterministicReader) Read(p []byte) (n int, err error) {
if len(p) == 1 {
return r.singleByteReader.Read(p)
}
return r.reader.Read(p)
}
Loading