From f9513f7394f65ed2683ec53b448b68a1ae751cab Mon Sep 17 00:00:00 2001 From: Max Tropets Date: Fri, 13 Mar 2026 15:21:33 +0000 Subject: [PATCH] RSA keys support --- native/rust/cose_openssl/src/cose.rs | 211 ++++++++++++++++-- native/rust/cose_openssl/src/lib.rs | 2 +- native/rust/cose_openssl/src/ossl_wrappers.rs | 136 ++++++++++- native/rust/cose_openssl/src/sign.rs | 21 +- native/rust/cose_openssl/src/verify.rs | 21 +- 5 files changed, 366 insertions(+), 25 deletions(-) diff --git a/native/rust/cose_openssl/src/cose.rs b/native/rust/cose_openssl/src/cose.rs index 1bf00cd6..bfcc4b91 100644 --- a/native/rust/cose_openssl/src/cose.rs +++ b/native/rust/cose_openssl/src/cose.rs @@ -1,6 +1,7 @@ use crate::cbor::CborValue; use crate::ossl_wrappers::{ - EvpKey, KeyType, WhichEC, ecdsa_der_to_fixed, ecdsa_fixed_to_der, + EvpKey, KeyType, WhichEC, WhichRSA, ecdsa_der_to_fixed, ecdsa_fixed_to_der, + rsa_pss_md_for_cose_alg, }; #[cfg(feature = "pqc")] @@ -18,6 +19,9 @@ fn cose_alg(key: &EvpKey) -> Result { KeyType::EC(WhichEC::P256) => Ok(-7), KeyType::EC(WhichEC::P384) => Ok(-35), KeyType::EC(WhichEC::P521) => Ok(-36), + KeyType::RSA(WhichRSA::PS256) => Ok(-37), + KeyType::RSA(WhichRSA::PS384) => Ok(-38), + KeyType::RSA(WhichRSA::PS512) => Ok(-39), #[cfg(feature = "pqc")] KeyType::MLDSA(which) => match which { WhichMLDSA::P44 => Ok(-48), @@ -120,6 +124,7 @@ pub fn cose_sign1( let sig = match &key.typ { KeyType::EC(_) => ecdsa_der_to_fixed(&sig, key.ec_field_size()?)?, + KeyType::RSA(_) => sig, #[cfg(feature = "pqc")] KeyType::MLDSA(_) => sig, }; @@ -156,24 +161,40 @@ pub fn cose_sign1_encoded( cose_sign1(key, phdr_value, uhdr_value, payload, detached) } -/// Check that the algorithm encoded in the phdr matches the key type. -fn check_phdr_alg(key: &EvpKey, phdr_bytes: &[u8]) -> Result<(), String> { +/// Check that the algorithm encoded in the phdr is compatible with the key. +/// For RSA keys, any PS* algorithm is accepted (returns the alg value). +/// For other keys, exact match is required. +fn check_phdr_alg(key: &EvpKey, phdr_bytes: &[u8]) -> Result { let parsed = CborValue::from_bytes(phdr_bytes)?; let alg = parsed .map_at_int(COSE_HEADER_ALG) .map_err(|_| "Algorithm not found in protected header".to_string())?; - let expected = cose_alg(key)?; - match alg { - CborValue::Int(v) if *v == expected => Ok(()), - CborValue::Int(_) => { - Err("Algorithm mismatch between protected header and key" - .to_string()) + let alg_val = match alg { + CborValue::Int(v) => *v, + _ => { + return Err( + "Algorithm value in protected header is not an integer" + .to_string(), + ); + } + }; + + match &key.typ { + KeyType::RSA(_) => { + // Accept any PS* algorithm with any RSA key. + rsa_pss_md_for_cose_alg(alg_val)?; + Ok(alg_val) } _ => { - Err("Algorithm value in protected header is not an integer" - .to_string()) + let expected = cose_alg(key)?; + if alg_val == expected { + Ok(alg_val) + } else { + Err("Algorithm mismatch between protected header and key" + .to_string()) + } } } } @@ -187,7 +208,7 @@ pub fn cose_verify1( ) -> Result { let (phdr_bytes, cose_payload, cose_sig) = parse_cose_sign1(envelope)?; - check_phdr_alg(key, &phdr_bytes)?; + let header_alg = check_phdr_alg(key, &phdr_bytes)?; let actual_payload = match payload { Some(p) => p.to_vec(), @@ -206,12 +227,20 @@ pub fn cose_verify1( let sig = match &key.typ { KeyType::EC(_) => ecdsa_fixed_to_der(&sig, key.ec_field_size()?)?, + KeyType::RSA(_) => sig, #[cfg(feature = "pqc")] KeyType::MLDSA(_) => sig, }; let tbs = sig_structure(&phdr_bytes, &actual_payload)?; - crate::verify::verify(key, &sig, &tbs) + + match &key.typ { + KeyType::RSA(_) => { + let md = rsa_pss_md_for_cose_alg(header_alg)?; + crate::verify::verify_with_md(key, &sig, &tbs, md) + } + _ => crate::verify::verify(key, &sig, &tbs), + } } /// Verify a COSE_Sign1 from pre-parsed components, skipping all CBOR @@ -225,19 +254,37 @@ pub fn cose_verify1_decoded( payload: &[u8], sig: &[u8], ) -> Result { - let expected_alg = cose_alg(key)?; - if alg != expected_alg { - return Err("Algorithm mismatch between supplied alg and key".into()); + match &key.typ { + KeyType::RSA(_) => { + // For RSA, accept any PS* algorithm regardless of key size. + rsa_pss_md_for_cose_alg(alg)?; + } + _ => { + let expected_alg = cose_alg(key)?; + if alg != expected_alg { + return Err( + "Algorithm mismatch between supplied alg and key".into() + ); + } + } } let sig = match &key.typ { KeyType::EC(_) => ecdsa_fixed_to_der(sig, key.ec_field_size()?)?, + KeyType::RSA(_) => sig.to_vec(), #[cfg(feature = "pqc")] KeyType::MLDSA(_) => sig.to_vec(), }; let tbs = sig_structure(phdr, payload)?; - crate::verify::verify(key, &sig, &tbs) + + match &key.typ { + KeyType::RSA(_) => { + let md = rsa_pss_md_for_cose_alg(alg)?; + crate::verify::verify_with_md(key, &sig, &tbs, md) + } + _ => crate::verify::verify(key, &sig, &tbs), + } } #[cfg(test)] @@ -437,6 +484,136 @@ mod tests { assert!(cose_verify1(&verification_key, &envelope, None).unwrap()); } + fn sign_verify_cose_rsa(key_type: KeyType) { + let key = EvpKey::new(key_type).unwrap(); + let phdr_bytes = hex_decode(TEST_PHDR); + let phdr = CborValue::from_bytes(&phdr_bytes).unwrap(); + let uhdr = CborValue::Map(vec![]); + let payload = b"RSA PSS test"; + + let envelope = cose_sign1(&key, phdr, uhdr, payload, false).unwrap(); + assert!(cose_verify1(&key, &envelope, None).unwrap()); + } + + #[test] + fn cose_rsa_ps256() { + sign_verify_cose_rsa(KeyType::RSA(WhichRSA::PS256)); + } + + #[test] + fn cose_rsa_ps384() { + sign_verify_cose_rsa(KeyType::RSA(WhichRSA::PS384)); + } + + #[test] + fn cose_rsa_ps512() { + sign_verify_cose_rsa(KeyType::RSA(WhichRSA::PS512)); + } + + #[test] + fn cose_rsa_with_der_imported_key() { + let original_key = EvpKey::new(KeyType::RSA(WhichRSA::PS256)).unwrap(); + + let priv_der = original_key.to_der_private().unwrap(); + let signing_key = EvpKey::from_der_private(&priv_der).unwrap(); + + let pub_der = original_key.to_der_public().unwrap(); + let verification_key = EvpKey::from_der_public(&pub_der).unwrap(); + + let phdr_bytes = hex_decode(TEST_PHDR); + let phdr = CborValue::from_bytes(&phdr_bytes).unwrap(); + let uhdr = CborValue::Map(vec![]); + let payload = b"RSA with DER-imported key"; + + let envelope = + cose_sign1(&signing_key, phdr, uhdr, payload, false).unwrap(); + assert!(cose_verify1(&verification_key, &envelope, None).unwrap()); + } + + #[test] + fn cose_rsa_detached_payload() { + let key = EvpKey::new(KeyType::RSA(WhichRSA::PS384)).unwrap(); + let phdr_bytes = hex_decode(TEST_PHDR); + let phdr = CborValue::from_bytes(&phdr_bytes).unwrap(); + let uhdr = CborValue::Map(vec![]); + let payload = b"RSA detached"; + + let envelope = cose_sign1(&key, phdr, uhdr, payload, true).unwrap(); + assert!(cose_verify1(&key, &envelope, Some(payload)).unwrap()); + assert!(cose_verify1(&key, &envelope, None).is_err()); + } + + #[test] + fn cose_verify1_decoded_rsa() { + let key = EvpKey::new(KeyType::RSA(WhichRSA::PS256)).unwrap(); + let phdr_bytes = hex_decode(TEST_PHDR); + let phdr = CborValue::from_bytes(&phdr_bytes).unwrap(); + let uhdr = CborValue::Map(vec![]); + let payload = b"RSA decoded verify"; + + let envelope = cose_sign1(&key, phdr, uhdr, payload, false).unwrap(); + + let (phdr_raw, cose_payload, cose_sig) = + parse_cose_sign1(&envelope).unwrap(); + let sig = match cose_sig { + CborValue::ByteString(b) => b, + _ => panic!("sig not bstr"), + }; + let embedded_payload = match cose_payload { + CborValue::ByteString(b) => b, + _ => panic!("payload not bstr"), + }; + + let alg = cose_alg(&key).unwrap(); + assert!( + cose_verify1_decoded(&key, alg, &phdr_raw, &embedded_payload, &sig) + .unwrap() + ); + } + + /// Sign with a PS256 key (2048-bit RSA) but use SHA-384 (PS384 + /// algorithm). Verify must succeed because the header's algorithm + /// drives the digest, not the key's WhichRSA variant. + #[test] + fn cose_rsa_ps256_key_with_sha384() { + use crate::ossl_wrappers::rsa_pss_md_for_cose_alg; + + let key = EvpKey::new(KeyType::RSA(WhichRSA::PS256)).unwrap(); + let payload = b"PS256 key, SHA-384 digest"; + + // Build phdr with alg = -38 (PS384) already set. + let phdr_bytes = hex_decode(TEST_PHDR); + let mut phdr = CborValue::from_bytes(&phdr_bytes).unwrap(); + if let CborValue::Map(ref mut entries) = phdr { + entries.insert( + 0, + (CborValue::Int(COSE_HEADER_ALG), CborValue::Int(-38)), + ); + } + let phdr_ser = phdr.to_bytes().unwrap(); + + // Build TBS and sign with SHA-384. + let tbs = sig_structure(&phdr_ser, payload).unwrap(); + let md = rsa_pss_md_for_cose_alg(-38).unwrap(); + let sig = crate::sign::sign_with_md(&key, &tbs, md).unwrap(); + + // Assemble the COSE_Sign1 envelope. + let envelope = CborValue::Tagged { + tag: COSE_SIGN1_TAG, + payload: Box::new(CborValue::Array(vec![ + CborValue::ByteString(phdr_ser), + CborValue::Map(vec![]), + CborValue::ByteString(payload.to_vec()), + CborValue::ByteString(sig), + ])), + } + .to_bytes() + .unwrap(); + + // Verify — header says PS384 so SHA-384 is used. + assert!(cose_verify1(&key, &envelope, None).unwrap()); + } + #[cfg(feature = "pqc")] mod pqc_tests { use super::*; diff --git a/native/rust/cose_openssl/src/lib.rs b/native/rust/cose_openssl/src/lib.rs index 7dba5d79..740a87bd 100644 --- a/native/rust/cose_openssl/src/lib.rs +++ b/native/rust/cose_openssl/src/lib.rs @@ -8,7 +8,7 @@ pub use cbor::CborValue; pub use cose::{ cose_sign1, cose_sign1_encoded, cose_verify1, cose_verify1_decoded, }; -pub use ossl_wrappers::{EvpKey, KeyType, WhichEC}; +pub use ossl_wrappers::{EvpKey, KeyType, WhichEC, WhichRSA}; #[cfg(feature = "pqc")] pub use ossl_wrappers::WhichMLDSA; diff --git a/native/rust/cose_openssl/src/ossl_wrappers.rs b/native/rust/cose_openssl/src/ossl_wrappers.rs index 419df464..846b51d0 100644 --- a/native/rust/cose_openssl/src/ossl_wrappers.rs +++ b/native/rust/cose_openssl/src/ossl_wrappers.rs @@ -37,6 +37,23 @@ impl WhichMLDSA { } } +#[derive(Debug)] +pub enum WhichRSA { + PS256, + PS384, + PS512, +} + +impl WhichRSA { + fn key_bits(&self) -> u32 { + match self { + WhichRSA::PS256 => 2048, + WhichRSA::PS384 => 3072, + WhichRSA::PS512 => 4096, + } + } +} + #[derive(Debug)] pub enum WhichEC { P256, @@ -65,6 +82,7 @@ impl WhichEC { #[derive(Debug)] pub enum KeyType { EC(WhichEC), + RSA(WhichRSA), #[cfg(feature = "pqc")] MLDSA(WhichMLDSA), @@ -91,6 +109,16 @@ impl EvpKey { ) } + KeyType::RSA(which) => { + let alg = CString::new("RSA").unwrap(); + ossl::EVP_PKEY_Q_keygen( + ptr::null_mut(), + ptr::null_mut(), + alg.as_ptr(), + which.key_bits() as std::ffi::c_uint, + ) + } + #[cfg(feature = "pqc")] KeyType::MLDSA(which) => { let alg = CString::new(which.openssl_str()).unwrap(); @@ -170,6 +198,17 @@ impl EvpKey { pkey: *mut ossl::EVP_PKEY, ) -> Result { unsafe { + let rsa = CString::new("RSA").unwrap(); + if EVP_PKEY_is_a(pkey as *const _, rsa.as_ptr()) == 1 { + let bits = ossl::EVP_PKEY_bits(pkey); + let which = match bits { + ..=2048 => WhichRSA::PS256, + 2049..=3072 => WhichRSA::PS384, + _ => WhichRSA::PS512, + }; + return Ok(KeyType::RSA(which)); + } + let ec = CString::new("EC").unwrap(); if EVP_PKEY_is_a(pkey as *const _, ec.as_ptr()) == 1 { let mut buf = [0u8; 64]; @@ -281,6 +320,9 @@ impl EvpKey { KeyType::EC(WhichEC::P256) => ossl::EVP_sha256(), KeyType::EC(WhichEC::P384) => ossl::EVP_sha384(), KeyType::EC(WhichEC::P521) => ossl::EVP_sha512(), + KeyType::RSA(WhichRSA::PS256) => ossl::EVP_sha256(), + KeyType::RSA(WhichRSA::PS384) => ossl::EVP_sha384(), + KeyType::RSA(WhichRSA::PS512) => ossl::EVP_sha512(), #[cfg(feature = "pqc")] KeyType::MLDSA(_) => ptr::null(), } @@ -432,6 +474,7 @@ pub trait ContextInit { ctx: *mut ossl::EVP_MD_CTX, md: *const ossl::EVP_MD, key: *mut ossl::EVP_PKEY, + pctx_out: *mut *mut ossl::EVP_PKEY_CTX, ) -> Result<(), i32>; fn purpose() -> &'static str; } @@ -441,11 +484,12 @@ impl ContextInit for SignOp { ctx: *mut ossl::EVP_MD_CTX, md: *const ossl::EVP_MD, key: *mut ossl::EVP_PKEY, + pctx_out: *mut *mut ossl::EVP_PKEY_CTX, ) -> Result<(), i32> { unsafe { let rc = ossl::EVP_DigestSignInit( ctx, - ptr::null_mut(), + pctx_out, md, ptr::null_mut(), key, @@ -466,11 +510,12 @@ impl ContextInit for VerifyOp { ctx: *mut ossl::EVP_MD_CTX, md: *const ossl::EVP_MD, key: *mut ossl::EVP_PKEY, + pctx_out: *mut *mut ossl::EVP_PKEY_CTX, ) -> Result<(), i32> { unsafe { let rc = ossl::EVP_DigestVerifyInit( ctx, - ptr::null_mut(), + pctx_out, md, ptr::null_mut(), key, @@ -488,6 +533,15 @@ impl ContextInit for VerifyOp { impl EvpMdContext { pub fn new(key: &EvpKey) -> Result { + Self::new_with_md(key, key.digest()) + } + + /// Create a context with an explicit digest, allowing the caller + /// to override the digest that `key.digest()` would return. + pub fn new_with_md( + key: &EvpKey, + md: *const ossl::EVP_MD, + ) -> Result { unsafe { let ctx = ossl::EVP_MD_CTX_new(); if ctx.is_null() { @@ -496,7 +550,8 @@ impl EvpMdContext { T::purpose() )); } - if let Err(err) = T::init(ctx, key.digest(), key.key) { + let mut pctx: *mut ossl::EVP_PKEY_CTX = ptr::null_mut(); + if let Err(err) = T::init(ctx, md, key.key, &mut pctx) { ossl::EVP_MD_CTX_free(ctx); return Err(format!( "Failed to init context for {} with err {}", @@ -504,6 +559,26 @@ impl EvpMdContext { err )); } + // For RSA keys, configure PSS padding. + if matches!(key.typ, KeyType::RSA(_)) && !pctx.is_null() { + const RSA_PSS_SALTLEN_DIGEST: std::ffi::c_int = -1; + if ossl::EVP_PKEY_CTX_set_rsa_padding( + pctx, + ossl::RSA_PKCS1_PSS_PADDING, + ) != 1 + { + ossl::EVP_MD_CTX_free(ctx); + return Err("Failed to set RSA PSS padding".into()); + } + if ossl::EVP_PKEY_CTX_set_rsa_pss_saltlen( + pctx, + RSA_PSS_SALTLEN_DIGEST, + ) != 1 + { + ossl::EVP_MD_CTX_free(ctx); + return Err("Failed to set RSA PSS salt length".into()); + } + } Ok(EvpMdContext { op: PhantomData, ctx, @@ -512,6 +587,20 @@ impl EvpMdContext { } } +/// Return the OpenSSL digest for the given COSE RSA-PSS algorithm ID. +pub fn rsa_pss_md_for_cose_alg( + alg: i64, +) -> Result<*const ossl::EVP_MD, String> { + unsafe { + match alg { + -37 => Ok(ossl::EVP_sha256()), + -38 => Ok(ossl::EVP_sha384()), + -39 => Ok(ossl::EVP_sha512()), + _ => Err(format!("{alg} is not a COSE RSA-PSS algorithm")), + } + } +} + impl Drop for EvpMdContext { fn drop(&mut self) { unsafe { @@ -540,6 +629,47 @@ mod tests { assert!(EvpKey::new(KeyType::EC(WhichEC::P521)).is_ok()); } + #[test] + fn create_rsa_keys() { + assert!(EvpKey::new(KeyType::RSA(WhichRSA::PS256)).is_ok()); + assert!(EvpKey::new(KeyType::RSA(WhichRSA::PS384)).is_ok()); + assert!(EvpKey::new(KeyType::RSA(WhichRSA::PS512)).is_ok()); + } + + #[test] + fn rsa_key_der_roundtrip() { + for which in [WhichRSA::PS256, WhichRSA::PS384, WhichRSA::PS512] { + let key = EvpKey::new(KeyType::RSA(which)).unwrap(); + let der = key.to_der_public().unwrap(); + let imported = EvpKey::from_der_public(&der).unwrap(); + assert!( + matches!(imported.typ, KeyType::RSA(_)), + "Expected RSA key type" + ); + let der2 = imported.to_der_public().unwrap(); + assert_eq!(der, der2); + } + } + + #[test] + fn rsa_key_private_der_roundtrip() { + for which in [WhichRSA::PS256, WhichRSA::PS384, WhichRSA::PS512] { + let key = EvpKey::new(KeyType::RSA(which)).unwrap(); + let priv_der = key.to_der_private().unwrap(); + let imported = EvpKey::from_der_private(&priv_der).unwrap(); + assert!( + matches!(imported.typ, KeyType::RSA(_)), + "Expected RSA key type" + ); + let priv_der2 = imported.to_der_private().unwrap(); + assert_eq!(priv_der, priv_der2); + + let pub1 = key.to_der_public().unwrap(); + let pub2 = imported.to_der_public().unwrap(); + assert_eq!(pub1, pub2); + } + } + #[test] fn ec_key_from_der_roundtrip() { for which in [WhichEC::P256, WhichEC::P384, WhichEC::P521] { diff --git a/native/rust/cose_openssl/src/sign.rs b/native/rust/cose_openssl/src/sign.rs index 214a1951..69b7ce9f 100644 --- a/native/rust/cose_openssl/src/sign.rs +++ b/native/rust/cose_openssl/src/sign.rs @@ -4,9 +4,26 @@ use openssl_sys as ossl; use std::ptr; pub fn sign(key: &EvpKey, msg: &[u8]) -> Result, String> { - unsafe { - let ctx = EvpMdContext::::new(key)?; + let ctx = EvpMdContext::::new(key)?; + sign_with_ctx(&ctx, msg) +} +// Only used in tests to sign with an explicit digest that differs from the key's default. +#[cfg(test)] +pub fn sign_with_md( + key: &EvpKey, + msg: &[u8], + md: *const ossl::EVP_MD, +) -> Result, String> { + let ctx = EvpMdContext::::new_with_md(key, md)?; + sign_with_ctx(&ctx, msg) +} + +fn sign_with_ctx( + ctx: &EvpMdContext, + msg: &[u8], +) -> Result, String> { + unsafe { let mut sig_size: usize = 0; let res = ossl::EVP_DigestSign( ctx.ctx, diff --git a/native/rust/cose_openssl/src/verify.rs b/native/rust/cose_openssl/src/verify.rs index f89a352e..5cdafcdb 100644 --- a/native/rust/cose_openssl/src/verify.rs +++ b/native/rust/cose_openssl/src/verify.rs @@ -3,9 +3,26 @@ use crate::ossl_wrappers::{EvpKey, EvpMdContext, VerifyOp}; use openssl_sys as ossl; pub fn verify(key: &EvpKey, sig: &[u8], msg: &[u8]) -> Result { - unsafe { - let ctx = EvpMdContext::::new(key)?; + let ctx = EvpMdContext::::new(key)?; + verify_with_ctx(&ctx, sig, msg) +} +pub fn verify_with_md( + key: &EvpKey, + sig: &[u8], + msg: &[u8], + md: *const ossl::EVP_MD, +) -> Result { + let ctx = EvpMdContext::::new_with_md(key, md)?; + verify_with_ctx(&ctx, sig, msg) +} + +fn verify_with_ctx( + ctx: &EvpMdContext, + sig: &[u8], + msg: &[u8], +) -> Result { + unsafe { let res = ossl::EVP_DigestVerify( ctx.ctx, sig.as_ptr(),