-
Notifications
You must be signed in to change notification settings - Fork 46
/
measurements.go
597 lines (517 loc) · 18.4 KB
/
measurements.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
/*
Copyright (c) Edgeless Systems GmbH
SPDX-License-Identifier: AGPL-3.0-only
*/
/*
# Measurements
Defines default expected measurements for the current release, as well as functions for comparing, updating and marshalling measurements.
This package should not include TPM specific code.
*/
package measurements
import (
"bytes"
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"sort"
"strconv"
"strings"
"github.com/edgelesssys/constellation/v2/internal/attestation/variant"
"github.com/edgelesssys/constellation/v2/internal/cloud/cloudprovider"
"github.com/google/go-tpm/tpmutil"
"github.com/siderolabs/talos/pkg/machinery/config/encoder"
"gopkg.in/yaml.v3"
"github.com/edgelesssys/constellation/v2/internal/api/versionsapi"
)
//go:generate measurement-generator
const (
// PCRIndexClusterID is a PCR we extend to mark the node as initialized.
// The value used to extend is a random generated 32 Byte value.
PCRIndexClusterID = tpmutil.Handle(15)
// PCRIndexOwnerID is a PCR we extend to mark the node as initialized.
// The value used to extend is derived from Constellation's master key.
// TODO(daniel-weisse): move to stable, non-debug PCR before use.
PCRIndexOwnerID = tpmutil.Handle(16)
// TDXIndexClusterID is the measurement used to mark the node as initialized.
// The value is the index of the RTMR + 1, since index 0 of the TDX measurements is reserved for MRTD.
TDXIndexClusterID = RTMRIndexClusterID + 1
// RTMRIndexClusterID is the RTMR we extend to mark the node as initialized.
RTMRIndexClusterID = 2
// PCRMeasurementLength holds the length for valid PCR measurements (SHA256).
PCRMeasurementLength = 32
// TDXMeasurementLength holds the length for valid TDX measurements (SHA384).
TDXMeasurementLength = 48
)
// M are Platform Configuration Register (PCR) values that make up the Measurements.
type M map[uint32]Measurement
// ImageMeasurementsV2 is a struct to hold measurements for a specific image.
// .List contains measurements for all variants of the image.
type ImageMeasurementsV2 struct {
Version string `json:"version" yaml:"version"`
Ref string `json:"ref" yaml:"ref"`
Stream string `json:"stream" yaml:"stream"`
List []ImageMeasurementsV2Entry `json:"list" yaml:"list"`
}
// ImageMeasurementsV2Entry is a struct to hold measurements for one variant of a specific image.
type ImageMeasurementsV2Entry struct {
CSP cloudprovider.Provider `json:"csp" yaml:"csp"`
AttestationVariant string `json:"attestationVariant" yaml:"attestationVariant"`
Measurements M `json:"measurements" yaml:"measurements"`
}
// MergeImageMeasurementsV2 combines the image measurement entries from multiple sources into a single
// ImageMeasurementsV2 object.
func MergeImageMeasurementsV2(measurements ...ImageMeasurementsV2) (ImageMeasurementsV2, error) {
if len(measurements) == 0 {
return ImageMeasurementsV2{}, errors.New("no measurement objects specified")
}
if len(measurements) == 1 {
return measurements[0], nil
}
out := ImageMeasurementsV2{
Version: measurements[0].Version,
Ref: measurements[0].Ref,
Stream: measurements[0].Stream,
List: []ImageMeasurementsV2Entry{},
}
for _, m := range measurements {
if m.Version != out.Version {
return ImageMeasurementsV2{}, errors.New("version mismatch")
}
if m.Ref != out.Ref {
return ImageMeasurementsV2{}, errors.New("ref mismatch")
}
if m.Stream != out.Stream {
return ImageMeasurementsV2{}, errors.New("stream mismatch")
}
out.List = append(out.List, m.List...)
}
sort.SliceStable(out.List, func(i, j int) bool {
if out.List[i].CSP != out.List[j].CSP {
return out.List[i].CSP < out.List[j].CSP
}
return out.List[i].AttestationVariant < out.List[j].AttestationVariant
})
return out, nil
}
// MarshalYAML returns the YAML encoding of m.
func (m M) MarshalYAML() (any, error) {
// cast to prevent infinite recursion
node, err := encoder.NewEncoder(map[uint32]Measurement(m)).Marshal()
if err != nil {
return nil, err
}
// sort keys numerically
sort.Sort(mYamlContent(node.Content))
return node, nil
}
// FetchAndVerify fetches measurement and signature files via provided URLs,
// using client for download.
// The hash of the fetched measurements is returned.
func (m *M) FetchAndVerify(
ctx context.Context, client *http.Client, verifier cosignVerifier,
measurementsURL, signatureURL *url.URL,
version versionsapi.Version, csp cloudprovider.Provider, attestationVariant variant.Variant,
) (string, error) {
return m.fetchAndVerify(
ctx, client, verifier,
measurementsURL, signatureURL,
version, csp, attestationVariant,
)
}
// fetchAndVerify fetches measurement and signature files via provided URLs,
// using client for download. The publicKey is used to verify the measurements.
// The hash of the fetched measurements is returned.
func (m *M) fetchAndVerify(
ctx context.Context, client *http.Client, verifier cosignVerifier,
measurementsURL, signatureURL *url.URL,
version versionsapi.Version, csp cloudprovider.Provider, attestationVariant variant.Variant,
) (string, error) {
measurementsRaw, err := getFromURL(ctx, client, measurementsURL)
if err != nil {
return "", fmt.Errorf("failed to fetch measurements: %w", err)
}
signature, err := getFromURL(ctx, client, signatureURL)
if err != nil {
return "", fmt.Errorf("failed to fetch signature: %w", err)
}
if err := verifier.VerifySignature(measurementsRaw, signature); err != nil {
return "", err
}
var measurements ImageMeasurementsV2
if err := json.Unmarshal(measurementsRaw, &measurements); err != nil {
return "", err
}
if err := m.fromImageMeasurementsV2(measurements, version, csp, attestationVariant); err != nil {
return "", err
}
shaHash := sha256.Sum256(measurementsRaw)
return hex.EncodeToString(shaHash[:]), nil
}
// FetchNoVerify fetches measurement via provided URLs,
// using client for download. Measurements are not verified.
func (m *M) FetchNoVerify(ctx context.Context, client *http.Client, measurementsURL *url.URL,
version versionsapi.Version, csp cloudprovider.Provider, attestationVariant variant.Variant,
) error {
measurementsRaw, err := getFromURL(ctx, client, measurementsURL)
if err != nil {
return fmt.Errorf("failed to fetch measurements from %s: %w", measurementsURL.String(), err)
}
var measurements ImageMeasurementsV2
if err := json.Unmarshal(measurementsRaw, &measurements); err != nil {
return err
}
return m.fromImageMeasurementsV2(measurements, version, csp, attestationVariant)
}
// CopyFrom copies over all values from other. Overwriting existing values,
// but keeping not specified values untouched.
func (m *M) CopyFrom(other M) {
for idx := range other {
(*m)[idx] = other[idx]
}
}
// Copy creates a new map with the same values as the original.
func (m *M) Copy() M {
newM := make(M, len(*m))
for idx := range *m {
newM[idx] = (*m)[idx]
}
return newM
}
// EqualTo tests whether the provided other Measurements are equal to these
// measurements.
func (m *M) EqualTo(other M) bool {
if len(*m) != len(other) {
return false
}
for k, v := range *m {
otherExpected := other[k].Expected
if !bytes.Equal(v.Expected, otherExpected) {
return false
}
if v.ValidationOpt != other[k].ValidationOpt {
return false
}
}
return true
}
// Compare compares the expected measurements to the given list of measurements.
// It returns a list of warnings for non matching measurements for WarnOnly entries,
// and a list of errors for non matching measurements for Enforce entries.
func (m M) Compare(other map[uint32][]byte) (warnings []string, errs []error) {
// Get list of indices in expected measurements
var mIndices []uint32
for idx := range m {
mIndices = append(mIndices, idx)
}
sort.SliceStable(mIndices, func(i, j int) bool {
return mIndices[i] < mIndices[j]
})
for _, idx := range mIndices {
if !bytes.Equal(m[idx].Expected, other[idx]) {
msg := fmt.Sprintf("untrusted measurement value %x at index %d", other[idx], idx)
if len(other[idx]) == 0 {
msg = fmt.Sprintf("missing measurement value for index %d", idx)
}
if m[idx].ValidationOpt == Enforce {
errs = append(errs, errors.New(msg))
} else {
warnings = append(warnings, fmt.Sprintf("Encountered %s", msg))
}
}
}
return warnings, errs
}
// GetEnforced returns a list of all enforced Measurements,
// i.e. all Measurements that are not marked as WarnOnly.
func (m *M) GetEnforced() []uint32 {
var enforced []uint32
for idx, measurement := range *m {
if measurement.ValidationOpt == Enforce {
enforced = append(enforced, idx)
}
}
return enforced
}
// SetEnforced sets the WarnOnly flag to true for all Measurements
// that are NOT included in the provided list of enforced measurements.
func (m *M) SetEnforced(enforced []uint32) error {
newM := make(M)
// set all measurements to warn only
for idx, measurement := range *m {
newM[idx] = Measurement{
Expected: measurement.Expected,
ValidationOpt: WarnOnly,
}
}
// set enforced measurements from list
for _, idx := range enforced {
measurement, ok := newM[idx]
if !ok {
return fmt.Errorf("measurement %d not in list, but set to enforced", idx)
}
measurement.ValidationOpt = Enforce
newM[idx] = measurement
}
*m = newM
return nil
}
// UnmarshalJSON unmarshals measurements from json.
// This function enforces all measurements to be of equal length.
func (m *M) UnmarshalJSON(b []byte) error {
newM := make(map[uint32]Measurement)
if err := json.Unmarshal(b, &newM); err != nil {
return err
}
// check if all measurements are of equal length
if err := checkLength(newM); err != nil {
return err
}
*m = newM
return nil
}
// UnmarshalYAML unmarshals measurements from yaml.
// This function enforces all measurements to be of equal length.
func (m *M) UnmarshalYAML(unmarshal func(any) error) error {
newM := make(map[uint32]Measurement)
if err := unmarshal(&newM); err != nil {
return err
}
// check if all measurements are of equal length
if err := checkLength(newM); err != nil {
return err
}
*m = newM
return nil
}
// String returns a string representation of the measurements.
func (m M) String() string {
var returnString string
for i, measurement := range m {
returnString = strings.Join([]string{returnString, fmt.Sprintf("%d: 0x%s", i, hex.EncodeToString(measurement.Expected))}, ",")
}
return returnString
}
func (m *M) fromImageMeasurementsV2(
measurements ImageMeasurementsV2, wantVersion versionsapi.Version,
csp cloudprovider.Provider, attestationVariant variant.Variant,
) error {
gotVersion, err := versionsapi.NewVersion(measurements.Ref, measurements.Stream, measurements.Version, versionsapi.VersionKindImage)
if err != nil {
return fmt.Errorf("invalid metadata version: %w", err)
}
if !wantVersion.Equal(gotVersion) {
return fmt.Errorf("invalid measurement metadata: version mismatch: expected %s, got %s", wantVersion.ShortPath(), gotVersion.ShortPath())
}
// find measurements for requested image in list
var measurementsEntry ImageMeasurementsV2Entry
var found bool
for _, entry := range measurements.List {
gotCSP := entry.CSP
if gotCSP != csp {
continue
}
gotAttestationVariant, err := variant.FromString(entry.AttestationVariant)
if err != nil {
continue
}
if gotAttestationVariant == nil || attestationVariant == nil {
continue
}
if !gotAttestationVariant.Equal(attestationVariant) {
continue
}
measurementsEntry = entry
found = true
break
}
if !found {
return fmt.Errorf("invalid measurement metadata: no measurements found for csp %s, attestationVariant %s and image %s", csp.String(), attestationVariant, wantVersion.ShortPath())
}
*m = measurementsEntry.Measurements
return nil
}
// Measurement wraps expected PCR value and whether it is enforced.
type Measurement struct {
// Expected measurement value.
// 32 bytes for vTPM attestation, 48 for TDX.
Expected []byte `json:"expected" yaml:"expected"`
// ValidationOpt indicates how measurement mismatches should be handled.
ValidationOpt MeasurementValidationOption `json:"warnOnly" yaml:"warnOnly"`
}
// MeasurementValidationOption indicates how measurement mismatches should be handled.
type MeasurementValidationOption bool
const (
// WarnOnly will only result in a warning in case of a mismatching measurement.
WarnOnly MeasurementValidationOption = true
// Enforce will result in an error in case of a mismatching measurement, and operation will be aborted.
Enforce MeasurementValidationOption = false
)
// UnmarshalJSON reads a Measurement either as json object,
// or as a simple hex or base64 encoded string.
func (m *Measurement) UnmarshalJSON(b []byte) error {
var eM encodedMeasurement
if err := json.Unmarshal(b, &eM); err != nil {
// Unmarshalling failed, Measurement might be a simple string instead of Measurement struct.
// These values will always be enforced.
if legacyErr := json.Unmarshal(b, &eM.Expected); legacyErr != nil {
return errors.Join(
err,
fmt.Errorf("trying legacy format: %w", legacyErr),
)
}
}
if err := m.unmarshal(eM); err != nil {
return fmt.Errorf("unmarshalling json: %w", err)
}
return nil
}
// MarshalJSON writes out a Measurement with Expected encoded as a hex string.
func (m Measurement) MarshalJSON() ([]byte, error) {
return json.Marshal(encodedMeasurement{
Expected: hex.EncodeToString(m.Expected[:]),
WarnOnly: m.ValidationOpt,
})
}
// UnmarshalYAML reads a Measurement either as yaml object,
// or as a simple hex or base64 encoded string.
func (m *Measurement) UnmarshalYAML(unmarshal func(any) error) error {
var eM encodedMeasurement
if err := unmarshal(&eM); err != nil {
// Unmarshalling failed, Measurement might be a simple string instead of Measurement struct.
// These values will always be enforced.
if legacyErr := unmarshal(&eM.Expected); legacyErr != nil {
return errors.Join(
err,
fmt.Errorf("trying legacy format: %w", legacyErr),
)
}
}
if err := m.unmarshal(eM); err != nil {
return fmt.Errorf("unmarshalling yaml: %w", err)
}
return nil
}
// MarshalYAML writes out a Measurement with Expected encoded as a hex string.
func (m Measurement) MarshalYAML() (any, error) {
return encodedMeasurement{
Expected: hex.EncodeToString(m.Expected[:]),
WarnOnly: m.ValidationOpt,
}, nil
}
// unmarshal parses a hex or base64 encoded Measurement.
func (m *Measurement) unmarshal(eM encodedMeasurement) error {
expected, err := hex.DecodeString(eM.Expected)
if err != nil {
return fmt.Errorf("decoding measurement: %w", err)
}
if len(expected) != 32 && len(expected) != 48 {
return fmt.Errorf("invalid measurement: invalid length: %d", len(expected))
}
m.Expected = expected
m.ValidationOpt = eM.WarnOnly
return nil
}
// WithAllBytes returns a measurement value where all bytes are set to b. Takes a dynamic length as input.
// Expected are either 32 bytes (PCRMeasurementLength) or 48 bytes (TDXMeasurementLength).
// Over inputs are possible in this function, but potentially rejected elsewhere.
func WithAllBytes(b byte, validationOpt MeasurementValidationOption, len int) Measurement {
return Measurement{
Expected: bytes.Repeat([]byte{b}, len),
ValidationOpt: validationOpt,
}
}
// PlaceHolderMeasurement returns a measurement with placeholder values for Expected.
func PlaceHolderMeasurement(len int) Measurement {
return Measurement{
Expected: bytes.Repeat([]byte{0x12, 0x34}, len/2),
ValidationOpt: Enforce,
}
}
// DefaultsFor provides the default measurements for given cloud provider.
func DefaultsFor(provider cloudprovider.Provider, attestationVariant variant.Variant) M {
switch {
case provider == cloudprovider.AWS && attestationVariant == variant.AWSNitroTPM{}:
return aws_AWSNitroTPM.Copy()
case provider == cloudprovider.AWS && attestationVariant == variant.AWSSEVSNP{}:
return aws_AWSSEVSNP.Copy()
case provider == cloudprovider.Azure && attestationVariant == variant.AzureSEVSNP{}:
return azure_AzureSEVSNP.Copy()
case provider == cloudprovider.Azure && attestationVariant == variant.AzureTDX{}:
return azure_AzureTDX.Copy()
case provider == cloudprovider.Azure && attestationVariant == variant.AzureTrustedLaunch{}:
return azure_AzureTrustedLaunch.Copy()
case provider == cloudprovider.GCP && attestationVariant == variant.GCPSEVES{}:
return gcp_GCPSEVES.Copy()
case provider == cloudprovider.OpenStack && attestationVariant == variant.QEMUVTPM{}:
return openstack_QEMUVTPM.Copy()
case provider == cloudprovider.QEMU && attestationVariant == variant.QEMUTDX{}:
return qemu_QEMUTDX.Copy()
case provider == cloudprovider.QEMU && attestationVariant == variant.QEMUVTPM{}:
return qemu_QEMUVTPM.Copy()
default:
return nil
}
}
func checkLength(m map[uint32]Measurement) error {
var length int
for idx, measurement := range m {
if length == 0 {
length = len(measurement.Expected)
} else if len(measurement.Expected) != length {
return fmt.Errorf("inconsistent measurement length: index %d: expected %d, got %d", idx, length, len(measurement.Expected))
}
}
return nil
}
type encodedMeasurement struct {
Expected string `json:"expected" yaml:"expected"`
WarnOnly MeasurementValidationOption `json:"warnOnly" yaml:"warnOnly"`
}
// mYamlContent is the Content of a yaml.Node encoding of an M. It implements sort.Interface.
// The slice is filled like {key1, value1, key2, value2, ...}.
type mYamlContent []*yaml.Node
func (c mYamlContent) Len() int {
return len(c) / 2
}
func (c mYamlContent) Less(i, j int) bool {
lhs, err := strconv.Atoi(c[2*i].Value)
if err != nil {
panic(err)
}
rhs, err := strconv.Atoi(c[2*j].Value)
if err != nil {
panic(err)
}
return lhs < rhs
}
func (c mYamlContent) Swap(i, j int) {
// The slice is filled like {key1, value1, key2, value2, ...}.
// We need to swap both key and value.
c[2*i], c[2*j] = c[2*j], c[2*i]
c[2*i+1], c[2*j+1] = c[2*j+1], c[2*i+1]
}
// getFromURL fetches the content from the given URL and returns the content as a byte slice.
func getFromURL(ctx context.Context, client *http.Client, sourceURL *url.URL) ([]byte, error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, sourceURL.String(), http.NoBody)
if err != nil {
return nil, err
}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("http status code: %d", resp.StatusCode)
}
return io.ReadAll(resp.Body)
}
type cosignVerifier interface {
VerifySignature(content, signature []byte) error
}