diff --git a/cmd/gcs-sidecar/main.go b/cmd/gcs-sidecar/main.go index 4cd0d70b34..71e27dbc05 100644 --- a/cmd/gcs-sidecar/main.go +++ b/cmd/gcs-sidecar/main.go @@ -15,6 +15,7 @@ import ( "github.com/Microsoft/hcsshim/internal/gcs/prot" shimlog "github.com/Microsoft/hcsshim/internal/log" "github.com/Microsoft/hcsshim/internal/oc" + "github.com/Microsoft/hcsshim/internal/pspdriver" "github.com/Microsoft/hcsshim/pkg/securitypolicy" "github.com/sirupsen/logrus" "go.opencensus.io/trace" @@ -214,24 +215,17 @@ func main() { return } - // gcs-sidecar can be used for non-confidentail hyperv wcow - // as well. So we do not always want to check for initialPolicyStance - var initialEnforcer securitypolicy.SecurityPolicyEnforcer - // TODO (kiashok/Mahati): The initialPolicyStance is set to allow - // only for dev. This will eventually be set to allow/deny depending on - // on whether SNP is supported or not. - initialPolicyStance := "allow" - switch initialPolicyStance { - case "allow": - initialEnforcer = &securitypolicy.OpenDoorSecurityPolicyEnforcer{} - logrus.Tracef("initial-policy-stance: allow") - case "deny": - initialEnforcer = &securitypolicy.ClosedDoorSecurityPolicyEnforcer{} - logrus.Tracef("initial-policy-stance: deny") - default: - logrus.Error("unknown initial-policy-stance") + if err := pspdriver.StartPSPDriver(ctx); err != nil { + // When error happens, pspdriver.GetPspDriverError() returns true. + // In that case, gcs-sidecar should keep the initial "deny" policy + // and reject all requests from the host. + logrus.WithError(err).Errorf("failed to start PSP driver") } + // Use "deny" policy as initial enforcer. + // This is updated later with user provided policy. + initialEnforcer := &securitypolicy.ClosedDoorSecurityPolicyEnforcer{} + // 3. Create bridge and initializa brdg := sidecar.NewBridge(shimCon, gcsCon, initialEnforcer) brdg.AssignHandlers() diff --git a/internal/gcs-sidecar/handlers.go b/internal/gcs-sidecar/handlers.go index 689e02bb66..9dc0818d31 100644 --- a/internal/gcs-sidecar/handlers.go +++ b/internal/gcs-sidecar/handlers.go @@ -323,7 +323,7 @@ func (b *Bridge) modifySettings(req *request) (err error) { case guestresource.ResourceTypeSecurityPolicy: securityPolicyRequest := modifyGuestSettingsRequest.Settings.(*guestresource.WCOWConfidentialOptions) log.G(ctx).Tracef("WCOWConfidentialOptions: { %v}", securityPolicyRequest) - _ = b.hostState.SetWCOWConfidentialUVMOptions(securityPolicyRequest) + _ = b.hostState.SetWCOWConfidentialUVMOptions(ctx, securityPolicyRequest) // Send response back to shim resp := &prot.ResponseBase{ diff --git a/internal/gcs-sidecar/host.go b/internal/gcs-sidecar/host.go index b17cd38448..9e0602aaa8 100644 --- a/internal/gcs-sidecar/host.go +++ b/internal/gcs-sidecar/host.go @@ -4,12 +4,14 @@ package bridge import ( - "errors" + "context" "fmt" "sync" "github.com/Microsoft/hcsshim/internal/protocol/guestresource" + "github.com/Microsoft/hcsshim/internal/pspdriver" "github.com/Microsoft/hcsshim/pkg/securitypolicy" + "github.com/pkg/errors" ) type Host struct { @@ -32,7 +34,7 @@ func NewHost(initialEnforcer securitypolicy.SecurityPolicyEnforcer) *Host { } } -func (h *Host) SetWCOWConfidentialUVMOptions(securityPolicyRequest *guestresource.WCOWConfidentialOptions) error { +func (h *Host) SetWCOWConfidentialUVMOptions(ctx context.Context, securityPolicyRequest *guestresource.WCOWConfidentialOptions) error { h.policyMutex.Lock() defer h.policyMutex.Unlock() @@ -40,6 +42,22 @@ func (h *Host) SetWCOWConfidentialUVMOptions(securityPolicyRequest *guestresourc return errors.New("security policy has already been set") } + if err := pspdriver.GetPspDriverError(); err != nil { + // For this case gcs-sidecar will keep initial deny policy. + return errors.Wrapf(err, "an error occurred while using PSP driver") + } + + // Fetch report and validate host_data + hostData, err := securitypolicy.NewSecurityPolicyDigest(securityPolicyRequest.EncodedSecurityPolicy) + if err != nil { + return err + } + + if err := pspdriver.ValidateHostData(ctx, hostData[:]); err != nil { + // For this case gcs-sidecar will keep initial deny policy. + return err + } + // This limit ensures messages are below the character truncation limit that // can be imposed by an orchestrator maxErrorMessageLength := 3 * 1024 diff --git a/internal/pspdriver/pspdriver.go b/internal/pspdriver/pspdriver.go new file mode 100644 index 0000000000..db41384853 --- /dev/null +++ b/internal/pspdriver/pspdriver.go @@ -0,0 +1,334 @@ +//go:build windows +// +build windows + +package pspdriver + +import ( + "bytes" + "context" + "encoding/binary" + "encoding/hex" + "fmt" + "time" + + "github.com/Microsoft/hcsshim/internal/log" + "github.com/Microsoft/hcsshim/internal/winapi" + "github.com/pkg/errors" + "golang.org/x/sys/windows/svc" + "golang.org/x/sys/windows/svc/mgr" +) + +const ( + serviceName = "AmdSnpPsp" +) + +const ( + SnpPspAPIStatusSuccess = 0x00000000 + SnpPspAPIStatusUnsuccessful = 0x00000001 + SnpPspAPIStatusDriverUnsuccessful = 0x00000002 + SnpPspAPIStatusPspUnsuccessful = 0x00000003 + SnpPspAPIStatusInvalidParameter = 0x00000004 + SnpPspAPIStatusDeviceNotAvailable = 0x00000005 +) + +// TODO: Fix duplication with pkg/amdsevsnp and merge this into it. + +const ( + SnpPspReportDataSize = 64 + SnpPspReportHostDataSize = 32 + SnpPspAttestationReportSize = 0x4A0 +) + +type SNPPSPGuestRequestResult struct { + DriverStatus uint32 + PspStatus uint64 +} + +// Type used by FetchParsedSNPReport. +// This it converted to the public type `Report` +// by `func (sr *report) report() Report`. +type report struct { + Version uint32 + GuestSVN uint32 + Policy uint64 + FamilyID [16]byte + ImageID [16]byte + VMPL uint32 + SignatureAlgo uint32 + PlatformVersion uint64 + PlatformInfo uint64 + AuthorKeyEn uint32 + Reserved1 uint32 + ReportData [SnpPspReportDataSize]byte + Measurement [48]byte + HostData [SnpPspReportHostDataSize]byte + IDKeyDigest [48]byte + AuthorKeyDigest [48]byte + ReportID [32]byte + ReportIDMA [32]byte + ReportTCB uint64 + Reserved2 [24]byte + ChipID [64]byte + CommittedSVN [8]byte + CommittedVersion [8]byte + LaunchSVN [8]byte + Reserved3 [168]byte + Signature [512]byte +} + +// Report represents parsed attestation report. +// Fields with string type is hex-encoded values of the corresponding byte arrays. +// Based on Table 23 of 'SEV-ES Guest-Hypervisor Communication Block Standardization'. +// +// https://www.amd.com/content/dam/amd/en/documents/epyc-technical-docs/specifications/56421.pdf +type Report struct { + Version uint32 + GuestSVN uint32 + Policy uint64 + FamilyID string + ImageID string + VMPL uint32 + SignatureAlgo uint32 + PlatformVersion uint64 + PlatformInfo uint64 + AuthorKeyEn uint32 + ReportData string + Measurement string + HostData []byte + IDKeyDigest string + AuthorKeyDigest string + ReportID string + ReportIDMA string + ReportTCB uint64 + ChipID string + CommittedSVN string + CommittedVersion string + LaunchSVN string + Signature string +} + +func (sr *report) report() Report { + return Report{ + Version: sr.Version, + GuestSVN: sr.GuestSVN, + Policy: sr.Policy, + FamilyID: hex.EncodeToString(mirrorBytes(sr.FamilyID[:])[:]), + ImageID: hex.EncodeToString(mirrorBytes(sr.ImageID[:])[:]), + VMPL: sr.VMPL, + SignatureAlgo: sr.SignatureAlgo, + PlatformVersion: sr.PlatformVersion, + PlatformInfo: sr.PlatformInfo, + AuthorKeyEn: sr.AuthorKeyEn, + ReportData: hex.EncodeToString(sr.ReportData[:]), + Measurement: hex.EncodeToString(sr.Measurement[:]), + HostData: sr.HostData[:], + IDKeyDigest: hex.EncodeToString(sr.IDKeyDigest[:]), + AuthorKeyDigest: hex.EncodeToString(sr.AuthorKeyDigest[:]), + ReportID: hex.EncodeToString(sr.ReportID[:]), + ReportIDMA: hex.EncodeToString(sr.ReportIDMA[:]), + ReportTCB: sr.ReportTCB, + ChipID: hex.EncodeToString(sr.ChipID[:]), + CommittedSVN: hex.EncodeToString(sr.CommittedSVN[:]), + CommittedVersion: hex.EncodeToString(sr.CommittedVersion[:]), + LaunchSVN: hex.EncodeToString(sr.LaunchSVN[:]), + Signature: hex.EncodeToString(sr.Signature[:]), + } +} + +// mirrorBytes mirrors the byte ordering so that hex-encoding little endian +// ordered bytes come out in the readable order. +func mirrorBytes(b []byte) []byte { + for i := 0; i < len(b)/2; i++ { + mirrorIndex := len(b) - i - 1 + b[i], b[mirrorIndex] = b[mirrorIndex], b[i] + } + return b +} + +var ( + pspDriverStarted = false + // The error needs to be stored to be retrieved later. + // When driver or its dll fails, gcs-sidecar doesn't + // set security policy and keep the initial deny policy. + pspDriverError error = nil +) + +func StartPSPDriver(ctx context.Context) error { + // Connect to the Service Control Manager + m, err := mgr.Connect() + if err != nil { + return errors.Wrap(err, "Failed to connect to service manager") + } + defer func() { + if derr := m.Disconnect(); derr != nil { + // Log the error on disconnect but do not override the returned error. + log.G(ctx).Warnf("Failed to disconnect from service manager: %v", derr) + } + }() + + // Open the service + s, err := m.OpenService(serviceName) + if err != nil { + return errors.Wrapf(err, "Could not access service %q", serviceName) + } + defer s.Close() + + // Start the service + err = s.Start() + if err != nil { + return errors.Wrapf(err, "Could not start service %q", serviceName) + } + + // From the documentation, there is no guarantee that the service will be + // in `Running` state immediately after starting it. + // Wait until the service is in the `Running` state. + timeout := time.After(3 * time.Second) + tick := time.Tick(100 * time.Millisecond) + for { + select { + case <-timeout: + pspDriverError = errors.New("timed out waiting for PSP driver to start") + return pspDriverError + case <-tick: + status, err := s.Query() + if err != nil { + pspDriverError = errors.Wrap(err, "could not query PSP driver status") + return pspDriverError + } + if status.State == svc.Running { + log.G(ctx).Tracef("Service %q started successfully", serviceName) + + pspDriverStarted = true + return nil + } + } + } +} + +func IsPspDriverStarted() bool { + return pspDriverStarted +} + +// Return an error from the PSP driver dll +// when it fails to use the dll at all. +// Otherwise it returns nil. +func GetPspDriverError() error { + return pspDriverError +} + +// IsSNPMode() returns true if it's in SNP mode. +func IsSNPMode(ctx context.Context) (bool, error) { + + if pspDriverError != nil { + return false, pspDriverError + } + + if !pspDriverStarted { + return false, errors.New("PSP driver is not started") + } + + // snpMode is defined as BOOLEAN (= byte) + var snpMode uint8 + ret, err := winapi.SnpPspIsSnpMode(&snpMode) + + if ret != SnpPspAPIStatusSuccess || err != nil { + errMessage := "" + if err != nil { + // err is not nil either when `winapi` didn't find the API or when ret is not success. + // In case of the former, ret is meaningless because ret is returned by the dll. + // In case of the latter, we don't need to print err. + // We can't tell which case it is here, we print all the information we have. + // We could avoid this by loading the dll in this package, but we use `winapi` for consistency with existing code. + errMessage = fmt.Sprintf(", err: %v", err) + } + pspDriverError = errors.Errorf("failed to determine if it's in SNP VM. SNPPSP_API_STATUS: 0x%x%s", ret, errMessage) + return false, pspDriverError + } + + return snpMode == 1, nil +} + +// FetchRawSNPReport returns attestation report bytes. +func FetchRawSNPReport(ctx context.Context, reportData []byte) ([]byte, error) { + if pspDriverError != nil { + return nil, pspDriverError + } + + if !pspDriverStarted { + return nil, errors.New("PSP driver is not started") + } + + var reportDataBuf [SnpPspReportDataSize]uint8 + + if reportData != nil { + if len(reportData) > SnpPspReportDataSize { + return nil, fmt.Errorf("reportData too large: %s", reportData) + } + copy(reportDataBuf[:], reportData) + } + + var report [SnpPspAttestationReportSize]uint8 + var guestRequestResult winapi.SNPPSPGuestRequestResult + + // Fetch attestation report using generated winapi wrapper + ret, err := winapi.SnpPspFetchAttestationReport(&reportDataBuf[0], &guestRequestResult, &report[0]) + if ret != SnpPspAPIStatusSuccess || err != nil { + errMessage := "" + if err != nil { + // err is not nil either when `winapi` didn't find the API or when ret is not success. + // In case of the former, ret and guestRequestResult are meaningless because they are returned by the dll. + // In case of the latter, we don't need to print err. + // We can't tell which case it is here, we print all the information we have. + // We could avoid this by loading the dll in this package, but we use `winapi` for consistency with existing code. + errMessage = fmt.Sprintf(", err: %v", err) + } + pspDriverError = errors.Errorf("failed to fetch attestation report. res: 0x%x, DriverStatus: 0x%x, PspStatus: 0x%x%s", + ret, guestRequestResult.DriverStatus, guestRequestResult.PspStatus, errMessage) + return nil, pspDriverError + } + + return report[:], nil +} + +// FetchParsedSNPReport parses raw attestation response into proper structs. +func FetchParsedSNPReport(ctx context.Context, reportData []byte) (Report, error) { + rawBytes, err := FetchRawSNPReport(ctx, reportData) + if err != nil { + return Report{}, err + } + + var r report + buf := bytes.NewBuffer(rawBytes) + if err := binary.Read(buf, binary.LittleEndian, &r); err != nil { + return Report{}, err + } + return r.report(), nil +} + +// TODO: Based on internal\guest\runtime\hcsv2\hostdata.go and it's duplicated. +// ValidateHostData fetches SNP report (if applicable) and validates `hostData` against +// HostData set at UVM launch. +func ValidateHostData(ctx context.Context, hostData []byte) error { + // If the UVM is not SNP, then don't try to fetch an SNP report. + isSnpMode, err := IsSNPMode(ctx) + if err != nil { + return err + } + if !isSnpMode { + return nil + } + report, err := FetchParsedSNPReport(ctx, nil) + if err != nil { + return err + } + + if !bytes.Equal(hostData, report.HostData[:]) { + return fmt.Errorf( + "security policy digest %q doesn't match HostData provided at launch %q", + hostData, + report.HostData[:], + ) + } + + return nil +} diff --git a/internal/winapi/amdsnp.go b/internal/winapi/amdsnp.go new file mode 100644 index 0000000000..5e0f9cd2c5 --- /dev/null +++ b/internal/winapi/amdsnp.go @@ -0,0 +1,11 @@ +//go:build windows + +package winapi + +type SNPPSPGuestRequestResult struct { + DriverStatus uint32 + PspStatus uint64 +} + +//sys SnpPspIsSnpMode(snpMode *uint8) (ret uint32, err error) [failretval>0] = amdsnppspapi.SnpPspIsSnpMode? +//sys SnpPspFetchAttestationReport(reportData *uint8, guestRequestResult *SNPPSPGuestRequestResult, report *uint8) (ret uint32, err error) [failretval>0] = amdsnppspapi.SnpPspFetchAttestationReport? diff --git a/internal/winapi/zsyscall_windows.go b/internal/winapi/zsyscall_windows.go index a7eea44ec7..761e09bbd4 100644 --- a/internal/winapi/zsyscall_windows.go +++ b/internal/winapi/zsyscall_windows.go @@ -37,17 +37,20 @@ func errnoErr(e syscall.Errno) error { } var ( - modadvapi32 = windows.NewLazySystemDLL("advapi32.dll") - modbindfltapi = windows.NewLazySystemDLL("bindfltapi.dll") - modcfgmgr32 = windows.NewLazySystemDLL("cfgmgr32.dll") - modcimfs = windows.NewLazySystemDLL("cimfs.dll") - modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll") - modkernel32 = windows.NewLazySystemDLL("kernel32.dll") - modnetapi32 = windows.NewLazySystemDLL("netapi32.dll") - modntdll = windows.NewLazySystemDLL("ntdll.dll") - modoffreg = windows.NewLazySystemDLL("offreg.dll") + modadvapi32 = windows.NewLazySystemDLL("advapi32.dll") + modamdsnppspapi = windows.NewLazySystemDLL("amdsnppspapi.dll") + modbindfltapi = windows.NewLazySystemDLL("bindfltapi.dll") + modcfgmgr32 = windows.NewLazySystemDLL("cfgmgr32.dll") + modcimfs = windows.NewLazySystemDLL("cimfs.dll") + modiphlpapi = windows.NewLazySystemDLL("iphlpapi.dll") + modkernel32 = windows.NewLazySystemDLL("kernel32.dll") + modnetapi32 = windows.NewLazySystemDLL("netapi32.dll") + modntdll = windows.NewLazySystemDLL("ntdll.dll") + modoffreg = windows.NewLazySystemDLL("offreg.dll") procLogonUserW = modadvapi32.NewProc("LogonUserW") + procSnpPspFetchAttestationReport = modamdsnppspapi.NewProc("SnpPspFetchAttestationReport") + procSnpPspIsSnpMode = modamdsnppspapi.NewProc("SnpPspIsSnpMode") procBfSetupFilter = modbindfltapi.NewProc("BfSetupFilter") procCM_Get_DevNode_PropertyW = modcfgmgr32.NewProc("CM_Get_DevNode_PropertyW") procCM_Get_Device_ID_ListA = modcfgmgr32.NewProc("CM_Get_Device_ID_ListA") @@ -124,6 +127,32 @@ func LogonUser(username *uint16, domain *uint16, password *uint16, logonType uin return } +func SnpPspFetchAttestationReport(reportData *uint8, guestRequestResult *SNPPSPGuestRequestResult, report *uint8) (ret uint32, err error) { + err = procSnpPspFetchAttestationReport.Find() + if err != nil { + return + } + r0, _, e1 := syscall.SyscallN(procSnpPspFetchAttestationReport.Addr(), uintptr(unsafe.Pointer(reportData)), uintptr(unsafe.Pointer(guestRequestResult)), uintptr(unsafe.Pointer(report))) + ret = uint32(r0) + if ret > 0 { + err = errnoErr(e1) + } + return +} + +func SnpPspIsSnpMode(snpMode *uint8) (ret uint32, err error) { + err = procSnpPspIsSnpMode.Find() + if err != nil { + return + } + r0, _, e1 := syscall.SyscallN(procSnpPspIsSnpMode.Addr(), uintptr(unsafe.Pointer(snpMode))) + ret = uint32(r0) + if ret > 0 { + err = errnoErr(e1) + } + return +} + func BfSetupFilter(jobHandle windows.Handle, flags uint32, virtRootPath *uint16, virtTargetPath *uint16, virtExceptions **uint16, virtExceptionPathCount uint32) (hr error) { hr = procBfSetupFilter.Find() if hr != nil {