Skip to content

Commit

Permalink
use a random length destination connection ID on the Initial packet
Browse files Browse the repository at this point in the history
The destination connection ID on the Initial packet must be at least 8
bytes long. By using all valid values, we make sure that the everything
works correctly. The server chooses a new connection ID with the Retry
or Handshake packet it sends, so the overhead of this is negligible.
  • Loading branch information
marten-seemann committed Jul 3, 2018
1 parent 0bd7e74 commit 73f7636
Show file tree
Hide file tree
Showing 6 changed files with 78 additions and 38 deletions.
47 changes: 23 additions & 24 deletions client.go
Expand Up @@ -140,23 +140,13 @@ func newClient(
) (*client, error) {
clientConfig := populateClientConfig(config)
version := clientConfig.Versions[0]
srcConnID, err := generateConnectionID()
if err != nil {
return nil, err
}
destConnID := srcConnID
if version.UsesTLS() {
destConnID, err = generateConnectionID()
if err != nil {
return nil, err
}
}

var hostname string
if tlsConf != nil {
hostname = tlsConf.ServerName
}
if hostname == "" {
var err error
hostname, _, err = net.SplitHostPort(host)
if err != nil {
return nil, err
Expand All @@ -175,18 +165,17 @@ func newClient(
if closeCallback != nil {
onClose = closeCallback
}
return &client{
c := &client{
conn: &conn{pconn: pconn, currentAddr: remoteAddr},
srcConnID: srcConnID,
destConnID: destConnID,
hostname: hostname,
tlsConf: tlsConf,
config: clientConfig,
version: version,
handshakeChan: make(chan struct{}),
closeCallback: onClose,
logger: utils.DefaultLogger.WithPrefix("client"),
}, nil
}
return c, c.generateConnectionIDs()
}

// populateClientConfig populates fields in the quic.Config with their default values, if none are set
Expand Down Expand Up @@ -243,6 +232,23 @@ func populateClientConfig(config *Config) *Config {
}
}

func (c *client) generateConnectionIDs() error {
srcConnID, err := generateConnectionID(protocol.ConnectionIDLenGQUIC)
if err != nil {
return err
}
destConnID := srcConnID
if c.version.UsesTLS() {
destConnID, err = protocol.GenerateDestinationConnectionID()
if err != nil {
return err
}
}
c.srcConnID = srcConnID
c.destConnID = destConnID
return nil
}

func (c *client) dial(ctx context.Context) error {
c.logger.Infof("Starting new connection to %s (%s -> %s), source connection ID %s, destination connection ID %s, version %s", c.hostname, c.conn.LocalAddr(), c.conn.RemoteAddr(), c.srcConnID, c.destConnID, c.version)

Expand Down Expand Up @@ -506,15 +512,8 @@ func (c *client) handleVersionNegotiationPacket(hdr *wire.Header) error {
// switch to negotiated version
c.initialVersion = c.version
c.version = newVersion
var err error
c.destConnID, err = generateConnectionID()
if err != nil {
return err
}
// in gQUIC, there's only one connection ID
if !c.version.UsesTLS() {
c.srcConnID = c.destConnID
}
c.generateConnectionIDs()

c.logger.Infof("Switching to QUIC version %s. New connection ID: %s", newVersion, c.destConnID)
c.session.Close(errCloseSessionForNewVersion)
return nil
Expand Down
4 changes: 2 additions & 2 deletions client_test.go
Expand Up @@ -81,11 +81,11 @@ var _ = Describe("Client", func() {
})

Context("Dialing", func() {
var origGenerateConnectionID func() (protocol.ConnectionID, error)
var origGenerateConnectionID func(int) (protocol.ConnectionID, error)

BeforeEach(func() {
origGenerateConnectionID = generateConnectionID
generateConnectionID = func() (protocol.ConnectionID, error) {
generateConnectionID = func(int) (protocol.ConnectionID, error) {
return connID, nil
}
})
Expand Down
17 changes: 15 additions & 2 deletions internal/protocol/connection_id.go
Expand Up @@ -10,15 +10,28 @@ import (
// A ConnectionID in QUIC
type ConnectionID []byte

const maxConnectionIDLen = 18

// GenerateConnectionID generates a connection ID using cryptographic random
func GenerateConnectionID() (ConnectionID, error) {
b := make([]byte, ConnectionIDLenGQUIC)
func GenerateConnectionID(len int) (ConnectionID, error) {
b := make([]byte, len)
if _, err := rand.Read(b); err != nil {
return nil, err
}
return ConnectionID(b), nil
}

// GenerateDestinationConnectionID generates a connection ID for the Initial packet.
// It uses a length randomly chosen between 8 and 18 bytes.
func GenerateDestinationConnectionID() (ConnectionID, error) {
r := make([]byte, 1)
if _, err := rand.Read(r); err != nil {
return nil, err
}
len := MinConnectionIDLenInitial + int(r[0])%(maxConnectionIDLen-MinConnectionIDLenInitial+1)
return GenerateConnectionID(len)
}

// ReadConnectionID reads a connection ID of length len from the given io.Reader.
// It returns io.EOF if there are not enough bytes to read.
func ReadConnectionID(r io.Reader, len int) (ConnectionID, error) {
Expand Down
28 changes: 26 additions & 2 deletions internal/protocol/connection_id_test.go
Expand Up @@ -10,14 +10,38 @@ import (

var _ = Describe("Connection ID generation", func() {
It("generates random connection IDs", func() {
c1, err := GenerateConnectionID()
c1, err := GenerateConnectionID(8)
Expect(err).ToNot(HaveOccurred())
Expect(c1).ToNot(BeZero())
c2, err := GenerateConnectionID()
c2, err := GenerateConnectionID(8)
Expect(err).ToNot(HaveOccurred())
Expect(c1).ToNot(Equal(c2))
})

It("generates connection IDs with the requested length", func() {
c, err := GenerateConnectionID(5)
Expect(err).ToNot(HaveOccurred())
Expect(c.Len()).To(Equal(5))
})

It("generates random length destination connection IDs", func() {
var has8ByteConnID, has18ByteConnID bool
for i := 0; i < 1000; i++ {
c, err := GenerateDestinationConnectionID()
Expect(err).ToNot(HaveOccurred())
Expect(c.Len()).To(BeNumerically(">=", 8))
Expect(c.Len()).To(BeNumerically("<=", 18))
if c.Len() == 8 {
has8ByteConnID = true
}
if c.Len() == 18 {
has18ByteConnID = true
}
}
Expect(has8ByteConnID).To(BeTrue())
Expect(has18ByteConnID).To(BeTrue())
})

It("says if connection IDs are equal", func() {
c1 := ConnectionID{1, 2, 3, 4, 5, 6, 7, 8}
c2 := ConnectionID{8, 7, 6, 5, 4, 3, 2, 1}
Expand Down
8 changes: 6 additions & 2 deletions server_tls.go
Expand Up @@ -194,11 +194,15 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header,
StreamID: version.CryptoStreamID(),
Data: bc.GetDataForWriting(),
}
srcConnID, err := protocol.GenerateConnectionID(protocol.ConnectionIDLenGQUIC)
if err != nil {
return nil, nil, err
}
replyHdr := &wire.Header{
IsLongHeader: true,
Type: protocol.PacketTypeRetry,
DestConnectionID: hdr.SrcConnectionID,
SrcConnectionID: hdr.DestConnectionID,
SrcConnectionID: srcConnID,
PayloadLen: f.Length(version) + protocol.ByteCount(aead.Overhead()),
PacketNumber: hdr.PacketNumber, // echo the client's packet number
PacketNumberLen: hdr.PacketNumberLen,
Expand All @@ -224,7 +228,7 @@ func (s *serverTLS) handleUnpackedInitial(remoteAddr net.Addr, hdr *wire.Header,
return nil, nil, fmt.Errorf("Expected mint state to be %s, got %s", mint.StateServerWaitFlight2, tls.State())
}
params := <-paramsChan
connID, err := protocol.GenerateConnectionID()
connID, err := protocol.GenerateConnectionID(protocol.ConnectionIDLenGQUIC)
if err != nil {
return nil, nil, err
}
Expand Down
12 changes: 6 additions & 6 deletions server_tls_test.go
Expand Up @@ -71,7 +71,7 @@ var _ = Describe("Stateless TLS handling", func() {
return hdr, data
}

unpackPacket := func(data []byte) (*wire.Header, []byte) {
unpackPacket := func(data []byte, clientDestConnID protocol.ConnectionID) (*wire.Header, []byte) {
r := bytes.NewReader(conn.dataWritten.Bytes())
iHdr, err := wire.ParseInvariantHeader(r, 0)
Expect(err).ToNot(HaveOccurred())
Expand All @@ -80,7 +80,7 @@ var _ = Describe("Stateless TLS handling", func() {
hdr.Raw = data[:len(data)-r.Len()]
var payload []byte
if r.Len() > 0 {
aead, err := crypto.NewNullAEAD(protocol.PerspectiveClient, hdr.SrcConnectionID, protocol.VersionTLS)
aead, err := crypto.NewNullAEAD(protocol.PerspectiveClient, clientDestConnID, protocol.VersionTLS)
Expect(err).ToNot(HaveOccurred())
payload, err = aead.Open(nil, data[len(data)-r.Len():], hdr.PacketNumber, hdr.Raw)
Expect(err).ToNot(HaveOccurred())
Expand All @@ -97,7 +97,7 @@ var _ = Describe("Stateless TLS handling", func() {
}
server.HandleInitial(nil, hdr, bytes.Repeat([]byte{0}, protocol.MinInitialPacketSize))
Expect(conn.dataWritten.Len()).ToNot(BeZero())
replyHdr, _ := unpackPacket(conn.dataWritten.Bytes())
replyHdr, _ := unpackPacket(conn.dataWritten.Bytes(), hdr.DestConnectionID)
Expect(replyHdr.IsVersionNegotiation).To(BeTrue())
Expect(sessionChan).ToNot(Receive())
})
Expand Down Expand Up @@ -134,9 +134,9 @@ var _ = Describe("Stateless TLS handling", func() {
hdr, data := getPacket(&wire.StreamFrame{Data: []byte("Client Hello")})
server.HandleInitial(nil, hdr, data)
Expect(conn.dataWritten.Len()).ToNot(BeZero())
replyHdr, payload := unpackPacket(conn.dataWritten.Bytes())
replyHdr, payload := unpackPacket(conn.dataWritten.Bytes(), hdr.DestConnectionID)
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeRetry))
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
Expect(replyHdr.SrcConnectionID).ToNot(Equal(hdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
Expect(replyHdr.PayloadLen).To(BeEquivalentTo(len(payload) + 16 /* AEAD overhead */))
Expect(sessionChan).ToNot(Receive())
Expand Down Expand Up @@ -187,7 +187,7 @@ var _ = Describe("Stateless TLS handling", func() {
// the Handshake packet is written by the session
Expect(conn.dataWritten.Bytes()).ToNot(BeEmpty())
// unpack the packet to check that it actually contains a CONNECTION_CLOSE
replyHdr, data := unpackPacket(conn.dataWritten.Bytes())
replyHdr, data := unpackPacket(conn.dataWritten.Bytes(), hdr.DestConnectionID)
Expect(replyHdr.Type).To(Equal(protocol.PacketTypeHandshake))
Expect(replyHdr.SrcConnectionID).To(Equal(hdr.DestConnectionID))
Expect(replyHdr.DestConnectionID).To(Equal(hdr.SrcConnectionID))
Expand Down

0 comments on commit 73f7636

Please sign in to comment.