Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion server/datastore/mysql/microsoft_mdm.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ func (ds *Datastore) MDMWindowsGetEnrolledDeviceWithDeviceID(ctx context.Context
enroll_proto_version,
enroll_client_version,
not_in_oobe,
awaiting_configuration,
awaiting_configuration_at,
credentials_hash,
credentials_acknowledged,
created_at,
Expand Down Expand Up @@ -99,6 +101,8 @@ func (ds *Datastore) MDMWindowsGetEnrolledDeviceWithHostUUID(ctx context.Context
enroll_proto_version,
enroll_client_version,
not_in_oobe,
awaiting_configuration,
awaiting_configuration_at,
credentials_hash,
credentials_acknowledged,
created_at,
Expand Down Expand Up @@ -132,11 +136,13 @@ func (ds *Datastore) MDMWindowsInsertEnrolledDevice(ctx context.Context, device
enroll_proto_version,
enroll_client_version,
not_in_oobe,
awaiting_configuration,
awaiting_configuration_at,
host_uuid,
credentials_hash,
credentials_acknowledged)
VALUES
(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
ON DUPLICATE KEY UPDATE
mdm_device_id = VALUES(mdm_device_id),
device_state = VALUES(device_state),
Expand All @@ -147,6 +153,8 @@ func (ds *Datastore) MDMWindowsInsertEnrolledDevice(ctx context.Context, device
enroll_proto_version = VALUES(enroll_proto_version),
enroll_client_version = VALUES(enroll_client_version),
not_in_oobe = VALUES(not_in_oobe),
awaiting_configuration = VALUES(awaiting_configuration),
awaiting_configuration_at = VALUES(awaiting_configuration_at),
host_uuid = VALUES(host_uuid),
credentials_hash = VALUES(credentials_hash),
credentials_acknowledged = VALUES(credentials_acknowledged)
Expand All @@ -164,6 +172,8 @@ func (ds *Datastore) MDMWindowsInsertEnrolledDevice(ctx context.Context, device
device.MDMEnrollProtoVersion,
device.MDMEnrollClientVersion,
device.MDMNotInOOBE,
device.AwaitingConfiguration,
device.AwaitingConfigurationAt,
device.HostUUID,
device.CredentialsHash,
device.CredentialsAcknowledged)
Expand Down
25 changes: 25 additions & 0 deletions server/datastore/mysql/microsoft_mdm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ func testMDMWindowsEnrolledDevice(t *testing.T, ds *Datastore) {
require.NotZero(t, gotEnrolledDevice.CreatedAt)
require.Equal(t, enrolledDevice.MDMDeviceID, gotEnrolledDevice.MDMDeviceID)
require.Equal(t, enrolledDevice.MDMHardwareID, gotEnrolledDevice.MDMHardwareID)
require.Equal(t, fleet.WindowsMDMAwaitingConfigurationNone, gotEnrolledDevice.AwaitingConfiguration)
require.Nil(t, gotEnrolledDevice.AwaitingConfigurationAt)

err = ds.MDMWindowsDeleteEnrolledDeviceOnReenrollment(ctx, enrolledDevice.MDMHardwareID)
require.NoError(t, err)
Expand Down Expand Up @@ -127,6 +129,29 @@ func testMDMWindowsEnrolledDevice(t *testing.T, ds *Datastore) {

err = ds.MDMWindowsDeleteEnrolledDeviceOnReenrollment(ctx, enrolledDevice.MDMHardwareID)
require.ErrorAs(t, err, &nfe)

// Test that awaiting configuration is persisted and updated on upsert.
now := time.Now().UTC()
enrolledDevice.AwaitingConfiguration = fleet.WindowsMDMAwaitingConfigurationPending
enrolledDevice.AwaitingConfigurationAt = &now
err = ds.MDMWindowsInsertEnrolledDevice(ctx, enrolledDevice)
require.NoError(t, err)

gotEnrolledDevice, err = ds.MDMWindowsGetEnrolledDeviceWithDeviceID(ctx, enrolledDevice.MDMDeviceID)
require.NoError(t, err)
require.Equal(t, fleet.WindowsMDMAwaitingConfigurationPending, gotEnrolledDevice.AwaitingConfiguration)
require.NotNil(t, gotEnrolledDevice.AwaitingConfigurationAt)

// Re-enroll clears awaiting configuration via upsert.
enrolledDevice.AwaitingConfiguration = fleet.WindowsMDMAwaitingConfigurationNone
enrolledDevice.AwaitingConfigurationAt = nil
err = ds.MDMWindowsInsertEnrolledDevice(ctx, enrolledDevice)
require.NoError(t, err)

gotEnrolledDevice, err = ds.MDMWindowsGetEnrolledDeviceWithDeviceID(ctx, enrolledDevice.MDMDeviceID)
require.NoError(t, err)
require.Equal(t, fleet.WindowsMDMAwaitingConfigurationNone, gotEnrolledDevice.AwaitingConfiguration)
require.Nil(t, gotEnrolledDevice.AwaitingConfigurationAt)
}

func testMDMWindowsDiskEncryption(t *testing.T, ds *Datastore) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package tables

import (
"database/sql"
"fmt"
)

func init() {
MigrationClient.AddMigration(Up_20260407214038, Down_20260407214038)
}

func Up_20260407214038(tx *sql.Tx) error {
if columnExists(tx, "mdm_windows_enrollments", "awaiting_configuration") {
return nil
}
_, err := tx.Exec(`
ALTER TABLE mdm_windows_enrollments
ADD COLUMN awaiting_configuration TINYINT(1) NOT NULL DEFAULT 0,
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For whatever it's worth my initial thinking for this was this was almost a state based thing. I made it a tinyint to not make the windows enrollments rows too big but in Magnus's POC he did the same where essentially it was 0=not waiting, 1=initial waiting state, 2=later state which is basically once the full initialization/enrollment has completed and things have actually started

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. Updated to 3 states

ADD COLUMN awaiting_configuration_at DATETIME(6) DEFAULT NULL
`)
if err != nil {
return fmt.Errorf("failed to add awaiting_configuration columns to mdm_windows_enrollments: %w", err)
}
return nil
}

func Down_20260407214038(tx *sql.Tx) error {
return nil
}
6 changes: 4 additions & 2 deletions server/datastore/mysql/schema.sql

Large diffs are not rendered by default.

58 changes: 42 additions & 16 deletions server/fleet/microsoft_mdm.go
Original file line number Diff line number Diff line change
Expand Up @@ -820,23 +820,49 @@ func (msg WapProvisioningDoc) GetEncodedB64Representation() (string, error) {
/// MDMWindowsEnrolledDevice type
/// Contains the information of the enrolled Windows host

// WindowsMDMEnrollType represents how a Windows device enrolled in MDM.
type WindowsMDMEnrollType int

const (
// WindowsMDMEnrollTypeProgrammatic is enrollment via fleetd/orbit using an orbit node key.
WindowsMDMEnrollTypeProgrammatic WindowsMDMEnrollType = iota
// WindowsMDMEnrollTypeAutomatic is enrollment via Azure JWT or WSTEP STS auth token (Autopilot, Entra join,
// Settings app).
WindowsMDMEnrollTypeAutomatic
)

// WindowsMDMAwaitingConfiguration represents the state of a Windows device's setup experience.
type WindowsMDMAwaitingConfiguration uint

const (
// WindowsMDMAwaitingConfigurationNone means the device is not awaiting configuration (default, or setup complete/failed).
WindowsMDMAwaitingConfigurationNone WindowsMDMAwaitingConfiguration = 0
// WindowsMDMAwaitingConfigurationPending means the device enrolled via autopilot in OOBE and is waiting for orbit
// to register and setup experience items to be enqueued.
WindowsMDMAwaitingConfigurationPending WindowsMDMAwaitingConfiguration = 1
// WindowsMDMAwaitingConfigurationActive means ESP commands have been enqueued and setup progress is being tracked.
WindowsMDMAwaitingConfigurationActive WindowsMDMAwaitingConfiguration = 2
)

type MDMWindowsEnrolledDevice struct {
ID uint `db:"id"`
HostUUID string `db:"host_uuid"`
MDMDeviceID string `db:"mdm_device_id"`
MDMHardwareID string `db:"mdm_hardware_id"`
MDMDeviceState string `db:"device_state"`
MDMDeviceType string `db:"device_type"`
MDMDeviceName string `db:"device_name"`
MDMEnrollType string `db:"enroll_type"`
MDMEnrollUserID string `db:"enroll_user_id"`
MDMEnrollProtoVersion string `db:"enroll_proto_version"`
MDMEnrollClientVersion string `db:"enroll_client_version"`
MDMNotInOOBE bool `db:"not_in_oobe"`
CredentialsHash *[]byte `db:"credentials_hash"`
CredentialsAcknowledged bool `db:"credentials_acknowledged"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
ID uint `db:"id"`
HostUUID string `db:"host_uuid"`
MDMDeviceID string `db:"mdm_device_id"`
MDMHardwareID string `db:"mdm_hardware_id"`
MDMDeviceState string `db:"device_state"`
MDMDeviceType string `db:"device_type"`
MDMDeviceName string `db:"device_name"`
MDMEnrollType string `db:"enroll_type"`
MDMEnrollUserID string `db:"enroll_user_id"`
MDMEnrollProtoVersion string `db:"enroll_proto_version"`
MDMEnrollClientVersion string `db:"enroll_client_version"`
MDMNotInOOBE bool `db:"not_in_oobe"`
AwaitingConfiguration WindowsMDMAwaitingConfiguration `db:"awaiting_configuration"`
AwaitingConfigurationAt *time.Time `db:"awaiting_configuration_at"`
CredentialsHash *[]byte `db:"credentials_hash"`
CredentialsAcknowledged bool `db:"credentials_acknowledged"`
CreatedAt time.Time `db:"created_at"`
UpdatedAt time.Time `db:"updated_at"`
}

func (e MDMWindowsEnrolledDevice) AuthzType() string {
Expand Down
64 changes: 41 additions & 23 deletions server/service/microsoft_mdm.go
Original file line number Diff line number Diff line change
Expand Up @@ -959,14 +959,14 @@ func mdmMicrosoftTOSEndpoint(ctx context.Context, request interface{}, svc fleet
// authBinarySecurityToken checks if the provided token is valid. For programmatic enrollment, it
// returns the orbit node key and host uuid. For automatic enrollment, it returns only the UPN (the
// host uuid will be an empty string).
func (svc *Service) authBinarySecurityToken(ctx context.Context, authToken *fleet.HeaderBinarySecurityToken) (claim string, hostUUID string, err error) {
func (svc *Service) authBinarySecurityToken(ctx context.Context, authToken *fleet.HeaderBinarySecurityToken) (claim string, hostUUID string, enrollType fleet.WindowsMDMEnrollType, err error) {
if authToken == nil {
return "", "", errors.New("authToken is empty")
return "", "", 0, errors.New("authToken is empty")
}

err = authToken.IsValidToken()
if err != nil {
return "", "", errors.New("authToken is not valid")
return "", "", 0, errors.New("authToken is not valid")
}

// Tokens that were generated by enrollment client
Expand All @@ -975,69 +975,72 @@ func (svc *Service) authBinarySecurityToken(ctx context.Context, authToken *flee
// Getting the Binary Security Token Payload
binSecToken, err := NewBinarySecurityTokenPayload(authToken.Content)
if err != nil {
return "", "", fmt.Errorf("token creation error %v", err)
return "", "", 0, fmt.Errorf("token creation error %v", err)
}

// Validating the Binary Security Token Payload
err = binSecToken.IsValidToken()
if err != nil {
return "", "", fmt.Errorf("invalid token data %v", err)
return "", "", 0, fmt.Errorf("invalid token data %v", err)
}

// Validating the Binary Security Token Type used on Programmatic Enrollments
if binSecToken.Type == mdm_types.WindowsMDMProgrammaticEnrollmentType {
host, err := svc.ds.LoadHostByOrbitNodeKey(ctx, binSecToken.Payload.OrbitNodeKey)
if err != nil {
return "", "", fmt.Errorf("host data cannot be found %v", err)
return "", "", 0, fmt.Errorf("host data cannot be found %v", err)
}
if host == nil {
return "", "", 0, errors.New("host not found for orbit node key")
}

mdmInfo, err := svc.ds.GetHostMDM(ctx, host.ID)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return "", "", errors.New("unable to retrieve host mdm info")
return "", "", 0, errors.New("unable to retrieve host mdm info")
}

// This ensures that only hosts that are eligible for Windows enrollment can be enrolled
if !isEligibleForWindowsMDMEnrollment(host, mdmInfo) {
return "", "", errors.New("host is not elegible for Windows MDM enrollment")
return "", "", 0, errors.New("host is not elegible for Windows MDM enrollment")
}

// No errors, token is authorized
return binSecToken.Payload.OrbitNodeKey, host.UUID, nil
return binSecToken.Payload.OrbitNodeKey, host.UUID, fleet.WindowsMDMEnrollTypeProgrammatic, nil
}

// Validating the Binary Security Token Type used on Automatic Enrollments (returned by STS Auth Endpoint)
if binSecToken.Type == mdm_types.WindowsMDMAutomaticEnrollmentType {
upnToken, err := svc.wstepCertManager.GetSTSAuthTokenUPNClaim(binSecToken.Payload.AuthToken)
if err != nil {
return "", "", ctxerr.Wrap(ctx, err, "issue retrieving UPN from Auth token")
return "", "", 0, ctxerr.Wrap(ctx, err, "issue retrieving UPN from Auth token")
}

// No errors, token is authorized
return upnToken, "", nil
return upnToken, "", fleet.WindowsMDMEnrollTypeAutomatic, nil
}
}

// Validating the Binary Security Token Type used on Automatic Enrollments
if authToken.IsAzureJWTToken() {
appConfig, err := svc.ds.AppConfig(ctx)
if err != nil {
return "", "", ctxerr.Wrap(ctx, err, "retrieving app config for auth token validation")
return "", "", 0, ctxerr.Wrap(ctx, err, "retrieving app config for auth token validation")
}

entraTenantIDs := appConfig.MDM.WindowsEntraTenantIDs.Value
if len(entraTenantIDs) == 0 {
return "", "", ctxerr.New(ctx, "no entra tenant IDs configured for automatic enrollment")
return "", "", 0, ctxerr.New(ctx, "no entra tenant IDs configured for automatic enrollment")
}
expectedURL := appConfig.ServerSettings.ServerURL
expectedURLParsed, err := url.Parse(expectedURL)
if err != nil {
return "", "", ctxerr.Wrap(ctx, err, "parsing server URL for auth token validation")
return "", "", 0, ctxerr.Wrap(ctx, err, "parsing server URL for auth token validation")
}

// Validate the JWT Auth token by retreving its claims
tokenData, err := microsoft_mdm.GetAzureAuthTokenClaims(ctx, authToken.Content)
if err != nil {
return "", "", fmt.Errorf("binary security token claim failed: %v", err)
return "", "", 0, fmt.Errorf("binary security token claim failed: %v", err)
}

hasExpectedAudience := false
Expand All @@ -1058,20 +1061,20 @@ func (svc *Service) authBinarySecurityToken(ctx context.Context, authToken *flee
"expected_host", expectedURLParsed.Host,
"token_audiences", strings.Join(tokenData.Audience, ","),
)
return "", "", ctxerr.Errorf(ctx, "token audience is not authorized")
return "", "", 0, ctxerr.Errorf(ctx, "token audience is not authorized")
}
if !slices.Contains(entraTenantIDs, tokenData.TenantID) {
svc.logger.ErrorContext(ctx, "unexpected token tenant in AzureAD Binary Security Token",
"token_tenant", tokenData.TenantID,
)
return "", "", ctxerr.New(ctx, "token tenant is not authorized")
return "", "", 0, ctxerr.New(ctx, "token tenant is not authorized")
}

// No errors, token is authorized
return tokenData.UPN, "", nil
return tokenData.UPN, "", fleet.WindowsMDMEnrollTypeAutomatic, nil
}

return "", "", ctxerr.New(ctx, "token is not authorized")
return "", "", 0, ctxerr.New(ctx, "token is not authorized")
}

// ProcessMDMMicrosoftDiscovery handles the Discovery message validation and response
Expand Down Expand Up @@ -1192,7 +1195,7 @@ func (svc *Service) GetMDMWindowsPolicyResponse(ctx context.Context, authToken *
}

// Validate the binary security token
_, _, err := svc.authBinarySecurityToken(ctx, authToken)
_, _, _, err := svc.authBinarySecurityToken(ctx, authToken)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "validate binary security token")
}
Expand All @@ -1218,7 +1221,7 @@ func (svc *Service) GetMDMWindowsEnrollResponse(ctx context.Context, secTokenMsg
}

// Auth the binary security token
userID, hostUUID, err := svc.authBinarySecurityToken(ctx, authToken)
userID, hostUUID, enrollType, err := svc.authBinarySecurityToken(ctx, authToken)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "validate binary security token")
}
Expand Down Expand Up @@ -1253,7 +1256,7 @@ func (svc *Service) GetMDMWindowsEnrollResponse(ctx context.Context, secTokenMsg
//
// This method also creates the relevant enrollment activity as it has
// access to the device information.
err = svc.storeWindowsMDMEnrolledDevice(ctx, userID, hostUUID, secTokenMsg, credentialsHash)
err = svc.storeWindowsMDMEnrolledDevice(ctx, userID, hostUUID, enrollType, secTokenMsg, credentialsHash)
if err != nil {
return nil, ctxerr.Wrap(ctx, err, "enrolled device information cannot be stored")
}
Expand Down Expand Up @@ -2040,7 +2043,7 @@ func (svc *Service) getDeviceProvisioningInformation(ctx context.Context, secTok
}

// storeWindowsMDMEnrolledDevice stores the device information to the list of MDM enrolled devices
func (svc *Service) storeWindowsMDMEnrolledDevice(ctx context.Context, userID string, hostUUID string, secTokenMsg *fleet.RequestSecurityToken, credentialsHash []byte) error {
func (svc *Service) storeWindowsMDMEnrolledDevice(ctx context.Context, userID string, hostUUID string, enrollType fleet.WindowsMDMEnrollType, secTokenMsg *fleet.RequestSecurityToken, credentialsHash []byte) error {
const (
error_tag = "windows MDM enrolled storage: "
)
Expand Down Expand Up @@ -2096,6 +2099,19 @@ func (svc *Service) storeWindowsMDMEnrolledDevice(ctx context.Context, userID st
reqNotInOOBE = true
}

// Determine if the device is awaiting configuration. Set to Pending when the enrollment is
// automatic (Autopilot via JWT/WSTEP, not orbit node key) AND the device is in OOBE
// (NotInOobe is false, since the field name is inverted). Later phases transition to Active
// (ESP commands enqueued) and back to None (setup complete/failed).
awaitingConfiguration := fleet.WindowsMDMAwaitingConfigurationNone
var awaitingConfigurationAt *time.Time
isInOOBE := !reqNotInOOBE
if enrollType == fleet.WindowsMDMEnrollTypeAutomatic && isInOOBE {
awaitingConfiguration = fleet.WindowsMDMAwaitingConfigurationPending
now := time.Now().UTC()
awaitingConfigurationAt = &now
}

// Getting the Windows Enrolled Device Information
enrolledDevice := &fleet.MDMWindowsEnrolledDevice{
MDMDeviceID: reqDeviceID,
Expand All @@ -2108,6 +2124,8 @@ func (svc *Service) storeWindowsMDMEnrolledDevice(ctx context.Context, userID st
MDMEnrollProtoVersion: reqEnrollVersion,
MDMEnrollClientVersion: reqAppVersion,
MDMNotInOOBE: reqNotInOOBE,
AwaitingConfiguration: awaitingConfiguration,
AwaitingConfigurationAt: awaitingConfigurationAt,
HostUUID: hostUUID,
CredentialsHash: &credentialsHash,
CredentialsAcknowledged: true,
Expand Down
Loading