Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 162 additions & 3 deletions codex-rs/exec/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,79 @@ fn load_output_schema(path: Option<PathBuf>) -> Option<Value> {
}
}

#[derive(Debug, Clone, PartialEq, Eq)]
enum PromptDecodeError {
InvalidUtf8 { valid_up_to: usize },
InvalidUtf16 { encoding: &'static str },
UnsupportedBom { encoding: &'static str },
}

impl std::fmt::Display for PromptDecodeError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PromptDecodeError::InvalidUtf8 { valid_up_to } => write!(
f,
"input is not valid UTF-8 (invalid byte at offset {valid_up_to}). Convert it to UTF-8 and retry (e.g., `iconv -f <ENC> -t UTF-8 prompt.txt`)."
),
PromptDecodeError::InvalidUtf16 { encoding } => write!(
f,
"input looked like {encoding} but could not be decoded. Convert it to UTF-8 and retry."
),
PromptDecodeError::UnsupportedBom { encoding } => write!(
f,
"input appears to be {encoding}. Convert it to UTF-8 and retry."
),
}
}
}

fn decode_prompt_bytes(input: &[u8]) -> Result<String, PromptDecodeError> {
let input = input.strip_prefix(&[0xEF, 0xBB, 0xBF]).unwrap_or(input);

if input.starts_with(&[0xFF, 0xFE, 0x00, 0x00]) {
return Err(PromptDecodeError::UnsupportedBom {
encoding: "UTF-32LE",
});
}

if input.starts_with(&[0x00, 0x00, 0xFE, 0xFF]) {
return Err(PromptDecodeError::UnsupportedBom {
encoding: "UTF-32BE",
});
}

if let Some(rest) = input.strip_prefix(&[0xFF, 0xFE]) {
return decode_utf16(rest, "UTF-16LE", u16::from_le_bytes);
}

if let Some(rest) = input.strip_prefix(&[0xFE, 0xFF]) {
return decode_utf16(rest, "UTF-16BE", u16::from_be_bytes);
}

std::str::from_utf8(input)
.map(str::to_string)
.map_err(|e| PromptDecodeError::InvalidUtf8 {
valid_up_to: e.valid_up_to(),
})
}

fn decode_utf16(
input: &[u8],
encoding: &'static str,
decode_unit: fn([u8; 2]) -> u16,
) -> Result<String, PromptDecodeError> {
if !input.len().is_multiple_of(2) {
return Err(PromptDecodeError::InvalidUtf16 { encoding });
}

let units: Vec<u16> = input
.chunks_exact(2)
.map(|chunk| decode_unit([chunk[0], chunk[1]]))
.collect();

String::from_utf16(&units).map_err(|_| PromptDecodeError::InvalidUtf16 { encoding })
}

fn resolve_prompt(prompt_arg: Option<String>) -> String {
match prompt_arg {
Some(p) if p != "-" => p,
Expand All @@ -571,11 +644,22 @@ fn resolve_prompt(prompt_arg: Option<String>) -> String {
if !force_stdin {
eprintln!("Reading prompt from stdin...");
}
let mut buffer = String::new();
if let Err(e) = std::io::stdin().read_to_string(&mut buffer) {

let mut bytes = Vec::new();
if let Err(e) = std::io::stdin().read_to_end(&mut bytes) {
eprintln!("Failed to read prompt from stdin: {e}");
std::process::exit(1);
} else if buffer.trim().is_empty() {
}

let buffer = match decode_prompt_bytes(&bytes) {
Ok(s) => s,
Err(e) => {
eprintln!("Failed to read prompt from stdin: {e}");
std::process::exit(1);
}
};

if buffer.trim().is_empty() {
eprintln!("No prompt provided via stdin.");
std::process::exit(1);
}
Expand Down Expand Up @@ -680,4 +764,79 @@ mod tests {

assert_eq!(request, expected);
}

#[test]
fn decode_prompt_bytes_strips_utf8_bom() {
let input = [0xEF, 0xBB, 0xBF, b'h', b'i', b'\n'];

let out = decode_prompt_bytes(&input).expect("decode utf-8 with BOM");

assert_eq!(out, "hi\n");
}

#[test]
fn decode_prompt_bytes_decodes_utf16le_bom() {
// UTF-16LE BOM + "hi\n"
let input = [0xFF, 0xFE, b'h', 0x00, b'i', 0x00, b'\n', 0x00];

let out = decode_prompt_bytes(&input).expect("decode utf-16le with BOM");

assert_eq!(out, "hi\n");
}

#[test]
fn decode_prompt_bytes_decodes_utf16be_bom() {
// UTF-16BE BOM + "hi\n"
let input = [0xFE, 0xFF, 0x00, b'h', 0x00, b'i', 0x00, b'\n'];

let out = decode_prompt_bytes(&input).expect("decode utf-16be with BOM");

assert_eq!(out, "hi\n");
}

#[test]
fn decode_prompt_bytes_rejects_utf32le_bom() {
// UTF-32LE BOM + "hi\n"
let input = [
0xFF, 0xFE, 0x00, 0x00, b'h', 0x00, 0x00, 0x00, b'i', 0x00, 0x00, 0x00, b'\n', 0x00,
0x00, 0x00,
];

let err = decode_prompt_bytes(&input).expect_err("utf-32le should be rejected");

assert_eq!(
err,
PromptDecodeError::UnsupportedBom {
encoding: "UTF-32LE"
}
);
}

#[test]
fn decode_prompt_bytes_rejects_utf32be_bom() {
// UTF-32BE BOM + "hi\n"
let input = [
0x00, 0x00, 0xFE, 0xFF, 0x00, 0x00, 0x00, b'h', 0x00, 0x00, 0x00, b'i', 0x00, 0x00,
0x00, b'\n',
];

let err = decode_prompt_bytes(&input).expect_err("utf-32be should be rejected");

assert_eq!(
err,
PromptDecodeError::UnsupportedBom {
encoding: "UTF-32BE"
}
);
}

#[test]
fn decode_prompt_bytes_rejects_invalid_utf8() {
// Invalid UTF-8 sequence: 0xC3 0x28
let input = [0xC3, 0x28];

let err = decode_prompt_bytes(&input).expect_err("invalid utf-8 should fail");

assert_eq!(err, PromptDecodeError::InvalidUtf8 { valid_up_to: 0 });
}
}
Loading