/
session.go
327 lines (282 loc) · 9.28 KB
/
session.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
// Copyright © 2023 Brett Vickers.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// Package nts provides a client implementation of Network Time Security (NTS)
// for the Network Time Protocol (NTP). It enables the secure querying of
// time-related information that can be used to synchronize the local system
// clock with a more accurate network clock. See RFC 8915
// (https://tools.ietf.org/html/rfc8915) for more details.
package nts
import (
"bytes"
"crypto/cipher"
"crypto/rand"
"encoding/binary"
"errors"
"io"
"net"
"strconv"
"strings"
"unsafe"
"github.com/beevik/ntp"
)
var (
ErrAuthFailedOnClient = errors.New("authentication failed on client")
ErrAuthFailedOnServer = errors.New("authentication failed on server")
ErrInvalidFormat = errors.New("invalid packet format")
ErrNoCookies = errors.New("no NTS cookies available")
ErrUniqueIDMismatch = errors.New("client and server unique ID mismatch")
)
// Session contains the state of an active NTS session. It is initialized by
// exchanging keys and cookies with an NTS key-exchange server, after which
// the connection to the key-exchange server is immediately dropped. The
// session's internal state is updated as NTP queries are made against an
// NTS-capable NTP server.
type Session struct {
ntskeAddr string // "host:port" address used for NTS key exchange
ntpAddr string // "host:port" address to use for NTP service
cookies cookieJar // container for cookies consumed by NTP queries
cipherC2S cipher.AEAD // client-to-server authentication & encryption
cipherS2C cipher.AEAD // server-to-client authentication & encryption
uniqueID []byte // most recently transmitted unique ID
}
// NewSession creates an NTS session by connecting to an NTS key-exchange
// server and requesting keys and cookies to be used for future secure NTP
// queries. Once keys and cookies have been received, the connection is
// dropped. The address is of the form "host", "host:port", "host%zone:port",
// "[host]:port" or "[host%zone]:port". If no port is included, NTS default
// port 4460 is used.
func NewSession(address string) (*Session, error) {
if _, _, err := net.SplitHostPort(address); err != nil {
if strings.Contains(err.Error(), "missing port") {
address = net.JoinHostPort(address, strconv.Itoa(defaultNtsPort))
} else {
return nil, err
}
}
s := &Session{ntskeAddr: address}
err := s.performKeyExchange()
if err != nil {
return nil, err
}
return s, nil
}
// Address returns the NTP server "host:port" pair configured for the session.
func (s *Session) Address() string {
return s.ntpAddr
}
// Query time data from the session's associated NTP server. The response
// contains information from which an accurate local time can be determined.
func (s *Session) Query() (response *ntp.Response, err error) {
return s.QueryWithOptions(&ntp.QueryOptions{})
}
// QueryWithOptions performs the same function as Query but allows for the
// customization of certain NTP behaviors.
func (s *Session) QueryWithOptions(opt *ntp.QueryOptions) (response *ntp.Response, err error) {
opt.Extensions = append(opt.Extensions, privateWrapper{s})
return ntp.QueryWithOptions(s.ntpAddr, *opt)
}
// Refresh the session by clearing the its current cookies and performing a
// new key exchange. This should only be done when no queries have been
// performed with the session for a very long time (i.e., more than 24 hours).
func (s *Session) Refresh() error {
s.ntpAddr = ""
s.cipherC2S = nil
s.cipherS2C = nil
s.uniqueID = nil
s.cookies.Clear()
return s.performKeyExchange()
}
// privateWrapper wraps a session in a private type so we can avoid exposing
// ntp.Extension's ProcessQuery and ProcessResponse functions as public
// Session APIs.
type privateWrapper struct {
session *Session
}
func (w privateWrapper) ProcessQuery(buf *bytes.Buffer) error {
return w.session.processQuery(buf)
}
func (w privateWrapper) ProcessResponse(buf []byte) error {
return w.session.processResponse(buf)
}
func (s *Session) processQuery(buf *bytes.Buffer) error {
// Refresh session if we're out of cookies.
if s.cookies.count == 0 {
err := s.Refresh()
if err != nil {
return err
}
}
// Append the UniqueID extension field. Remember the unique ID so we can
// compare it to the response's value.
s.uniqueID = make([]byte, 32)
_, err := rand.Read(s.uniqueID)
if err != nil {
return err
}
writeExtUniqueID(buf, s.uniqueID)
// Append the cookie extension field.
cookie := s.cookies.Consume()
if cookie == nil {
return ErrNoCookies
}
writeExtCookie(buf, cookie)
// Append cookie placeholder fields. Request enough additional cookies to
// fill the jar.
phCount := cookieJarSize - (s.cookies.Count() + 1)
if phCount > 0 {
placeholder := make([]byte, paddedLen(len(cookie)))
for i := 0; i < phCount; i++ {
writeExtCookiePlaceholder(buf, placeholder)
}
}
// Authenticate the packet up to this point and append the AEAD extension
// field.
nonce := allocAligned(s.cipherC2S.NonceSize())
_, err = rand.Read(nonce)
if err != nil {
return err
}
ciphertext := s.cipherC2S.Seal(nil, nonce, nil, buf.Bytes())
writeExtAEAD(buf, nonce, ciphertext)
return nil
}
func (s *Session) processResponse(buf []byte) error {
const (
cryptoNAK = 0x4e54534e // Kiss code "NTSN"
ntpHeaderLen = 48
)
defer func() {
s.uniqueID = nil
}()
// Check the NTP header for a crypto-NAK kiss-of-death.
stratum := buf[1]
if stratum == 0 {
kissCode := binary.BigEndian.Uint32(buf[12:])
if kissCode == cryptoNAK {
return ErrAuthFailedOnServer
}
}
// Process all NTS extension fields.
offset := ntpHeaderLen
cur := buf[offset:]
for len(cur) > 4 {
xtype := extType(binary.BigEndian.Uint16(cur[0:2]))
xlen := int(binary.BigEndian.Uint16(cur[2:4]))
if len(cur) < xlen {
return ErrInvalidFormat
}
body := cur[4:xlen]
cur = cur[xlen:]
switch xtype {
case extUniqueID:
if !bytes.Equal(s.uniqueID, body) {
return ErrUniqueIDMismatch
}
case extAEAD:
if len(body) < 4 {
return ErrInvalidFormat
}
nonceLen := int(binary.BigEndian.Uint16(body[0:2]))
nonceLenPadded := paddedLen(nonceLen)
ciphertextLen := int(binary.BigEndian.Uint16(body[2:4]))
ciphertextLenPadded := paddedLen(ciphertextLen)
if len(body) < 4+ciphertextLenPadded+nonceLenPadded {
return ErrInvalidFormat
}
// NOTE: The siv-go package has an undocumented issue where all
// memory accesses must be 64-bit aligned or it segfaults. To
// prevent this, copy the nonce and ciphertext into newly
// allocated, memory-aligned slices before decrypting and
// authenticating.
nonce := allocAligned(nonceLen)
copy(nonce, body[4:])
ciphertext := allocAligned(ciphertextLen)
copy(ciphertext, body[4+nonceLenPadded:])
// Decrypt and authenticate. The ciphertext contains only
// encrypted cookies.
plaintext, err := s.cipherS2C.Open(nil, nonce, ciphertext, buf[:offset])
if err != nil {
return ErrAuthFailedOnClient
}
err = s.processCookies(plaintext)
if err != nil {
return err
}
}
offset += xlen
}
return nil
}
func (s *Session) processCookies(buf []byte) error {
for len(buf) > 4 {
xtype := extType(binary.BigEndian.Uint16(buf[0:2]))
xlen := int(binary.BigEndian.Uint16(buf[2:4]))
if len(buf) < xlen {
return ErrInvalidFormat
}
body := buf[4:xlen]
buf = buf[xlen:]
if xtype == extCookie {
cookie := make([]byte, len(body))
copy(cookie, body)
s.cookies.Add(cookie)
}
}
return nil
}
func allocAligned(size int) []byte {
buf := make([]byte, size)
ptr := uintptr(unsafe.Pointer(&buf[0]))
if (ptr & uintptr(7)) == 0 {
return buf
}
buf = make([]byte, size+7)
ptr = uintptr(unsafe.Pointer(&buf[0]))
offset := (8 - int(ptr&uintptr(7))) & 7
return buf[offset : offset+size]
}
var pad = make([]byte, 4)
func paddedLen(len int) int {
return (len + 3) & ^3
}
type extType uint16
const (
extUniqueID extType = 0x0104
extCookie extType = 0x0204
extCookiePlaceholder extType = 0x0304
extAEAD extType = 0x0404
)
func writeExtUniqueID(w io.Writer, uniqueID []byte) {
totalLen := 4 + len(uniqueID)
binary.Write(w, binary.BigEndian, extUniqueID)
binary.Write(w, binary.BigEndian, uint16(totalLen))
w.Write(uniqueID)
}
func writeExtCookie(w io.Writer, cookie []byte) {
cookieLenPadded := paddedLen(len(cookie))
totalLen := 4 + cookieLenPadded
binary.Write(w, binary.BigEndian, extCookie)
binary.Write(w, binary.BigEndian, uint16(totalLen))
w.Write(cookie)
w.Write(pad[:cookieLenPadded-len(cookie)])
}
func writeExtCookiePlaceholder(w io.Writer, placeholder []byte) {
totalLen := 4 + len(placeholder)
binary.Write(w, binary.BigEndian, extCookiePlaceholder)
binary.Write(w, binary.BigEndian, uint16(totalLen))
w.Write(placeholder)
}
func writeExtAEAD(w io.Writer, nonce []byte, ciphertext []byte) {
nonceLenPadded := paddedLen(len(nonce))
ciphertextLenPadded := paddedLen(len(ciphertext))
totalLen := 4 + 4 + nonceLenPadded + ciphertextLenPadded
binary.Write(w, binary.BigEndian, extAEAD)
binary.Write(w, binary.BigEndian, uint16(totalLen))
binary.Write(w, binary.BigEndian, uint16(len(nonce)))
binary.Write(w, binary.BigEndian, uint16(len(ciphertext)))
w.Write(nonce)
w.Write(pad[:nonceLenPadded-len(nonce)])
w.Write(ciphertext)
w.Write(pad[:ciphertextLenPadded-len(ciphertext)])
}