forked from quic-go/quic-go
-
Notifications
You must be signed in to change notification settings - Fork 0
/
conn_id_generator.go
163 lines (147 loc) · 5.65 KB
/
conn_id_generator.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
package quic
import (
"fmt"
"github.com/danielpfeifer02/quic-go-prio-packs/internal/protocol"
"github.com/danielpfeifer02/quic-go-prio-packs/internal/qerr"
"github.com/danielpfeifer02/quic-go-prio-packs/internal/wire"
)
type connIDGenerator struct {
generator ConnectionIDGenerator
highestSeq uint64
activeSrcConnIDs map[uint64]protocol.ConnectionID
initialClientDestConnID *protocol.ConnectionID // nil for the client
addConnectionID func(protocol.ConnectionID)
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken
removeConnectionID func(protocol.ConnectionID)
retireConnectionID func(protocol.ConnectionID)
replaceWithClosed func([]protocol.ConnectionID, []byte)
queueControlFrame func(wire.Frame)
}
func newConnIDGenerator(
initialConnectionID protocol.ConnectionID,
initialClientDestConnID *protocol.ConnectionID, // nil for the client
addConnectionID func(protocol.ConnectionID),
getStatelessResetToken func(protocol.ConnectionID) protocol.StatelessResetToken,
removeConnectionID func(protocol.ConnectionID),
retireConnectionID func(protocol.ConnectionID),
replaceWithClosed func([]protocol.ConnectionID, []byte),
queueControlFrame func(wire.Frame),
generator ConnectionIDGenerator,
) *connIDGenerator {
m := &connIDGenerator{
generator: generator,
activeSrcConnIDs: make(map[uint64]protocol.ConnectionID),
addConnectionID: addConnectionID,
getStatelessResetToken: getStatelessResetToken,
removeConnectionID: removeConnectionID,
retireConnectionID: retireConnectionID,
replaceWithClosed: replaceWithClosed,
queueControlFrame: queueControlFrame,
}
m.activeSrcConnIDs[0] = initialConnectionID
m.initialClientDestConnID = initialClientDestConnID
return m
}
func (m *connIDGenerator) SetMaxActiveConnIDs(limit uint64) error {
if m.generator.ConnectionIDLen() == 0 {
return nil
}
// PRIO_PACKS_TAG
if _, ok := m.generator.(*protocol.PriorityConnectionIDGenerator); ok {
numberOfPriorities := m.generator.(*protocol.PriorityConnectionIDGenerator).NumberOfPriorities
// < ensures that for each priority we have at least one connection ID
// != would ensure that there is *exactly* one connection ID for each priority
if limit < uint64(numberOfPriorities) {
fmt.Println("WARNING: active_connection_id_limit is smaller than the number of priorities. Set to the number of priorities.")
limit = uint64(numberOfPriorities)
}
if protocol.MaxIssuedConnectionIDs < uint64(numberOfPriorities) {
panic("MaxIssuedConnectionIDs is smaller than the number of priorities. Choose a smaller number of priorities or increase MaxIssuedConnectionIDs.")
}
}
// The active_connection_id_limit transport parameter is the number of
// connection IDs the peer will store. This limit includes the connection ID
// used during the handshake, and the one sent in the preferred_address
// transport parameter.
// We currently don't send the preferred_address transport parameter,
// so we can issue (limit - 1) connection IDs.
for i := uint64(len(m.activeSrcConnIDs)); i < min(limit, protocol.MaxIssuedConnectionIDs); i++ {
if err := m.issueNewConnID(); err != nil {
return err
}
}
return nil
}
func (m *connIDGenerator) Retire(seq uint64, sentWithDestConnID protocol.ConnectionID) error {
if seq > m.highestSeq {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: fmt.Sprintf("retired connection ID %d (highest issued: %d)", seq, m.highestSeq),
}
}
connID, ok := m.activeSrcConnIDs[seq]
// We might already have deleted this connection ID, if this is a duplicate frame.
if !ok {
return nil
}
if connID == sentWithDestConnID {
return &qerr.TransportError{
ErrorCode: qerr.ProtocolViolation,
ErrorMessage: fmt.Sprintf("retired connection ID %d (%s), which was used as the Destination Connection ID on this packet", seq, connID),
}
}
m.retireConnectionID(connID)
delete(m.activeSrcConnIDs, seq)
// Don't issue a replacement for the initial connection ID.
if seq == 0 {
return nil
}
if _, ok := m.generator.(*protocol.PriorityConnectionIDGenerator); ok {
// PRIO_PACKS_TAG
// if the retired connection ID had the same priority as the next one to be issued
// we need to set the next priority to the one of the retired connection ID
prio := connID.Bytes()[0]
m.generator.(*protocol.PriorityConnectionIDGenerator).NextPriority = int8(prio)
m.generator.(*protocol.PriorityConnectionIDGenerator).NextPriorityValid = true
}
return m.issueNewConnID()
}
func (m *connIDGenerator) issueNewConnID() error {
connID, err := m.generator.GenerateConnectionID()
if err != nil {
return err
}
m.activeSrcConnIDs[m.highestSeq+1] = connID
m.addConnectionID(connID)
m.queueControlFrame(&wire.NewConnectionIDFrame{
SequenceNumber: m.highestSeq + 1,
ConnectionID: connID,
StatelessResetToken: m.getStatelessResetToken(connID),
})
m.highestSeq++
return nil
}
func (m *connIDGenerator) SetHandshakeComplete() {
if m.initialClientDestConnID != nil {
m.retireConnectionID(*m.initialClientDestConnID)
m.initialClientDestConnID = nil
}
}
func (m *connIDGenerator) RemoveAll() {
if m.initialClientDestConnID != nil {
m.removeConnectionID(*m.initialClientDestConnID)
}
for _, connID := range m.activeSrcConnIDs {
m.removeConnectionID(connID)
}
}
func (m *connIDGenerator) ReplaceWithClosed(connClose []byte) {
connIDs := make([]protocol.ConnectionID, 0, len(m.activeSrcConnIDs)+1)
if m.initialClientDestConnID != nil {
connIDs = append(connIDs, *m.initialClientDestConnID)
}
for _, connID := range m.activeSrcConnIDs {
connIDs = append(connIDs, connID)
}
m.replaceWithClosed(connIDs, connClose)
}