diff --git a/orbit/changes/38405-bitlocker-encryption b/orbit/changes/38405-bitlocker-encryption new file mode 100644 index 000000000000..3a57cccbefaf --- /dev/null +++ b/orbit/changes/38405-bitlocker-encryption @@ -0,0 +1 @@ +* Fixed a COM deadlock on Windows that could cause orbit to become unresponsive during BitLocker encryption enforcement. BitLocker operations now run on a dedicated COM thread instead of sharing the global comshim singleton with other subsystems. diff --git a/orbit/cmd/orbit/orbit.go b/orbit/cmd/orbit/orbit.go index a0902c664817..b7dfde482511 100644 --- a/orbit/cmd/orbit/orbit.go +++ b/orbit/cmd/orbit/orbit.go @@ -33,6 +33,7 @@ import ( httpsigproxy "github.com/fleetdm/fleet/v4/ee/orbit/pkg/httpsigproxy" "github.com/fleetdm/fleet/v4/ee/orbit/pkg/securehw" "github.com/fleetdm/fleet/v4/orbit/pkg/augeas" + "github.com/fleetdm/fleet/v4/orbit/pkg/bitlocker" "github.com/fleetdm/fleet/v4/orbit/pkg/build" "github.com/fleetdm/fleet/v4/orbit/pkg/constant" "github.com/fleetdm/fleet/v4/orbit/pkg/execuser" @@ -1219,7 +1220,13 @@ func main() { case "windows": orbitClient.RegisterConfigReceiver(update.ApplyWindowsMDMEnrollmentFetcherMiddleware(windowsMDMEnrollmentCommandFrequency, orbitHostInfo.HardwareUUID, orbitClient)) - orbitClient.RegisterConfigReceiver(update.ApplyWindowsMDMBitlockerFetcherMiddleware(windowsMDMBitlockerCommandFrequency, orbitClient)) + comWorker, err := bitlocker.NewCOMWorker() + if err != nil { + return fmt.Errorf("create BitLocker COM worker: %w", err) + } + defer comWorker.Close() + orbitClient.RegisterConfigReceiver(update.ApplyWindowsMDMBitlockerFetcherMiddleware( + windowsMDMBitlockerCommandFrequency, orbitClient, comWorker)) case "linux": orbitClient.RegisterConfigReceiver(luks.New(orbitClient)) } diff --git a/orbit/pkg/bitlocker/bitlocker_management_notwindows.go b/orbit/pkg/bitlocker/bitlocker_management_notwindows.go index 4263ba270e81..4b3d76cca741 100644 --- a/orbit/pkg/bitlocker/bitlocker_management_notwindows.go +++ b/orbit/pkg/bitlocker/bitlocker_management_notwindows.go @@ -1,19 +1,3 @@ //go:build !windows package bitlocker - -func GetRecoveryKeys(targetVolume string) (map[string]string, error) { - return nil, nil -} - -func EncryptVolume(targetVolume string) (string, error) { - return "", nil -} - -func DecryptVolume(targetVolume string) error { - return nil -} - -func GetEncryptionStatus() ([]VolumeStatus, error) { - return nil, nil -} diff --git a/orbit/pkg/bitlocker/bitlocker_management_windows.go b/orbit/pkg/bitlocker/bitlocker_management_windows.go index 857d530a0f3f..64fc504077af 100644 --- a/orbit/pkg/bitlocker/bitlocker_management_windows.go +++ b/orbit/pkg/bitlocker/bitlocker_management_windows.go @@ -3,13 +3,11 @@ package bitlocker import ( - "errors" "fmt" "syscall" "github.com/go-ole/go-ole" "github.com/go-ole/go-ole/oleutil" - "github.com/scjalliance/comshim" ) // Encryption Methods @@ -117,7 +115,6 @@ func (v *Volume) bitlockerClose() { v.wmiSvc.Release() } - comshim.Done() } // encrypt encrypts the volume @@ -252,49 +249,22 @@ func (v *Volume) getBitlockerStatus() (*EncryptionStatus, error) { return encStatus, nil } -// getProtectorsKeys returns the recovery keys for the volume -// https://learn.microsoft.com/en-us/windows/win32/secprov/getkeyprotectornumericalpassword-win32-encryptablevolume -func (v *Volume) getProtectorsKeys() (map[string]string, error) { - keys, err := getKeyProtectors(v.handle) - if err != nil { - return nil, fmt.Errorf("getKeyProtectors: %w", err) - } - - recoveryKeys := make(map[string]string) - for _, k := range keys { - var recoveryKey ole.VARIANT - _ = ole.VariantInit(&recoveryKey) - recoveryKeyResultRaw, err := oleutil.CallMethod(v.handle, "GetKeyProtectorNumericalPassword", k, &recoveryKey) - if err != nil { - continue // No recovery key for this protector - } else if val, ok := recoveryKeyResultRaw.Value().(int32); val != 0 || !ok { - continue // No recovery key for this protector - } - recoveryKeys[k] = recoveryKey.ToString() - } - - return recoveryKeys, nil -} - ///////////////////////////////////////////////////// // Helper functions ///////////////////////////////////////////////////// // bitlockerConnect connects to an encryptable volume in order to manage it. func bitlockerConnect(driveLetter string) (Volume, error) { - comshim.Add(1) v := Volume{letter: driveLetter} unknown, err := oleutil.CreateObject("WbemScripting.SWbemLocator") if err != nil { - comshim.Done() return v, fmt.Errorf("createObject: %w", err) } defer unknown.Release() v.wmiIntf, err = unknown.QueryInterface(ole.IID_IDispatch) if err != nil { - comshim.Done() return v, fmt.Errorf("queryInterface: %w", err) } serviceRaw, err := oleutil.CallMethod(v.wmiIntf, "ConnectServer", nil, `\\.\ROOT\CIMV2\Security\MicrosoftVolumeEncryption`) @@ -328,32 +298,6 @@ func intToPercentage(num int32) string { return fmt.Sprintf("%.2f%%", percentage) } -// getKeyProtectors returns the key protectors for the volume -// https://learn.microsoft.com/en-us/windows/win32/secprov/getkeyprotectors-win32-encryptablevolume -func getKeyProtectors(item *ole.IDispatch) ([]string, error) { - kp := []string{} - var keyProtectorResults ole.VARIANT - _ = ole.VariantInit(&keyProtectorResults) - - keyIDResultRaw, err := oleutil.CallMethod(item, "GetKeyProtectors", 3, &keyProtectorResults) - if err != nil { - return nil, fmt.Errorf("unable to get Key Protectors while getting BitLocker info. %s", err.Error()) - } else if val, ok := keyIDResultRaw.Value().(int32); val != 0 || !ok { - return nil, fmt.Errorf("unable to get Key Protectors while getting BitLocker info. Return code %d", val) - } - - keyProtectorValues := keyProtectorResults.ToArray().ToValueArray() - for _, keyIDItemRaw := range keyProtectorValues { - keyIDItem, ok := keyIDItemRaw.(string) - if !ok { - return nil, errors.New("keyProtectorID wasn't a string") - } - kp = append(kp, keyIDItem) - } - - return kp, nil -} - // bitsToDrives converts a bit map to a list of drives func bitsToDrives(bitMap uint32) (drives []string) { availableDrives := []string{"A", "B", "C", "D", "E", "F", "G", "H", "I", "J", "K", "L", "M", "N", "O", "P", "Q", "R", "S", "T", "U", "V", "W", "X", "Y", "Z"} @@ -411,24 +355,7 @@ func getBitlockerStatus(targetVolume string) (*EncryptionStatus, error) { // Bitlocker Management interface implementation ///////////////////////////////////////////////////// -func GetRecoveryKeys(targetVolume string) (map[string]string, error) { - // Connect to the volume - vol, err := bitlockerConnect(targetVolume) - if err != nil { - return nil, fmt.Errorf("connecting to the volume: %w", err) - } - defer vol.bitlockerClose() - - // Get recovery keys - keys, err := vol.getProtectorsKeys() - if err != nil { - return nil, fmt.Errorf("retreving protection keys: %w", err) - } - - return keys, nil -} - -func EncryptVolume(targetVolume string) (string, error) { +func encryptVolumeOnCOMThread(targetVolume string) (string, error) { // Connect to the volume vol, err := bitlockerConnect(targetVolume) if err != nil { @@ -460,7 +387,7 @@ func EncryptVolume(targetVolume string) (string, error) { return recoveryKey, nil } -func DecryptVolume(targetVolume string) error { +func decryptVolumeOnCOMThread(targetVolume string) error { // Connect to the volume vol, err := bitlockerConnect(targetVolume) if err != nil { @@ -476,7 +403,7 @@ func DecryptVolume(targetVolume string) error { return nil } -func GetEncryptionStatus() ([]VolumeStatus, error) { +func getEncryptionStatusOnCOMThread() ([]VolumeStatus, error) { drives, err := getLogicalVolumes() if err != nil { return nil, fmt.Errorf("logical volumen enumeration %w", err) diff --git a/orbit/pkg/bitlocker/bitlocker_worker_notwindows.go b/orbit/pkg/bitlocker/bitlocker_worker_notwindows.go new file mode 100644 index 000000000000..0e7e2ced884f --- /dev/null +++ b/orbit/pkg/bitlocker/bitlocker_worker_notwindows.go @@ -0,0 +1,21 @@ +//go:build !windows + +package bitlocker + +// COMWorker is a no-op on non-Windows platforms. +type COMWorker struct{} + +// NewCOMWorker returns a no-op COMWorker on non-Windows platforms. +func NewCOMWorker() (*COMWorker, error) { return &COMWorker{}, nil } + +// Close is a no-op on non-Windows platforms. +func (w *COMWorker) Close() {} + +// GetEncryptionStatus is a no-op on non-Windows platforms. +func (w *COMWorker) GetEncryptionStatus() ([]VolumeStatus, error) { return nil, nil } + +// EncryptVolume is a no-op on non-Windows platforms. +func (w *COMWorker) EncryptVolume(string) (string, error) { return "", nil } + +// DecryptVolume is a no-op on non-Windows platforms. +func (w *COMWorker) DecryptVolume(string) error { return nil } diff --git a/orbit/pkg/bitlocker/bitlocker_worker_windows.go b/orbit/pkg/bitlocker/bitlocker_worker_windows.go new file mode 100644 index 000000000000..ece46b50a294 --- /dev/null +++ b/orbit/pkg/bitlocker/bitlocker_worker_windows.go @@ -0,0 +1,105 @@ +//go:build windows + +package bitlocker + +import ( + "errors" + "runtime" + "sync" + + "github.com/go-ole/go-ole" +) + +// ErrWorkerClosed is returned when an operation is attempted on a closed COMWorker. +var ErrWorkerClosed = errors.New("COM worker is closed") + +type comWorkItem struct { + fn func() (any, error) + result chan comWorkResult +} + +type comWorkResult struct { + val any + err error +} + +// COMWorker runs all BitLocker COM/WMI operations on a single dedicated OS +// thread. This prevents deadlocks with other COM callers (MDM Bridge, Windows +// Update) that share the global comshim singleton. +type COMWorker struct { + workCh chan comWorkItem + done chan struct{} + closeOnce sync.Once +} + +// NewCOMWorker creates a new COMWorker that initializes COM on a dedicated OS +// thread and processes all BitLocker operations serially on that thread. +func NewCOMWorker() (*COMWorker, error) { + w := &COMWorker{ + workCh: make(chan comWorkItem), + done: make(chan struct{}), + } + initErr := make(chan error, 1) + go w.loop(initErr) + if err := <-initErr; err != nil { + return nil, err + } + return w, nil +} + +func (w *COMWorker) loop(initErr chan<- error) { + runtime.LockOSThread() + defer runtime.UnlockOSThread() + + if err := ole.CoInitializeEx(0, ole.COINIT_MULTITHREADED); err != nil { + initErr <- err + close(w.done) + return + } + defer ole.CoUninitialize() + initErr <- nil + + for item := range w.workCh { + val, err := item.fn() + item.result <- comWorkResult{val, err} + } + close(w.done) +} + +// Close shuts down the COM worker goroutine and waits for it to finish. +func (w *COMWorker) Close() { + w.closeOnce.Do(func() { + close(w.workCh) + }) + <-w.done +} + +func (w *COMWorker) exec(fn func() (any, error)) comWorkResult { + ch := make(chan comWorkResult, 1) + select { + case w.workCh <- comWorkItem{fn: fn, result: ch}: + return <-ch + case <-w.done: + return comWorkResult{err: ErrWorkerClosed} + } +} + +// GetEncryptionStatus returns the BitLocker encryption status for all logical volumes. +func (w *COMWorker) GetEncryptionStatus() ([]VolumeStatus, error) { + r := w.exec(func() (any, error) { return getEncryptionStatusOnCOMThread() }) + status, _ := r.val.([]VolumeStatus) + return status, r.err +} + +// EncryptVolume encrypts the specified volume and returns the recovery key. +func (w *COMWorker) EncryptVolume(targetVolume string) (string, error) { + r := w.exec(func() (any, error) { return encryptVolumeOnCOMThread(targetVolume) }) + key, _ := r.val.(string) + return key, r.err +} + +// DecryptVolume decrypts the specified volume. +func (w *COMWorker) DecryptVolume(targetVolume string) error { + r := w.exec(func() (any, error) { return nil, decryptVolumeOnCOMThread(targetVolume) }) + return r.err +} diff --git a/orbit/pkg/update/notifications.go b/orbit/pkg/update/notifications.go index 4cfc87724802..0f7f7451f43d 100644 --- a/orbit/pkg/update/notifications.go +++ b/orbit/pkg/update/notifications.go @@ -446,26 +446,27 @@ type windowsMDMBitlockerConfigReceiver struct { // ensures only one script execution runs at a time mu sync.Mutex - // for tests, to be able to mock API commands. If nil, will use - // bitlocker.EncryptVolume + // execEncryptVolumeFn handles volume encryption. Set by the middleware from the COMWorker, or overridden in tests. execEncryptVolumeFn execEncryptVolumeFunc - // for tests, to be able to mock API commands. If nil, will use - // bitlocker.GetEncryptionStatus + // execGetEncryptionStatusFn retrieves encryption status. Set by the middleware from the COMWorker, or overridden in tests. execGetEncryptionStatusFn execGetEncryptionStatusFunc - // for tests, to be able to mock the decryption process. If nil, will use - // bitlocker.DecryptVolume + // execDecryptVolumeFn handles volume decryption. Set by the middleware from the COMWorker, or overridden in tests. execDecryptVolumeFn execDecryptVolumeFunc } func ApplyWindowsMDMBitlockerFetcherMiddleware( frequency time.Duration, encryptionResult DiskEncryptionKeySetter, + comWorker *bitlocker.COMWorker, ) fleet.OrbitConfigReceiver { return &windowsMDMBitlockerConfigReceiver{ - Frequency: frequency, - EncryptionResult: encryptionResult, + Frequency: frequency, + EncryptionResult: encryptionResult, + execEncryptVolumeFn: comWorker.EncryptVolume, + execGetEncryptionStatusFn: comWorker.GetEncryptionStatus, + execDecryptVolumeFn: comWorker.DecryptVolume, } } @@ -559,11 +560,7 @@ func (w *windowsMDMBitlockerConfigReceiver) attemptBitlockerEncryption(notifs fl // getEncryptionStatusForVolume retrieves the encryption status for a specific volume. func (w *windowsMDMBitlockerConfigReceiver) getEncryptionStatusForVolume(volume string) (*bitlocker.EncryptionStatus, error) { - fn := w.execGetEncryptionStatusFn - if fn == nil { - fn = bitlocker.GetEncryptionStatus - } - status, err := fn() + status, err := w.execGetEncryptionStatusFn() if err != nil { return nil, err } @@ -593,12 +590,7 @@ func (w *windowsMDMBitlockerConfigReceiver) bitLockerActionInProgress(status *bi // performEncryption executes the encryption process. func (w *windowsMDMBitlockerConfigReceiver) performEncryption(volume string) (string, error) { - fn := w.execEncryptVolumeFn - if fn == nil { - fn = bitlocker.EncryptVolume - } - - recoveryKey, err := fn(volume) + recoveryKey, err := w.execEncryptVolumeFn(volume) if err != nil { return "", err } @@ -607,12 +599,7 @@ func (w *windowsMDMBitlockerConfigReceiver) performEncryption(volume string) (st } func (w *windowsMDMBitlockerConfigReceiver) decryptVolume(targetVolume string) error { - fn := w.execDecryptVolumeFn - if fn == nil { - fn = bitlocker.DecryptVolume - } - - return fn(targetVolume) + return w.execDecryptVolumeFn(targetVolume) } // isMisreportedDecryptionError checks whether the given error is a potentially