Skip to content

Commit

Permalink
webtransport: use deterministic TLS certificates (#1833)
Browse files Browse the repository at this point in the history
* Use deterministic TLS certificates for webtransport

* Update test to work with buckets

* Make sure to overlap and use a random offset

* Fixup mistaken change in other test

* Add QuickCheck tests for cert behavior

* Lint fix

* Add more tests

* Add webtransport integration test

* Use same key

* Actually offset by at least clockSkew

* Use seeded key for certs after reboot test

* PR comments

* Remove debug code

* Fix calculation for cert having been valid

Fixes the logic that a cert has been valid for a clockSkew by
subtracting the clockSkew from the start time rather than incorporating
it into the offset. The offset should be used to shift the buckets.

* Update comment

* Lint fix

* Update TestGetCurrentBucketStartTimeIsWithinBounds to include clockSkew calculation

* Rebase fixes
  • Loading branch information
MarcoPolo committed Nov 14, 2022
1 parent c48e78f commit a0432e7
Show file tree
Hide file tree
Showing 7 changed files with 493 additions and 26 deletions.
53 changes: 53 additions & 0 deletions p2p/test/webtransport/webtransport_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package webtransport_test

import (
"testing"
"time"

"github.com/benbjohnson/clock"
"github.com/libp2p/go-libp2p"
ic "github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/test"
libp2pwebtransport "github.com/libp2p/go-libp2p/p2p/transport/webtransport"
ma "github.com/multiformats/go-multiaddr"
"github.com/stretchr/testify/require"
)

func extractCertHashes(addr ma.Multiaddr) []string {
var certHashesStr []string
ma.ForEach(addr, func(c ma.Component) bool {
if c.Protocol().Code == ma.P_CERTHASH {
certHashesStr = append(certHashesStr, c.Value())
}
return true
})
return certHashesStr
}

func TestDeterministicCertsAfterReboot(t *testing.T) {
priv, _, err := test.RandTestKeyPair(ic.Ed25519, 256)
require.NoError(t, err)

cl := clock.NewMock()
// Move one year ahead to avoid edge cases around epoch
cl.Add(time.Hour * 24 * 365)
h, err := libp2p.New(libp2p.NoTransports, libp2p.Transport(libp2pwebtransport.New, libp2pwebtransport.WithClock(cl)), libp2p.Identity(priv))
require.NoError(t, err)
err = h.Network().Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport"))
require.NoError(t, err)

prevCerthashes := extractCertHashes(h.Addrs()[0])
h.Close()

h, err = libp2p.New(libp2p.NoTransports, libp2p.Transport(libp2pwebtransport.New, libp2pwebtransport.WithClock(cl)), libp2p.Identity(priv))
require.NoError(t, err)
defer h.Close()
err = h.Network().Listen(ma.StringCast("/ip4/127.0.0.1/udp/0/quic/webtransport"))
require.NoError(t, err)

nextCertHashes := extractCertHashes(h.Addrs()[0])

for i := range prevCerthashes {
require.Equal(t, prevCerthashes[i], nextCertHashes[i])
}
}
57 changes: 43 additions & 14 deletions p2p/transport/webtransport/cert_manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@ import (
"context"
"crypto/sha256"
"crypto/tls"
"encoding/binary"
"fmt"
"sync"
"time"

"github.com/benbjohnson/clock"
ic "github.com/libp2p/go-libp2p/core/crypto"
ma "github.com/multiformats/go-multiaddr"
"github.com/multiformats/go-multihash"
)
Expand All @@ -17,6 +19,7 @@ import (
// When we generate a certificate, the NotBefore time is set to clockSkewAllowance before the current time.
// Similarly, we stop using a certificate one clockSkewAllowance before its expiry time.
const clockSkewAllowance = time.Hour
const validityMinusTwoSkew = certValidity - (2 * clockSkewAllowance)

type certConfig struct {
tlsConf *tls.Config
Expand All @@ -26,8 +29,8 @@ type certConfig struct {
func (c *certConfig) Start() time.Time { return c.tlsConf.Certificates[0].Leaf.NotBefore }
func (c *certConfig) End() time.Time { return c.tlsConf.Certificates[0].Leaf.NotAfter }

func newCertConfig(start, end time.Time) (*certConfig, error) {
conf, err := getTLSConf(start, end)
func newCertConfig(key ic.PrivKey, start, end time.Time) (*certConfig, error) {
conf, err := getTLSConf(key, start, end)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -57,32 +60,58 @@ type certManager struct {
serializedCertHashes [][]byte
}

func newCertManager(clock clock.Clock) (*certManager, error) {
func newCertManager(hostKey ic.PrivKey, clock clock.Clock) (*certManager, error) {
m := &certManager{clock: clock}
m.ctx, m.ctxCancel = context.WithCancel(context.Background())
if err := m.init(); err != nil {
if err := m.init(hostKey); err != nil {
return nil, err
}

m.background()
m.background(hostKey)
return m, nil
}

func (m *certManager) init() error {
start := m.clock.Now().Add(-clockSkewAllowance)
var err error
m.nextConfig, err = newCertConfig(start, start.Add(certValidity))
// getCurrentTimeBucket returns the canonical start time of the given time as
// bucketed by ranges of certValidity since unix epoch (plus an offset). This
// lets you get the same time ranges across reboots without having to persist
// state.
// ```
// ... v--- epoch + offset
// ... |--------| |--------| ...
// ... |--------| |--------| ...
// ```
func getCurrentBucketStartTime(now time.Time, offset time.Duration) time.Time {
currentBucket := (now.UnixMilli() - offset.Milliseconds()) / validityMinusTwoSkew.Milliseconds()
return time.UnixMilli(offset.Milliseconds() + currentBucket*validityMinusTwoSkew.Milliseconds())
}

func (m *certManager) init(hostKey ic.PrivKey) error {
start := m.clock.Now()
pubkeyBytes, err := hostKey.GetPublic().Raw()
if err != nil {
return err
}

// We want to add a random offset to each start time so that not all certs
// rotate at the same time across the network. The offset represents moving
// the bucket start time some `offset` earlier.
offset := (time.Duration(binary.LittleEndian.Uint16(pubkeyBytes)) * time.Minute) % certValidity

// We want the certificate have been valid for at least one clockSkewAllowance
start = start.Add(-clockSkewAllowance)
startTime := getCurrentBucketStartTime(start, offset)
m.nextConfig, err = newCertConfig(hostKey, startTime, startTime.Add(certValidity))
if err != nil {
return err
}
return m.rollConfig()
return m.rollConfig(hostKey)
}

func (m *certManager) rollConfig() error {
func (m *certManager) rollConfig(hostKey ic.PrivKey) error {
// We stop using the current certificate clockSkewAllowance before its expiry time.
// At this point, the next certificate needs to be valid for one clockSkewAllowance.
nextStart := m.nextConfig.End().Add(-2 * clockSkewAllowance)
c, err := newCertConfig(nextStart, nextStart.Add(certValidity))
c, err := newCertConfig(hostKey, nextStart, nextStart.Add(certValidity))
if err != nil {
return err
}
Expand All @@ -95,7 +124,7 @@ func (m *certManager) rollConfig() error {
return m.cacheAddrComponent()
}

func (m *certManager) background() {
func (m *certManager) background(hostKey ic.PrivKey) {
d := m.currentConfig.End().Add(-clockSkewAllowance).Sub(m.clock.Now())
log.Debugw("setting timer", "duration", d.String())
t := m.clock.Timer(d)
Expand All @@ -111,7 +140,7 @@ func (m *certManager) background() {
return
case now := <-t.C:
m.mx.Lock()
if err := m.rollConfig(); err != nil {
if err := m.rollConfig(hostKey); err != nil {
log.Errorw("rolling config failed", "error", err)
}
d := m.currentConfig.End().Add(-clockSkewAllowance).Sub(now)
Expand Down
82 changes: 78 additions & 4 deletions p2p/transport/webtransport/cert_manager_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,14 @@ package libp2pwebtransport
import (
"crypto/sha256"
"crypto/tls"
"fmt"
"testing"
"testing/quick"
"time"

"github.com/benbjohnson/clock"
"github.com/libp2p/go-libp2p/core/crypto"
"github.com/libp2p/go-libp2p/core/test"
ma "github.com/multiformats/go-multiaddr"
"github.com/multiformats/go-multibase"
"github.com/multiformats/go-multihash"
Expand Down Expand Up @@ -39,14 +43,16 @@ func certHashFromComponent(t *testing.T, comp ma.Component) []byte {
func TestInitialCert(t *testing.T) {
cl := clock.NewMock()
cl.Add(1234567 * time.Hour)
m, err := newCertManager(cl)
priv, _, err := test.RandTestKeyPair(crypto.Ed25519, 256)
require.NoError(t, err)
m, err := newCertManager(priv, cl)
require.NoError(t, err)
defer m.Close()

conf := m.GetConfig()
require.Len(t, conf.Certificates, 1)
cert := conf.Certificates[0]
require.Equal(t, cl.Now().Add(-clockSkewAllowance).UTC(), cert.Leaf.NotBefore)
require.GreaterOrEqual(t, cl.Now().Add(-clockSkewAllowance), cert.Leaf.NotBefore)
require.Equal(t, cert.Leaf.NotBefore.Add(certValidity), cert.Leaf.NotAfter)
addr := m.AddrComponent()
components := splitMultiaddr(addr)
Expand All @@ -59,7 +65,11 @@ func TestInitialCert(t *testing.T) {

func TestCertRenewal(t *testing.T) {
cl := clock.NewMock()
m, err := newCertManager(cl)
// Add a year to avoid edge cases around the epoch
cl.Add(time.Hour * 24 * 365)
priv, _, err := test.SeededTestKeyPair(crypto.Ed25519, 256, 0)
require.NoError(t, err)
m, err := newCertManager(priv, cl)
require.NoError(t, err)
defer m.Close()

Expand All @@ -68,7 +78,7 @@ func TestCertRenewal(t *testing.T) {
require.Len(t, first, 2)
require.NotEqual(t, first[0].Value(), first[1].Value(), "the hashes should differ")
// wait for a new certificate to be generated
cl.Add(certValidity - 2*clockSkewAllowance - time.Second)
cl.Set(m.currentConfig.End().Add(-(clockSkewAllowance + time.Second)))
require.Never(t, func() bool {
for i, c := range splitMultiaddr(m.AddrComponent()) {
if c.Value() != first[i].Value() {
Expand Down Expand Up @@ -100,3 +110,67 @@ func TestCertRenewal(t *testing.T) {
// check that the 2nd certificate from the beginning was rolled over to be the 1st certificate
require.Equal(t, second[1].Value(), third[0].Value())
}

func TestDeterministicCertsAcrossReboots(t *testing.T) {
// Run this test 100 times to make sure it's deterministic
runs := 100
for i := 0; i < runs; i++ {
t.Run(fmt.Sprintf("Run=%d", i), func(t *testing.T) {
cl := clock.NewMock()
priv, _, err := test.SeededTestKeyPair(crypto.Ed25519, 256, 0)
require.NoError(t, err)
m, err := newCertManager(priv, cl)
require.NoError(t, err)
defer m.Close()

conf := m.GetConfig()
require.Len(t, conf.Certificates, 1)
oldCerts := m.serializedCertHashes

m.Close()

cl.Add(time.Hour)
// reboot
m, err = newCertManager(priv, cl)
require.NoError(t, err)
defer m.Close()

newCerts := m.serializedCertHashes

require.Equal(t, oldCerts, newCerts)
})
}
}

func TestDeterministicTimeBuckets(t *testing.T) {
cl := clock.NewMock()
cl.Add(time.Hour * 24 * 365)
startA := getCurrentBucketStartTime(cl.Now(), 0)
startB := getCurrentBucketStartTime(cl.Now().Add(time.Hour*24), 0)
require.Equal(t, startA, startB)

// 15 Days later
startC := getCurrentBucketStartTime(cl.Now().Add(time.Hour*24*15), 0)
require.NotEqual(t, startC, startB)
}

func TestGetCurrentBucketStartTimeIsWithinBounds(t *testing.T) {
require.NoError(t, quick.Check(func(timeSinceUnixEpoch time.Duration, offset time.Duration) bool {
if offset < 0 {
offset = -offset
}
if timeSinceUnixEpoch < 0 {
timeSinceUnixEpoch = -timeSinceUnixEpoch
}

offset = offset % certValidity
// Bound this to 100 years
timeSinceUnixEpoch = time.Duration(timeSinceUnixEpoch % (time.Hour * 24 * 365 * 100))
// Start a bit further in the future to avoid edge cases around epoch
timeSinceUnixEpoch += time.Hour * 24 * 365
start := time.UnixMilli(timeSinceUnixEpoch.Milliseconds())

bucketStart := getCurrentBucketStartTime(start.Add(-clockSkewAllowance), offset)
return !bucketStart.After(start.Add(-clockSkewAllowance)) || bucketStart.Equal(start.Add(-clockSkewAllowance))
}, nil))
}
58 changes: 51 additions & 7 deletions p2p/transport/webtransport/crypto.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,27 @@ import (
"bytes"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/sha256"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"encoding/binary"
"errors"
"fmt"
"io"
"math/big"
"time"

ic "github.com/libp2p/go-libp2p/core/crypto"

"github.com/multiformats/go-multihash"
"golang.org/x/crypto/hkdf"
)

func getTLSConf(start, end time.Time) (*tls.Config, error) {
cert, priv, err := generateCert(start, end)
const deterministicCertInfo = "determinisitic cert"

func getTLSConf(key ic.PrivKey, start, end time.Time) (*tls.Config, error) {
cert, priv, err := generateCert(key, start, end)
if err != nil {
return nil, err
}
Expand All @@ -32,9 +37,20 @@ func getTLSConf(start, end time.Time) (*tls.Config, error) {
}, nil
}

func generateCert(start, end time.Time) (*x509.Certificate, *ecdsa.PrivateKey, error) {
// generateCert generates certs deterministically based on the `key` and start
// time passed in. Uses `golang.org/x/crypto/hkdf`.
func generateCert(key ic.PrivKey, start, end time.Time) (*x509.Certificate, *ecdsa.PrivateKey, error) {
keyBytes, err := key.Raw()
if err != nil {
return nil, nil, err
}

startTimeSalt := make([]byte, 8)
binary.LittleEndian.PutUint64(startTimeSalt, uint64(start.UnixNano()))
deterministicHKDFReader := newDeterministicReader(keyBytes, startTimeSalt, deterministicCertInfo)

b := make([]byte, 8)
if _, err := rand.Read(b); err != nil {
if _, err := deterministicHKDFReader.Read(b); err != nil {
return nil, nil, err
}
serial := int64(binary.BigEndian.Uint64(b))
Expand All @@ -51,11 +67,12 @@ func generateCert(start, end time.Time) (*x509.Certificate, *ecdsa.PrivateKey, e
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageCertSign,
BasicConstraintsValid: true,
}
caPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)

caPrivateKey, err := ecdsa.GenerateKey(elliptic.P256(), deterministicHKDFReader)
if err != nil {
return nil, nil, err
}
caBytes, err := x509.CreateCertificate(rand.Reader, certTempl, certTempl, &caPrivateKey.PublicKey, caPrivateKey)
caBytes, err := x509.CreateCertificate(deterministicHKDFReader, certTempl, certTempl, caPrivateKey.Public(), caPrivateKey)
if err != nil {
return nil, nil, err
}
Expand Down Expand Up @@ -106,3 +123,30 @@ func verifyRawCerts(rawCerts [][]byte, certHashes []multihash.DecodedMultihash)
}
return nil
}

// deterministicReader is a hack. It counter-acts the Go library's attempt at
// making ECDSA signatures non-deterministic. Go adds non-determinism by
// randomly dropping a singly byte from the reader stream. This counteracts this
// by detecting when a read is a single byte and using a different reader
// instead.
type deterministicReader struct {
reader io.Reader
singleByteReader io.Reader
}

func newDeterministicReader(seed []byte, salt []byte, info string) io.Reader {
reader := hkdf.New(sha256.New, seed, salt, []byte(info))
singleByteReader := hkdf.New(sha256.New, seed, salt, []byte(info+" single byte"))

return &deterministicReader{
reader: reader,
singleByteReader: singleByteReader,
}
}

func (r *deterministicReader) Read(p []byte) (n int, err error) {
if len(p) == 1 {
return r.singleByteReader.Read(p)
}
return r.reader.Read(p)
}
Loading

0 comments on commit a0432e7

Please sign in to comment.