diff --git a/Cargo.lock b/Cargo.lock index 37ed636..c2f7643 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -685,6 +685,12 @@ version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" +[[package]] +name = "widestring" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7219d36b6eac893fa81e84ebe06485e7dcbb616177469b142df14f1f4deb1311" + [[package]] name = "winapi-util" version = "0.1.9" @@ -894,6 +900,7 @@ dependencies = [ "semver", "thiserror 2.0.7", "typed-path", + "widestring", "windows", "wslplugins-macro", "wslplugins-sys", diff --git a/wslplugins-rs/Cargo.toml b/wslplugins-rs/Cargo.toml index 8609aef..f81d165 100644 --- a/wslplugins-rs/Cargo.toml +++ b/wslplugins-rs/Cargo.toml @@ -9,15 +9,16 @@ version = "0.58" features = ["Win32_Foundation", "Win32_System", "Win32_Networking_WinSock"] [dependencies] -wslplugins-sys = { path = "../wslplugins-sys" } -typed-path = ">0.1" bitflags = { version = ">0.1.0", optional = true } -flagset = { version = ">0.1.0", optional = true } enumflags2 = { version = ">0.5", optional = true } +flagset = { version = ">0.1.0", optional = true } log = { version = "*", optional = true } log-instrument = { version = "*", optional = true } -wslplugins-macro = { path = "../wslplugins-macro", optional = true } thiserror = "2.0.7" +typed-path = ">0.1" +widestring = { version = "1", features = ["alloc"] } +wslplugins-macro = { path = "../wslplugins-macro", optional = true } +wslplugins-sys = { path = "../wslplugins-sys" } [dependencies.semver] version = ">0.1" diff --git a/wslplugins-rs/src/api/api_v1.rs b/wslplugins-rs/src/api/api_v1.rs index 2be24bd..544f038 100644 --- a/wslplugins-rs/src/api/api_v1.rs +++ b/wslplugins-rs/src/api/api_v1.rs @@ -3,11 +3,11 @@ extern crate wslplugins_sys; use super::Error; use super::Result; use crate::api::errors::require_update_error::Result as UpReqResult; -use crate::utils::{cstring_from_str, encode_wide_null_terminated}; +use crate::cstring_ext::CstringExt; use crate::wsl_session_information::WSLSessionInformation; #[cfg(feature = "log-instrument")] use log_instrument::instrument; -use std::ffi::{CString, OsStr, OsString}; +use std::ffi::{CString, OsStr}; use std::fmt::Debug; use std::iter::once; use std::mem::MaybeUninit; @@ -15,8 +15,8 @@ use std::net::TcpStream; use std::os::windows::io::FromRawSocket; use std::os::windows::raw::SOCKET; use std::path::Path; -use std::str::FromStr; use typed_path::Utf8UnixPath; +use widestring::U16CString; use windows::Win32::Networking::WinSock::SOCKET as WinSocket; use windows::{ core::{Result as WinResult, GUID, PCSTR, PCWSTR}, @@ -99,13 +99,10 @@ impl ApiV1 { read_only: bool, name: &OsStr, ) -> WinResult<()> { - let encoded_windows_path = encode_wide_null_terminated(windows_path.as_ref().as_os_str()); - let encoded_linux_path = encode_wide_null_terminated( - OsString::from_str(linux_path.as_ref().as_str()) - .unwrap() - .as_os_str(), - ); - let encoded_name = encode_wide_null_terminated(name); + let encoded_windows_path = + U16CString::from_os_str_truncate(windows_path.as_ref().as_os_str()); + let encoded_linux_path = U16CString::from_str_truncate(linux_path.as_ref().as_str()); + let encoded_name = U16CString::from_os_str_truncate(name); let result = unsafe { self.0.MountFolder.unwrap_unchecked()( session.id(), @@ -163,7 +160,10 @@ impl ApiV1 { .copied() .chain(once(0)) .collect(); - let c_args: Vec = args.iter().map(|&arg| cstring_from_str(arg)).collect(); + let c_args: Vec = args + .iter() + .map(|&arg| CString::from_str_truncate(arg)) + .collect(); let mut args_ptrs: Vec = c_args .iter() .map(|arg| PCSTR::from_raw(arg.as_ptr() as *const u8)) @@ -188,8 +188,10 @@ impl ApiV1 { /// Set the error message to display to the user if the VM or distribution creation fails. #[cfg_attr(feature = "log-instrument", instrument)] pub(crate) fn plugin_error(&self, error: &OsStr) -> WinResult<()> { - let error_vec = encode_wide_null_terminated(error); - unsafe { self.0.PluginError.unwrap_unchecked()(PCWSTR::from_raw(error_vec.as_ptr())).ok() } + let error_utf16 = U16CString::from_os_str_truncate(error); + unsafe { + self.0.PluginError.unwrap_unchecked()(PCWSTR::from_raw(error_utf16.as_ptr())).ok() + } } /// Execute a program in a user distribution @@ -242,7 +244,10 @@ impl ApiV1 { .chain(once(0)) .collect(); let path_ptr = PCSTR::from_raw(c_path.as_ptr()); - let c_args: Vec = args.iter().map(|&arg| cstring_from_str(arg)).collect(); + let c_args: Vec = args + .iter() + .map(|&arg| CString::from_str_truncate(arg)) + .collect(); let mut args_ptrs: Vec = c_args .iter() .map(|arg| PCSTR::from_raw(arg.as_ptr() as *const u8)) diff --git a/wslplugins-rs/src/cstring_ext.rs b/wslplugins-rs/src/cstring_ext.rs new file mode 100644 index 0000000..ed6fa56 --- /dev/null +++ b/wslplugins-rs/src/cstring_ext.rs @@ -0,0 +1,59 @@ +use std::ffi::CString; + +pub(crate) trait CstringExt { + /// Creates a `CString` from a string slice, truncating at the first null byte if present. + fn from_str_truncate(value: &str) -> Self; +} + +impl CstringExt for CString { + fn from_str_truncate(value: &str) -> Self { + let bytes = value.as_bytes(); + let truncated_bytes = match bytes.iter().position(|&b| b == 0) { + Some(pos) => &bytes[..pos], + None => bytes, + }; + // SAFETY: `truncated_bytes` is guaranteed not to contain null bytes. + unsafe { Self::from_vec_unchecked(truncated_bytes.to_vec()) } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::ffi::CString; + + #[test] + fn test_from_str_truncate_no_null() { + let input = "Hello, world!"; + let cstring = CString::from_str_truncate(input); + assert_eq!(cstring.to_str().unwrap(), input); + } + + #[test] + fn test_from_str_truncate_with_null() { + let input = "Hello\0world!"; + let cstring = CString::from_str_truncate(input); + assert_eq!(cstring.to_str().unwrap(), "Hello"); + } + + #[test] + fn test_from_str_truncate_empty() { + let input = ""; + let cstring = CString::from_str_truncate(input); + assert_eq!(cstring.to_str().unwrap(), ""); + } + + #[test] + fn test_from_str_truncate_null_only() { + let input = "\0"; + let cstring = CString::from_str_truncate(input); + assert_eq!(cstring.to_str().unwrap(), ""); + } + + #[test] + fn test_from_str_truncate_null_in_middle() { + let input = "Rust\0is awesome!"; + let cstring = CString::from_str_truncate(input); + assert_eq!(cstring.to_str().unwrap(), "Rust"); + } +} diff --git a/wslplugins-rs/src/lib.rs b/wslplugins-rs/src/lib.rs index 6dc3da0..5b63b3d 100644 --- a/wslplugins-rs/src/lib.rs +++ b/wslplugins-rs/src/lib.rs @@ -41,6 +41,7 @@ pub mod api; // Internal modules for managing specific WSL features. mod core_distribution_information; +pub(crate) mod cstring_ext; mod distribution_information; mod offline_distribution_information; mod utils; diff --git a/wslplugins-rs/src/utils.rs b/wslplugins-rs/src/utils.rs index 96ffa4a..feb4eff 100644 --- a/wslplugins-rs/src/utils.rs +++ b/wslplugins-rs/src/utils.rs @@ -1,66 +1,5 @@ -//! # String Encoding Utilities -//! -//! This module provides utility functions to handle string encoding conversions, specifically for: -//! - Encoding `OsStr` as wide, null-terminated UTF-16 strings. -//! - Creating `CString` instances from Rust strings, filtering out null bytes. - -use std::ffi::{CString, OsStr}; -use std::os::windows::ffi::OsStrExt; - -pub fn encode_wide_null_terminated(input: &OsStr) -> Vec { - input - .encode_wide() - .filter(|&c| c != 0) - .chain(Some(0)) - .collect() -} - -pub fn cstring_from_str(input: &str) -> CString { - let filtered_input: Vec = input.bytes().filter(|&c| c != 0).collect(); - unsafe { CString::from_vec_unchecked(filtered_input) } -} - #[cfg(test)] pub(crate) fn test_transparence() { assert_eq!(align_of::(), align_of::()); assert_eq!(size_of::(), size_of::()); } - -#[cfg(test)] -mod tests { - use super::*; - use std::ffi::OsString; - - /// Tests `encode_wide_null_terminated` with a string containing no null characters. - #[test] - fn test_encode_wide_null_terminated_no_nulls() { - let input = OsString::from("Hello"); - let expected: Vec = "Hello\0".encode_utf16().collect(); - assert_eq!(encode_wide_null_terminated(&input), expected); - } - - /// Tests `encode_wide_null_terminated` with a string containing null characters. - #[test] - fn test_encode_wide_null_terminated_with_nulls() { - let input = OsString::from("Hel\0lo"); - let expected: Vec = "Hello\0".encode_utf16().collect(); - assert_eq!(encode_wide_null_terminated(&input), expected); - } - - /// Tests `cstring_from_str` with a string containing no null characters. - #[test] - fn test_cstring_from_str_no_nulls() { - let input = "Hello"; - let cstring = cstring_from_str(input); - assert_eq!(cstring.to_str().unwrap(), input); - } - - /// Tests `cstring_from_str` with a string containing null characters. - #[test] - fn test_cstring_from_str_with_nulls() { - let input = "Hel\0lo"; - let cstring = cstring_from_str(input); - let expected = "Hello".as_bytes(); - assert_eq!(cstring.into_bytes(), expected); - } -}