Skip to content
Open
14 changes: 11 additions & 3 deletions cmd/attested-get/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ var flags []cli.Flag = []cli.Flag{
Value: false,
Usage: "log debug messages",
},
&cli.BoolFlag{
Name: "verify-ak-certificate",
Value: false,
EnvVars: []string{"VERIFY_AK_CERTIFICATE"},
Usage: "verify Azure TDX vTPM attestation key certificate chain",
},
}

func main() {
Expand All @@ -108,10 +114,11 @@ func main() {
}

// createAzureTDXValidator creates an Azure TDX validator without required measurements
func createAzureTDXValidator(log *slog.Logger, overrideAzurev6Tcbinfo bool) atls.Validator {
func createAzureTDXValidator(log *slog.Logger, overrideAzurev6Tcbinfo bool, verifyAKCertificate bool) atls.Validator {
attConfig := config.DefaultForAzureTDX()
attConfig.SetMeasurements(measurements.M{})
validator := azure_tdx.NewValidator(attConfig, proxy.AttestationLogger{Log: log})
validator.SetVerifyAKCertificate(verifyAKCertificate)
if overrideAzurev6Tcbinfo {
azure_tcbinfo_override.OverrideAzureValidatorsForV6SEAMLoader(log, []atls.Validator{validator})
}
Expand All @@ -132,6 +139,7 @@ func runClient(cCtx *cli.Context) (err error) {
attestationTypeStr := cCtx.String("attestation-type")
expectedMeasurementsPath := cCtx.String("expected-measurements")
overrideAzurev6Tcbinfo := cCtx.Bool("override-azurev6-tcbinfo")
verifyAKCertificate := cCtx.Bool("verify-ak-certificate")

// Setup logging
log := common.SetupLogger(&common.LoggingOpts{
Expand All @@ -156,13 +164,13 @@ func runClient(cCtx *cli.Context) (err error) {
var validators []atls.Validator
switch attestationType {
case proxy.AttestationAzureTDX:
validators = append(validators, createAzureTDXValidator(log, overrideAzurev6Tcbinfo))
validators = append(validators, createAzureTDXValidator(log, overrideAzurev6Tcbinfo, verifyAKCertificate))
case proxy.AttestationDCAPTDX:
validators = append(validators, createDCAPTDXValidator(log))
case proxy.AttestationAuto:
// In auto mode, add all validators to support any attestation type
log.Info("Auto mode: creating validators for all supported attestation types")
validators = append(validators, createAzureTDXValidator(log, overrideAzurev6Tcbinfo))
validators = append(validators, createAzureTDXValidator(log, overrideAzurev6Tcbinfo, verifyAKCertificate))
validators = append(validators, createDCAPTDXValidator(log))
default:
log.Error("unsupported attestation type, see --help for available options")
Expand Down
9 changes: 8 additions & 1 deletion cmd/proxy-client/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,12 @@ var flags []cli.Flag = []cli.Flag{
EnvVars: []string{"DEV_DUMMY_DCAP"},
Usage: "URL of the remote dummy DCAP service. Only with --client-attestation-type dummy.",
},
&cli.BoolFlag{
Name: "verify-ak-certificate",
EnvVars: []string{"VERIFY_AK_CERTIFICATE"},
Value: false,
Usage: "verify Azure TDX vTPM attestation key certificate chain",
},
}

func main() {
Expand Down Expand Up @@ -103,6 +109,7 @@ func runClient(cCtx *cli.Context) error {
devDummyDcapURL := cCtx.String("dev-dummy-dcap")

verifyTLS := cCtx.Bool("verify-tls")
verifyAKCertificate := cCtx.Bool("verify-ak-certificate")

log := common.SetupLogger(&common.LoggingOpts{
Debug: logDebug,
Expand Down Expand Up @@ -145,7 +152,7 @@ func runClient(cCtx *cli.Context) error {
}
}

validators, err := proxy.CreateAttestationValidatorsFromFile(log, serverMeasurements)
validators, err := proxy.CreateAttestationValidatorsFromFile(log, serverMeasurements, verifyAKCertificate)
if err != nil {
log.Error("could not create attestation validators from file", "err", err)
return err
Expand Down
9 changes: 8 additions & 1 deletion cmd/proxy-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,12 @@ var flags []cli.Flag = []cli.Flag{
EnvVars: []string{"DEV_DUMMY_DCAP"},
Usage: "URL of the remote dummy DCAP service. Only with --server-attestation-type dummy.",
},
&cli.BoolFlag{
Name: "verify-ak-certificate",
EnvVars: []string{"VERIFY_AK_CERTIFICATE"},
Value: false,
Usage: "verify Azure TDX vTPM attestation key certificate chain",
},
}

var log *slog.Logger
Expand Down Expand Up @@ -120,6 +126,7 @@ func runServer(cCtx *cli.Context) error {
overrideAzurev6Tcbinfo := cCtx.Bool("override-azurev6-tcbinfo")
logJSON := cCtx.Bool("log-json")
logDebug := cCtx.Bool("log-debug")
verifyAKCertificate := cCtx.Bool("verify-ak-certificate")
tdx.SetLogDcapQuote(cCtx.Bool("log-dcap-quote"))

serverAttestationTypeFlag := cCtx.String("server-attestation-type")
Expand Down Expand Up @@ -148,7 +155,7 @@ func runServer(cCtx *cli.Context) error {
return errors.New("not all of --tls-certificate-path and --tls-private-key-path specified")
}

validators, err := proxy.CreateAttestationValidatorsFromFile(log, clientMeasurements)
validators, err := proxy.CreateAttestationValidatorsFromFile(log, clientMeasurements, verifyAKCertificate)
if err != nil {
log.Error("could not create attestation validators from file", "err", err)
return err
Expand Down
97 changes: 97 additions & 0 deletions internal/attestation/azure/tdx/issuer.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package tdx
import (
"bytes"
"context"
"crypto/x509"
"encoding/base64"
"encoding/binary"
"encoding/json"
Expand All @@ -27,6 +28,7 @@ import (
const (
imdsURL = "http://169.254.169.254/acc/tdquote"
indexHCLReport = 0x1400001
tpmAkCertIdx = 0x1C101D0
hclDataOffset = 1216
hclReportTypeOffset = 8
hclReportTypeOffsetStart = hclDataOffset + hclReportTypeOffset
Expand All @@ -53,6 +55,7 @@ type Issuer struct {
*vtpm.Issuer

quoteGetter quoteGetter
log attestation.Logger
}

// NewIssuer initializes a new Azure Issuer.
Expand All @@ -61,6 +64,7 @@ func NewIssuer(log attestation.Logger) *Issuer {
quoteGetter: imdsQuoteGetter{
client: &http.Client{Transport: &http.Transport{Proxy: nil}},
},
log: log,
}

i.Issuer = vtpm.NewIssuer(
Expand Down Expand Up @@ -91,9 +95,17 @@ func (i *Issuer) getInstanceInfo(ctx context.Context, tpm io.ReadWriteCloser, _
return nil, fmt.Errorf("getting quote: %w", err)
}

// Read and extract the vTPM AK certificate. If this fails, we log a warning and continue without it
akCert, err := i.readAKCertificateFromTPM(tpm)
if err != nil {
i.log.Warn(fmt.Sprintf("Failed to read AK certificate: %v", err))
akCert = nil
}

Comment on lines +99 to +104
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We dont return an error here on purpose to allow backward compatibility as well as leave it up to the validator to handle this.
If the validator wants to only check the attestation quote, then it would still pass. If the validator wants to also verify certificates, then this will be checked accordingly

instanceInfo := InstanceInfo{
AttestationReport: quote,
RuntimeData: runtimeData,
AkCert: akCert, // Use the clean certificate
}
instanceInfoJSON, err := json.Marshal(instanceInfo)
if err != nil {
Expand All @@ -102,6 +114,91 @@ func (i *Issuer) getInstanceInfo(ctx context.Context, tpm io.ReadWriteCloser, _
return instanceInfoJSON, nil
}

// readAKCertificateFromTPM reads and extracts the attestation key certificate from TPM.
// Returns the clean DER-encoded certificate or an error if reading/extraction fails.
func (i *Issuer) readAKCertificateFromTPM(tpm io.ReadWriteCloser) ([]byte, error) {
certDERRaw, err := tpm2.NVReadEx(tpm, tpmAkCertIdx, tpm2.HandleOwner, "", 0)
if err != nil {
return nil, fmt.Errorf("reading attestation key certificate from TPM: %w", err)
}

i.log.Debug(fmt.Sprintf("Read %d bytes from TPM AK cert index", len(certDERRaw)))

// The TPM NV index contains trailing data. We need to extract just the certificate.
// X.509 DER certificates start with 0x30 (SEQUENCE) followed by length encoding
cleanCertDER, err := extractDERCertificate(certDERRaw)
if err != nil {
return nil, fmt.Errorf("extracting certificate from TPM data: %w", err)
}

i.log.Debug(fmt.Sprintf("Extracted %d bytes certificate from %d bytes TPM data", len(cleanCertDER), len(certDERRaw)))

// Verify we can parse the extracted certificate
_, err = x509.ParseCertificate(cleanCertDER)
if err != nil {
return nil, fmt.Errorf("parsing extracted attestation key certificate: %w", err)
}

return cleanCertDER, nil
}

// extractDERCertificate extracts a clean X.509 DER certificate from raw TPM data.
// The TPM NV index may contain trailing data, so this function parses the DER
// structure to extract exactly the certificate bytes.
//
// X.509 DER certificates use ASN.1 encoding and start with:
// - Tag: 0x30 (SEQUENCE)
// - Length: encoded in one of three forms (short, long-1byte, long-2byte)
// - Content: the certificate data
func extractDERCertificate(certDERRaw []byte) ([]byte, error) {
if len(certDERRaw) < 4 {
return nil, fmt.Errorf("certificate data too short: %d bytes", len(certDERRaw))
}

// Verify it starts with DER SEQUENCE tag (0x30)
if certDERRaw[0] != 0x30 {
return nil, fmt.Errorf("invalid certificate format: does not start with DER SEQUENCE tag (0x30), got 0x%02x", certDERRaw[0])
}

// Parse the DER length encoding to determine certificate size
var certLen int
lengthByte := certDERRaw[1]

if lengthByte < 0x80 {
// Short form: length fits in 7 bits (0-127 bytes)
// Format: 0x30 <length> <data...>
certLen = int(lengthByte) + 2 // +2 for tag and length bytes
} else if lengthByte == 0x81 {
// Long form with 1 length byte (128-255 bytes)
// Format: 0x30 0x81 <length> <data...>
if len(certDERRaw) < 3 {
return nil, fmt.Errorf("truncated DER encoding: expected length byte")
}
certLen = int(certDERRaw[2]) + 3 // +3 for tag, 0x81, and length byte
} else if lengthByte == 0x82 {
// Long form with 2 length bytes (256-65535 bytes)
// Format: 0x30 0x82 <high-byte> <low-byte> <data...>
if len(certDERRaw) < 4 {
return nil, fmt.Errorf("truncated DER encoding: expected 2 length bytes")
}
certLen = (int(certDERRaw[2]) << 8) | int(certDERRaw[3])
certLen += 4 // +4 for tag, 0x82, and two length bytes
} else {
return nil, fmt.Errorf("unsupported DER length encoding: 0x%02x", lengthByte)
}

// Validate the calculated length
if certLen <= 0 {
return nil, fmt.Errorf("invalid certificate length: %d", certLen)
}
if certLen > len(certDERRaw) {
return nil, fmt.Errorf("invalid certificate length: %d exceeds available data (%d bytes)", certLen, len(certDERRaw))
}

// Extract the exact certificate bytes
return certDERRaw[:certLen], nil
}

func parseHCLReport(report []byte) (hwReport, runtimeData []byte, err error) {
// First, ensure the extracted report is actually for TDX
if len(report) < hclReportTypeOffsetStart+4 {
Expand Down
1 change: 1 addition & 0 deletions internal/attestation/azure/tdx/tdx.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ package tdx
type InstanceInfo struct {
AttestationReport []byte
RuntimeData []byte
AkCert []byte `json:"akCert,omitempty"`
}
Loading