Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions server/datastore/mysql/host_identity_scep.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@ package mysql

import (
"context"
"crypto/sha256"
"crypto/x509"
"database/sql"
"encoding/hex"
"encoding/pem"
"errors"
"fmt"

Expand Down Expand Up @@ -60,3 +64,54 @@ func (ds *Datastore) GetHostIdentityCertByName(ctx context.Context, name string)
}
return &hostIdentityCert, nil
}

// GetMDMSCEPCertBySerial looks up an MDM SCEP certificate by serial number
// and returns the device UUID it's associated with. This is used for iOS/iPadOS
// certificate-based authentication on the My Device page.
//
// This query uses the nano_cert_auth_associations table which maps device IDs to
// certificate hashes. The serial number lookup in scep_certificates provides
// the raw certificate data, but we need the nanomdm association to get the device UUID.
func (ds *Datastore) GetMDMSCEPCertBySerial(ctx context.Context, serialNumber uint64) (deviceUUID string, err error) {
// First get the certificate by serial
var certPEM string
err = sqlx.GetContext(ctx, ds.reader(ctx), &certPEM, `
SELECT certificate_pem
FROM scep_certificates
WHERE serial = ?
AND not_valid_after > NOW()
AND revoked = 0`, serialNumber)
switch {
case errors.Is(err, sql.ErrNoRows):
return "", notFound("MDM SCEP certificate")
case err != nil:
return "", err
}

// Calculate the SHA256 hash of the certificate the same way nanomdm does
// (see server/mdm/nanomdm/service/certauth/certauth.go HashCert function)
// The hash is calculated from cert.Raw (DER-encoded bytes), not the PEM string
block, _ := pem.Decode([]byte(certPEM))
if block == nil {
return "", errors.New("failed to decode PEM certificate")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return "", fmt.Errorf("failed to parse certificate: %w", err)
}
hashed := sha256.Sum256(cert.Raw)
hash := hex.EncodeToString(hashed[:])

// Look up the device UUID by certificate hash
err = sqlx.GetContext(ctx, ds.reader(ctx), &deviceUUID, `
SELECT id
FROM nano_cert_auth_associations
WHERE sha256 = ?`, hash)
switch {
case errors.Is(err, sql.ErrNoRows):
return "", notFound("MDM certificate association")
case err != nil:
return "", err
}
return deviceUUID, nil
}
3 changes: 3 additions & 0 deletions server/fleet/datastore.go
Original file line number Diff line number Diff line change
Expand Up @@ -2429,6 +2429,9 @@ type Datastore interface {
GetHostIdentityCertByName(ctx context.Context, name string) (*types.HostIdentityCertificate, error)
// UpdateHostIdentityCertHostIDBySerial updates the host ID associated with a certificate using its serial number.
UpdateHostIdentityCertHostIDBySerial(ctx context.Context, serialNumber uint64, hostID uint) error
// GetMDMSCEPCertBySerial looks up an MDM SCEP certificate by serial number and returns the device UUID.
// This is used for iOS/iPadOS certificate-based authentication.
GetMDMSCEPCertBySerial(ctx context.Context, serialNumber uint64) (deviceUUID string, err error)

// /////////////////////////////////////////////////////////////////////////////
// Certificate Authorities
Expand Down
12 changes: 12 additions & 0 deletions server/mock/datastore_mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -1569,6 +1569,8 @@ type GetHostIdentityCertByNameFunc func(ctx context.Context, name string) (*type

type UpdateHostIdentityCertHostIDBySerialFunc func(ctx context.Context, serialNumber uint64, hostID uint) error

type GetMDMSCEPCertBySerialFunc func(ctx context.Context, serialNumber uint64) (string, error)

type NewCertificateAuthorityFunc func(ctx context.Context, ca *fleet.CertificateAuthority) (*fleet.CertificateAuthority, error)

type GetCertificateAuthorityByIDFunc func(ctx context.Context, id uint, includeSecrets bool) (*fleet.CertificateAuthority, error)
Expand Down Expand Up @@ -3909,6 +3911,9 @@ type DataStore struct {
UpdateHostIdentityCertHostIDBySerialFunc UpdateHostIdentityCertHostIDBySerialFunc
UpdateHostIdentityCertHostIDBySerialFuncInvoked bool

GetMDMSCEPCertBySerialFunc GetMDMSCEPCertBySerialFunc
GetMDMSCEPCertBySerialFuncInvoked bool

NewCertificateAuthorityFunc NewCertificateAuthorityFunc
NewCertificateAuthorityFuncInvoked bool

Expand Down Expand Up @@ -9353,6 +9358,13 @@ func (s *DataStore) UpdateHostIdentityCertHostIDBySerial(ctx context.Context, se
return s.UpdateHostIdentityCertHostIDBySerialFunc(ctx, serialNumber, hostID)
}

func (s *DataStore) GetMDMSCEPCertBySerial(ctx context.Context, serialNumber uint64) (string, error) {
s.mu.Lock()
s.GetMDMSCEPCertBySerialFuncInvoked = true
s.mu.Unlock()
return s.GetMDMSCEPCertBySerialFunc(ctx, serialNumber)
}

func (s *DataStore) NewCertificateAuthority(ctx context.Context, ca *fleet.CertificateAuthority) (*fleet.CertificateAuthority, error) {
s.mu.Lock()
s.NewCertificateAuthorityFuncInvoked = true
Expand Down
8 changes: 4 additions & 4 deletions server/service/devices.go
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,8 @@ func (svc *Service) AuthenticateDeviceByCertificate(ctx context.Context, certSer
return nil, false, ctxerr.Wrap(ctx, fleet.NewAuthRequiredError("authentication error: missing host UUID"))
}

// Look up the certificate by serial number
cert, err := svc.ds.GetHostIdentityCertBySerialNumber(ctx, certSerial)
// Look up the MDM SCEP certificate by serial number to get the device UUID
certDeviceUUID, err := svc.ds.GetMDMSCEPCertBySerial(ctx, certSerial)
switch {
case err == nil:
// OK
Expand All @@ -284,8 +284,8 @@ func (svc *Service) AuthenticateDeviceByCertificate(ctx context.Context, certSer
return nil, false, ctxerr.Wrap(ctx, err, "lookup certificate by serial")
}

// Verify certificate matches the host UUID (CN should match UUID)
if cert.CommonName != hostUUID {
// Verify certificate's device UUID matches the requested host UUID
if certDeviceUUID != hostUUID {
return nil, false, ctxerr.Wrap(ctx, fleet.NewAuthRequiredError("authentication error: certificate does not match host"))
}

Expand Down
68 changes: 20 additions & 48 deletions server/service/devices_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"testing"
"time"

"github.com/fleetdm/fleet/v4/ee/server/service/hostidentity/types"
"github.com/fleetdm/fleet/v4/pkg/optjson"
"github.com/fleetdm/fleet/v4/server/fleet"
"github.com/fleetdm/fleet/v4/server/mock"
Expand Down Expand Up @@ -625,13 +624,9 @@ func TestAuthenticateDeviceByCertificate(t *testing.T) {
return &fleet.AppConfig{}, nil
}

ds.GetHostIdentityCertBySerialNumberFunc = func(ctx context.Context, serialNumber uint64) (*types.HostIdentityCertificate, error) {
ds.GetMDMSCEPCertBySerialFunc = func(ctx context.Context, serialNumber uint64) (string, error) {
require.Equal(t, certSerial, serialNumber)
return &types.HostIdentityCertificate{
SerialNumber: certSerial,
CommonName: hostUUID,
HostID: ptr.Uint(1),
}, nil
return hostUUID, nil
}

ds.HostByIdentifierFunc = func(ctx context.Context, identifier string) (*fleet.Host, error) {
Expand Down Expand Up @@ -662,12 +657,9 @@ func TestAuthenticateDeviceByCertificate(t *testing.T) {
return &fleet.AppConfig{}, nil
}

ds.GetHostIdentityCertBySerialNumberFunc = func(ctx context.Context, serialNumber uint64) (*types.HostIdentityCertificate, error) {
return &types.HostIdentityCertificate{
SerialNumber: certSerial,
CommonName: hostUUID,
HostID: ptr.Uint(2),
}, nil
ds.GetMDMSCEPCertBySerialFunc = func(ctx context.Context, serialNumber uint64) (string, error) {
require.Equal(t, certSerial, serialNumber)
return hostUUID, nil
}

ds.HostByIdentifierFunc = func(ctx context.Context, identifier string) (*fleet.Host, error) {
Expand Down Expand Up @@ -715,8 +707,8 @@ func TestAuthenticateDeviceByCertificate(t *testing.T) {
svc, ctx := newTestService(t, ds, nil, nil, &TestServerOpts{SkipCreateTestUsers: true})

certSerial := uint64(99999)
ds.GetHostIdentityCertBySerialNumberFunc = func(ctx context.Context, serialNumber uint64) (*types.HostIdentityCertificate, error) {
return nil, &mock.Error{Message: "certificate not found"}
ds.GetMDMSCEPCertBySerialFunc = func(ctx context.Context, serialNumber uint64) (string, error) {
return "", &mock.Error{Message: "certificate not found"}
}

host, debug, err := svc.AuthenticateDeviceByCertificate(ctx, certSerial, "test-uuid")
Expand All @@ -727,19 +719,15 @@ func TestAuthenticateDeviceByCertificate(t *testing.T) {
require.ErrorAs(t, err, &authErr)
})

t.Run("error - certificate CN does not match UUID", func(t *testing.T) {
t.Run("error - device UUID mismatch", func(t *testing.T) {
ds := new(mock.Store)
svc, ctx := newTestService(t, ds, nil, nil, &TestServerOpts{SkipCreateTestUsers: true})

certSerial := uint64(12345)
hostUUID := "test-uuid"

ds.GetHostIdentityCertBySerialNumberFunc = func(ctx context.Context, serialNumber uint64) (*types.HostIdentityCertificate, error) {
return &types.HostIdentityCertificate{
SerialNumber: certSerial,
CommonName: "different-uuid",
HostID: ptr.Uint(1),
}, nil
ds.GetMDMSCEPCertBySerialFunc = func(ctx context.Context, serialNumber uint64) (string, error) {
return "different-uuid", nil
}

host, debug, err := svc.AuthenticateDeviceByCertificate(ctx, certSerial, hostUUID)
Expand All @@ -757,12 +745,8 @@ func TestAuthenticateDeviceByCertificate(t *testing.T) {
certSerial := uint64(12345)
hostUUID := "nonexistent-uuid"

ds.GetHostIdentityCertBySerialNumberFunc = func(ctx context.Context, serialNumber uint64) (*types.HostIdentityCertificate, error) {
return &types.HostIdentityCertificate{
SerialNumber: certSerial,
CommonName: hostUUID,
HostID: ptr.Uint(1),
}, nil
ds.GetMDMSCEPCertBySerialFunc = func(ctx context.Context, serialNumber uint64) (string, error) {
return hostUUID, nil
}

ds.HostByIdentifierFunc = func(ctx context.Context, identifier string) (*fleet.Host, error) {
Expand All @@ -784,12 +768,8 @@ func TestAuthenticateDeviceByCertificate(t *testing.T) {
certSerial := uint64(12345)
hostUUID := "test-uuid-macos"

ds.GetHostIdentityCertBySerialNumberFunc = func(ctx context.Context, serialNumber uint64) (*types.HostIdentityCertificate, error) {
return &types.HostIdentityCertificate{
SerialNumber: certSerial,
CommonName: hostUUID,
HostID: ptr.Uint(1),
}, nil
ds.GetMDMSCEPCertBySerialFunc = func(ctx context.Context, serialNumber uint64) (string, error) {
return hostUUID, nil
}

ds.HostByIdentifierFunc = func(ctx context.Context, identifier string) (*fleet.Host, error) {
Expand All @@ -815,12 +795,8 @@ func TestAuthenticateDeviceByCertificate(t *testing.T) {
certSerial := uint64(12345)
hostUUID := "test-uuid-windows"

ds.GetHostIdentityCertBySerialNumberFunc = func(ctx context.Context, serialNumber uint64) (*types.HostIdentityCertificate, error) {
return &types.HostIdentityCertificate{
SerialNumber: certSerial,
CommonName: hostUUID,
HostID: ptr.Uint(1),
}, nil
ds.GetMDMSCEPCertBySerialFunc = func(ctx context.Context, serialNumber uint64) (string, error) {
return hostUUID, nil
}

ds.HostByIdentifierFunc = func(ctx context.Context, identifier string) (*fleet.Host, error) {
Expand All @@ -844,8 +820,8 @@ func TestAuthenticateDeviceByCertificate(t *testing.T) {
svc, ctx := newTestService(t, ds, nil, nil, &TestServerOpts{SkipCreateTestUsers: true})

certSerial := uint64(12345)
ds.GetHostIdentityCertBySerialNumberFunc = func(ctx context.Context, serialNumber uint64) (*types.HostIdentityCertificate, error) {
return nil, errors.New("database connection error")
ds.GetMDMSCEPCertBySerialFunc = func(ctx context.Context, serialNumber uint64) (string, error) {
return "", errors.New("database connection error")
}

host, debug, err := svc.AuthenticateDeviceByCertificate(ctx, certSerial, "test-uuid")
Expand All @@ -862,12 +838,8 @@ func TestAuthenticateDeviceByCertificate(t *testing.T) {
certSerial := uint64(12345)
hostUUID := "test-uuid"

ds.GetHostIdentityCertBySerialNumberFunc = func(ctx context.Context, serialNumber uint64) (*types.HostIdentityCertificate, error) {
return &types.HostIdentityCertificate{
SerialNumber: certSerial,
CommonName: hostUUID,
HostID: ptr.Uint(1),
}, nil
ds.GetMDMSCEPCertBySerialFunc = func(ctx context.Context, serialNumber uint64) (string, error) {
return hostUUID, nil
}

ds.HostByIdentifierFunc = func(ctx context.Context, identifier string) (*fleet.Host, error) {
Expand Down
Loading
Loading