Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 5 additions & 4 deletions wslplugins-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,16 @@ version = "0.58"
features = ["Win32_Foundation", "Win32_System", "Win32_Networking_WinSock"]

[dependencies]
wslplugins-sys = { path = "../wslplugins-sys" }
typed-path = ">0.1"
bitflags = { version = ">0.1.0", optional = true }
flagset = { version = ">0.1.0", optional = true }
enumflags2 = { version = ">0.5", optional = true }
flagset = { version = ">0.1.0", optional = true }
log = { version = "*", optional = true }
log-instrument = { version = "*", optional = true }
wslplugins-macro = { path = "../wslplugins-macro", optional = true }
thiserror = "2.0.7"
typed-path = ">0.1"
widestring = { version = "1", features = ["alloc"] }
wslplugins-macro = { path = "../wslplugins-macro", optional = true }
wslplugins-sys = { path = "../wslplugins-sys" }

[dependencies.semver]
version = ">0.1"
Expand Down
33 changes: 19 additions & 14 deletions wslplugins-rs/src/api/api_v1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,20 @@ extern crate wslplugins_sys;
use super::Error;
use super::Result;
use crate::api::errors::require_update_error::Result as UpReqResult;
use crate::utils::{cstring_from_str, encode_wide_null_terminated};
use crate::cstring_ext::CstringExt;
use crate::wsl_session_information::WSLSessionInformation;
#[cfg(feature = "log-instrument")]
use log_instrument::instrument;
use std::ffi::{CString, OsStr, OsString};
use std::ffi::{CString, OsStr};
use std::fmt::Debug;
use std::iter::once;
use std::mem::MaybeUninit;
use std::net::TcpStream;
use std::os::windows::io::FromRawSocket;
use std::os::windows::raw::SOCKET;
use std::path::Path;
use std::str::FromStr;
use typed_path::Utf8UnixPath;
use widestring::U16CString;
use windows::Win32::Networking::WinSock::SOCKET as WinSocket;
use windows::{
core::{Result as WinResult, GUID, PCSTR, PCWSTR},
Expand Down Expand Up @@ -99,13 +99,10 @@ impl ApiV1 {
read_only: bool,
name: &OsStr,
) -> WinResult<()> {
let encoded_windows_path = encode_wide_null_terminated(windows_path.as_ref().as_os_str());
let encoded_linux_path = encode_wide_null_terminated(
OsString::from_str(linux_path.as_ref().as_str())
.unwrap()
.as_os_str(),
);
let encoded_name = encode_wide_null_terminated(name);
let encoded_windows_path =
U16CString::from_os_str_truncate(windows_path.as_ref().as_os_str());
let encoded_linux_path = U16CString::from_str_truncate(linux_path.as_ref().as_str());
let encoded_name = U16CString::from_os_str_truncate(name);
let result = unsafe {
self.0.MountFolder.unwrap_unchecked()(
session.id(),
Expand Down Expand Up @@ -163,7 +160,10 @@ impl ApiV1 {
.copied()
.chain(once(0))
.collect();
let c_args: Vec<CString> = args.iter().map(|&arg| cstring_from_str(arg)).collect();
let c_args: Vec<CString> = args
.iter()
.map(|&arg| CString::from_str_truncate(arg))
.collect();
let mut args_ptrs: Vec<PCSTR> = c_args
.iter()
.map(|arg| PCSTR::from_raw(arg.as_ptr() as *const u8))
Expand All @@ -188,8 +188,10 @@ impl ApiV1 {
/// Set the error message to display to the user if the VM or distribution creation fails.
#[cfg_attr(feature = "log-instrument", instrument)]
pub(crate) fn plugin_error(&self, error: &OsStr) -> WinResult<()> {
let error_vec = encode_wide_null_terminated(error);
unsafe { self.0.PluginError.unwrap_unchecked()(PCWSTR::from_raw(error_vec.as_ptr())).ok() }
let error_utf16 = U16CString::from_os_str_truncate(error);
unsafe {
self.0.PluginError.unwrap_unchecked()(PCWSTR::from_raw(error_utf16.as_ptr())).ok()
}
}

/// Execute a program in a user distribution
Expand Down Expand Up @@ -242,7 +244,10 @@ impl ApiV1 {
.chain(once(0))
.collect();
let path_ptr = PCSTR::from_raw(c_path.as_ptr());
let c_args: Vec<CString> = args.iter().map(|&arg| cstring_from_str(arg)).collect();
let c_args: Vec<CString> = args
.iter()
.map(|&arg| CString::from_str_truncate(arg))
.collect();
let mut args_ptrs: Vec<PCSTR> = c_args
.iter()
.map(|arg| PCSTR::from_raw(arg.as_ptr() as *const u8))
Expand Down
59 changes: 59 additions & 0 deletions wslplugins-rs/src/cstring_ext.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
use std::ffi::CString;

pub(crate) trait CstringExt {
/// Creates a `CString` from a string slice, truncating at the first null byte if present.
fn from_str_truncate(value: &str) -> Self;
}

impl CstringExt for CString {
fn from_str_truncate(value: &str) -> Self {
let bytes = value.as_bytes();
let truncated_bytes = match bytes.iter().position(|&b| b == 0) {
Some(pos) => &bytes[..pos],
None => bytes,
};
// SAFETY: `truncated_bytes` is guaranteed not to contain null bytes.
unsafe { Self::from_vec_unchecked(truncated_bytes.to_vec()) }
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::ffi::CString;

#[test]
fn test_from_str_truncate_no_null() {
let input = "Hello, world!";
let cstring = CString::from_str_truncate(input);
assert_eq!(cstring.to_str().unwrap(), input);
}

#[test]
fn test_from_str_truncate_with_null() {
let input = "Hello\0world!";
let cstring = CString::from_str_truncate(input);
assert_eq!(cstring.to_str().unwrap(), "Hello");
}

#[test]
fn test_from_str_truncate_empty() {
let input = "";
let cstring = CString::from_str_truncate(input);
assert_eq!(cstring.to_str().unwrap(), "");
}

#[test]
fn test_from_str_truncate_null_only() {
let input = "\0";
let cstring = CString::from_str_truncate(input);
assert_eq!(cstring.to_str().unwrap(), "");
}

#[test]
fn test_from_str_truncate_null_in_middle() {
let input = "Rust\0is awesome!";
let cstring = CString::from_str_truncate(input);
assert_eq!(cstring.to_str().unwrap(), "Rust");
}
}
1 change: 1 addition & 0 deletions wslplugins-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ pub mod api;

// Internal modules for managing specific WSL features.
mod core_distribution_information;
pub(crate) mod cstring_ext;
mod distribution_information;
mod offline_distribution_information;
mod utils;
Expand Down
61 changes: 0 additions & 61 deletions wslplugins-rs/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,66 +1,5 @@
//! # String Encoding Utilities
//!
//! This module provides utility functions to handle string encoding conversions, specifically for:
//! - Encoding `OsStr` as wide, null-terminated UTF-16 strings.
//! - Creating `CString` instances from Rust strings, filtering out null bytes.

use std::ffi::{CString, OsStr};
use std::os::windows::ffi::OsStrExt;

pub fn encode_wide_null_terminated(input: &OsStr) -> Vec<u16> {
input
.encode_wide()
.filter(|&c| c != 0)
.chain(Some(0))
.collect()
}

pub fn cstring_from_str(input: &str) -> CString {
let filtered_input: Vec<u8> = input.bytes().filter(|&c| c != 0).collect();
unsafe { CString::from_vec_unchecked(filtered_input) }
}

#[cfg(test)]
pub(crate) fn test_transparence<T, U>() {
assert_eq!(align_of::<T>(), align_of::<U>());
assert_eq!(size_of::<T>(), size_of::<U>());
}

#[cfg(test)]
mod tests {
use super::*;
use std::ffi::OsString;

/// Tests `encode_wide_null_terminated` with a string containing no null characters.
#[test]
fn test_encode_wide_null_terminated_no_nulls() {
let input = OsString::from("Hello");
let expected: Vec<u16> = "Hello\0".encode_utf16().collect();
assert_eq!(encode_wide_null_terminated(&input), expected);
}

/// Tests `encode_wide_null_terminated` with a string containing null characters.
#[test]
fn test_encode_wide_null_terminated_with_nulls() {
let input = OsString::from("Hel\0lo");
let expected: Vec<u16> = "Hello\0".encode_utf16().collect();
assert_eq!(encode_wide_null_terminated(&input), expected);
}

/// Tests `cstring_from_str` with a string containing no null characters.
#[test]
fn test_cstring_from_str_no_nulls() {
let input = "Hello";
let cstring = cstring_from_str(input);
assert_eq!(cstring.to_str().unwrap(), input);
}

/// Tests `cstring_from_str` with a string containing null characters.
#[test]
fn test_cstring_from_str_with_nulls() {
let input = "Hel\0lo";
let cstring = cstring_from_str(input);
let expected = "Hello".as_bytes();
assert_eq!(cstring.into_bytes(), expected);
}
}