Skip to content

Commit

Permalink
Optimize error propagation (#2819)
Browse files Browse the repository at this point in the history
  • Loading branch information
kennykerr committed Jan 27, 2024
1 parent 993bfa0 commit a2fbd10
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 70 deletions.
101 changes: 52 additions & 49 deletions crates/libs/core/src/error.rs
@@ -1,23 +1,27 @@
#![allow(missing_docs)]

use super::*;

/// An error object consists of both an error code as well as detailed error information for debugging.
#[derive(Clone, PartialEq, Eq)]
pub struct Error {
pub(crate) code: HRESULT,
pub(crate) info: Option<crate::imp::IRestrictedErrorInfo>,
pub(crate) info: Option<crate::imp::IErrorInfo>,
}

unsafe impl Send for Error {}
unsafe impl Sync for Error {}

impl Error {
/// An error object without any failure information.
pub const OK: Self = Self { code: HRESULT(0), info: None };

/// This creates a new WinRT error object, capturing the stack and other information about the
/// This creates a new error object, capturing the stack and other information about the
/// point of failure.
pub fn new(code: HRESULT, message: HSTRING) -> Self {
unsafe {
crate::imp::RoOriginateError(code.0, std::mem::transmute_copy(&message));
let info = GetErrorInfo().and_then(|e| e.cast()).ok();
Self { code, info }
Self { code, info: GetErrorInfo() }
}
}

Expand All @@ -31,43 +35,66 @@ impl Error {
self.code
}

/// The error information describing the error.
pub const fn info(&self) -> &Option<crate::imp::IRestrictedErrorInfo> {
&self.info
/// The error object describing the error.
pub fn info<T: Interface>(&self) -> Option<T> {
self.info.as_ref().and_then(|info| info.cast::<T>().ok())
}

/// The error message describing the error.
pub fn message(&self) -> HSTRING {
// First attempt to retrieve the restricted error information.
if let Some(info) = &self.info {
let mut fallback = BSTR::default();
let mut message = BSTR::default();
let mut code = HRESULT(0);

unsafe {
let _ = info.GetErrorDetails(&mut fallback, &mut code, &mut message, &mut BSTR::default());
// First attempt to retrieve the restricted error information.
if let Ok(info) = info.cast::<crate::imp::IRestrictedErrorInfo>() {
let mut fallback = BSTR::default();
let mut code = HRESULT(0);

unsafe {
// The vfptr is called directly to avoid the default error propagation logic.
_ = (info.vtable().GetErrorDetails)(info.as_raw(), &mut fallback as *mut _ as _, &mut code, &mut message as *mut _ as _, &mut BSTR::default() as *mut _ as _);
}

if message.is_empty() {
message = fallback
};
}

if self.code == code {
let message = if !message.is_empty() { message } else { fallback };
return HSTRING::from_wide(crate::imp::wide_trim_end(message.as_wide())).unwrap_or_default();
// Next attempt to retrieve the regular error information.
if message.is_empty() {
unsafe {
// The vfptr is called directly to avoid the default error propagation logic.
_ = (info.vtable().GetDescription)(info.as_raw(), &mut message as *mut _ as _);
}
}

return HSTRING::from_wide(crate::imp::wide_trim_end(message.as_wide())).unwrap_or_default();
}

// Otherwise fallback to a generic error code description.
self.code.message()
}
}

impl From<Error> for HRESULT {
fn from(error: Error) -> Self {
let code = error.code;
let info: Option<crate::imp::IErrorInfo> = error.info.and_then(|info| info.cast().ok());

unsafe {
let _ = crate::imp::SetErrorInfo(0, info.as_ref());
if error.info.is_some() {
unsafe {
crate::imp::SetErrorInfo(0, std::mem::transmute_copy(&error.info));
}
}

code
error.code
}
}

impl From<HRESULT> for Error {
fn from(code: HRESULT) -> Self {
let info = GetErrorInfo();

// Call CapturePropagationContext here if a use case presents itself. Otherwise, we can avoid the overhead for error propagation.

Self { code, info }
}
}

Expand Down Expand Up @@ -105,32 +132,6 @@ impl From<std::convert::Infallible> for Error {
}
}

impl From<HRESULT> for Error {
fn from(code: HRESULT) -> Self {
let info: Option<crate::imp::IRestrictedErrorInfo> = GetErrorInfo().and_then(|e| e.cast()).ok();

if let Some(info) = info {
// If it does (and therefore running on a recent version of Windows)
// then capture_propagation_context adds a breadcrumb to the error
// info to make debugging easier.
if let Ok(capture) = info.cast::<crate::imp::ILanguageExceptionErrorInfo2>() {
unsafe {
let _ = capture.CapturePropagationContext(None);
}
}

return Self { code, info: Some(info) };
}

if let Ok(info) = GetErrorInfo() {
let message = unsafe { info.GetDescription().unwrap_or_default() };
Self::new(code, HSTRING::from_wide(message.as_wide()).unwrap_or_default())
} else {
Self { code, info: None }
}
}
}

impl std::fmt::Debug for Error {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut debug = fmt.debug_struct("Error");
Expand All @@ -151,6 +152,8 @@ impl std::fmt::Display for Error {

impl std::error::Error for Error {}

fn GetErrorInfo() -> Result<crate::imp::IErrorInfo> {
unsafe { crate::imp::GetErrorInfo(0) }
fn GetErrorInfo() -> Option<crate::imp::IErrorInfo> {
let mut info = None;
unsafe { crate::imp::GetErrorInfo(0, &mut info as *mut _ as _) };
info
}
2 changes: 2 additions & 0 deletions crates/libs/core/src/imp/bindings.rs
Expand Up @@ -21,6 +21,8 @@
::windows_targets::link!("ole32.dll" "system" fn CoTaskMemFree(pv : *const ::core::ffi::c_void));
::windows_targets::link!("ole32.dll" "system" fn PropVariantClear(pvar : *mut PROPVARIANT) -> HRESULT);
::windows_targets::link!("ole32.dll" "system" fn PropVariantCopy(pvardest : *mut PROPVARIANT, pvarsrc : *const PROPVARIANT) -> HRESULT);
::windows_targets::link!("oleaut32.dll" "system" fn GetErrorInfo(dwreserved : u32, pperrinfo : *mut * mut::core::ffi::c_void) -> HRESULT);
::windows_targets::link!("oleaut32.dll" "system" fn SetErrorInfo(dwreserved : u32, perrinfo : * mut::core::ffi::c_void) -> HRESULT);
::windows_targets::link!("oleaut32.dll" "system" fn SysAllocStringLen(strin : PCWSTR, ui : u32) -> BSTR);
::windows_targets::link!("oleaut32.dll" "system" fn SysFreeString(bstrstring : BSTR));
::windows_targets::link!("oleaut32.dll" "system" fn SysStringLen(pbstr : BSTR) -> u32);
Expand Down
14 changes: 0 additions & 14 deletions crates/libs/core/src/imp/com_bindings.rs
Expand Up @@ -16,20 +16,6 @@ where
let mut result__ = ::std::mem::zeroed();
RoGetAgileReference(options, riid, punk.into_param().abi(), &mut result__).from_abi(result__)
}
#[inline]
pub unsafe fn GetErrorInfo(dwreserved: u32) -> ::windows_core::Result<IErrorInfo> {
::windows_targets::link!("oleaut32.dll" "system" fn GetErrorInfo(dwreserved : u32, pperrinfo : *mut * mut::core::ffi::c_void) -> ::windows_core::HRESULT);
let mut result__ = ::std::mem::zeroed();
GetErrorInfo(dwreserved, &mut result__).from_abi(result__)
}
#[inline]
pub unsafe fn SetErrorInfo<P0>(dwreserved: u32, perrinfo: P0) -> ::windows_core::Result<()>
where
P0: ::windows_core::IntoParam<IErrorInfo>,
{
::windows_targets::link!("oleaut32.dll" "system" fn SetErrorInfo(dwreserved : u32, perrinfo : * mut::core::ffi::c_void) -> ::windows_core::HRESULT);
SetErrorInfo(dwreserved, perrinfo.into_param().abi()).ok()
}
pub const AGILEREFERENCE_DEFAULT: AgileReferenceOptions = AgileReferenceOptions(0i32);
#[repr(transparent)]
#[derive(::core::cmp::PartialEq, ::core::cmp::Eq, ::core::marker::Copy, ::core::clone::Clone, ::core::default::Default)]
Expand Down
5 changes: 4 additions & 1 deletion crates/libs/core/src/interface.rs
Expand Up @@ -79,10 +79,13 @@ pub unsafe trait Interface: Sized + Clone {
/// named cast.
fn cast<T: Interface>(&self) -> Result<T> {
let mut result = None;

// SAFETY: `result` is valid for writing an interface pointer and it is safe
// to cast the `result` pointer as `T` on success because we are using the `IID` tied
// to `T` which the implementor of `Interface` has guaranteed is correct
unsafe { self.query(&T::IID, &mut result as *mut _ as _).and_some(result) }
unsafe { _ = self.query(&T::IID, &mut result as *mut _ as _) };

result.ok_or_else(|| Error { code: crate::imp::E_NOINTERFACE, info: None })
}

/// Attempts to create a [`Weak`] reference to this object.
Expand Down
2 changes: 2 additions & 0 deletions crates/tests/core/Cargo.toml
Expand Up @@ -10,6 +10,8 @@ features = [
"implement",
"Win32_Foundation",
"Win32_System_WinRT",
"Win32_System_Ole",
"Win32_System_Com",
"Win32_Media_Audio",
]

Expand Down
38 changes: 37 additions & 1 deletion crates/tests/core/tests/error.rs
@@ -1,4 +1,7 @@
use windows::{core::*, Win32::Foundation::*, Win32::Media::Audio::*};
use windows::{
core::*, Win32::Foundation::*, Win32::Media::Audio::*, Win32::System::Com::*,
Win32::System::Ole::*, Win32::System::WinRT::*,
};

#[test]
fn display_debug() {
Expand Down Expand Up @@ -29,3 +32,36 @@ fn hresult_last_error() {
assert_eq!(e.code(), CRYPT_E_NOT_FOUND);
}
}

// Checks that non-restricted error info is reported.
#[test]
fn set_error_info() -> Result<()> {
unsafe {
let creator = CreateErrorInfo()?;
creator.SetDescription(w!("message"))?;
SetErrorInfo(0, &creator.cast::<IErrorInfo>()?)?;

assert_eq!(Error::from(E_FAIL).message(), "message");
SetErrorInfo(0, None)?;
assert_eq!(Error::from(E_FAIL).message(), "Unspecified error");

Ok(())
}
}

// https://github.com/microsoft/cppwinrt/pull/1386
#[test]
fn suppressed_error_info() -> Result<()> {
unsafe { RoSetErrorReportingFlags(RO_ERROR_REPORTING_SUPPRESSSETERRORINFO.0 as u32)? };

assert_eq!(
Error::new(E_FAIL, "message".into()).message(),
"Unspecified error"
);

unsafe { RoSetErrorReportingFlags(RO_ERROR_REPORTING_USESETERRORINFO.0 as u32)? };

assert_eq!(Error::new(E_FAIL, "message".into()).message(), "message");

Ok(())
}
4 changes: 2 additions & 2 deletions crates/tests/debugger_visualizer/tests/test.rs
Expand Up @@ -175,11 +175,11 @@ hstring : "This is an HSTRING" [Type: windows_core::strings::hstring::H
out_of_memory_error : 0x8007000e (Not enough memory resources are available to complete this operation.) [Type: windows_core::error::Error]
[<Raw View>] [Type: windows_core::error::Error]
[info] : Some [Type: enum2$<core::option::Option<windows_core::imp::com_bindings::IRestrictedErrorInfo> >]
[info] : Some [Type: enum2$<core::option::Option<windows_core::imp::com_bindings::IErrorInfo> >]
invalid_argument_error : 0x80070057 (The parameter is incorrect.) [Type: windows_core::error::Error]
[<Raw View>] [Type: windows_core::error::Error]
[info] : Some [Type: enum2$<core::option::Option<windows_core::imp::com_bindings::IRestrictedErrorInfo> >]
[info] : Some [Type: enum2$<core::option::Option<windows_core::imp::com_bindings::IErrorInfo> >]
"#
)]
fn test_debugger_visualizer() {
Expand Down
2 changes: 1 addition & 1 deletion crates/tests/implement/tests/data_object.rs
Expand Up @@ -100,7 +100,7 @@ fn test() -> Result<()> {
assert!(r.is_err());
let e = r.unwrap_err();
assert!(e.code() == S_OK);
assert!(e.info().is_none());
assert!(e.info::<IUnknown>().is_none());

d.DAdvise(&Default::default(), 0, None)?;

Expand Down
2 changes: 2 additions & 0 deletions crates/tools/core/bindings.txt
Expand Up @@ -70,3 +70,5 @@
Windows.Win32.System.Variant.VT_UNKNOWN
Windows.Win32.System.WinRT.RoGetActivationFactory
Windows.Win32.System.WinRT.RoOriginateError
Windows.Win32.System.Com.GetErrorInfo
Windows.Win32.System.Com.SetErrorInfo
2 changes: 0 additions & 2 deletions crates/tools/core/com_bindings.txt
Expand Up @@ -17,10 +17,8 @@
Windows.Win32.Foundation.RPC_E_DISCONNECTED
Windows.Win32.Foundation.TYPE_E_TYPEMISMATCH
Windows.Win32.System.Com.CoCreateGuid
Windows.Win32.System.Com.GetErrorInfo
Windows.Win32.System.Com.IAgileObject
Windows.Win32.System.Com.IErrorInfo
Windows.Win32.System.Com.SetErrorInfo
Windows.Win32.System.WinRT.AGILEREFERENCE_DEFAULT
Windows.Win32.System.WinRT.IAgileReference
Windows.Win32.System.WinRT.ILanguageExceptionErrorInfo2
Expand Down

0 comments on commit a2fbd10

Please sign in to comment.