From 9a5a77f77dba062ec9775f4c6021b651dcce7ee9 Mon Sep 17 00:00:00 2001 From: Willem Van Lint Date: Fri, 2 Feb 2024 17:02:49 -0800 Subject: [PATCH] JWT authorization header based on LNURL Auth --- Cargo.toml | 9 + src/client.rs | 2 +- src/headers/lnurl_auth_jwt.rs | 304 ++++++++++++++++++++++++++++++++++ src/headers/mod.rs | 32 +++- tests/lnurl_auth_jwt_tests.rs | 100 +++++++++++ 5 files changed, 445 insertions(+), 2 deletions(-) create mode 100644 src/headers/lnurl_auth_jwt.rs create mode 100644 tests/lnurl_auth_jwt_tests.rs diff --git a/Cargo.toml b/Cargo.toml index 0d16e40..cffdc83 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,12 +11,21 @@ categories = ["web-programming::http-client", "cryptography::cryptocurrencies"] build = "build.rs" +[features] +default = ["lnurl-auth"] +lnurl-auth = ["dep:bitcoin", "dep:url", "dep:base64", "dep:serde", "dep:serde_json", "reqwest/json"] + [dependencies] prost = "0.11.6" reqwest = { version = "0.11.13", default-features = false, features = ["rustls-tls"] } tokio = { version = "1", default-features = false, features = ["time"] } rand = "0.8.5" async-trait = "0.1.77" +bitcoin = { version = "0.32.2", default-features = false, features = ["std", "rand-std"], optional = true } +url = { version = "2.5.0", default-features = false, optional = true } +base64 = { version = "0.21.7", default-features = false, optional = true } +serde = { version = "1.0.196", default-features = false, features = ["serde_derive"], optional = true } +serde_json = { version = "1.0.113", default-features = false, optional = true } [target.'cfg(genproto)'.build-dependencies] prost-build = { version = "0.11.3" } diff --git a/src/client.rs b/src/client.rs index ffe8b63..2428bb3 100644 --- a/src/client.rs +++ b/src/client.rs @@ -126,7 +126,7 @@ impl> VssClient { .header_provider .get_headers(&request_body) .await - .and_then(get_headermap) + .and_then(|h| get_headermap(&h)) .map_err(|e| VssError::AuthError(e.to_string()))?; let response_raw = self .client diff --git a/src/headers/lnurl_auth_jwt.rs b/src/headers/lnurl_auth_jwt.rs new file mode 100644 index 0000000..ca0415b --- /dev/null +++ b/src/headers/lnurl_auth_jwt.rs @@ -0,0 +1,304 @@ +use crate::headers::{get_headermap, VssHeaderProvider, VssHeaderProviderError}; +use async_trait::async_trait; +use base64::engine::general_purpose::URL_SAFE_NO_PAD; +use base64::Engine; +use bitcoin::bip32::{ChildNumber, DerivationPath, Xpriv}; +use bitcoin::hashes::hex::FromHex; +use bitcoin::hashes::sha256; +use bitcoin::hashes::{Hash, HashEngine, Hmac, HmacEngine}; +use bitcoin::secp256k1::{Message, Secp256k1, SignOnly}; +use bitcoin::Network; +use bitcoin::PrivateKey; +use serde::Deserialize; +use std::collections::HashMap; +use std::sync::RwLock; +use std::time::{Duration, SystemTime}; +use url::Url; + +// Derivation index of the parent extended private key as defined by LUD-05. +const PARENT_DERIVATION_INDEX: u32 = 138; +// Derivation index of the hashing private key as defined by LUD-05. +const HASHING_DERIVATION_INDEX: u32 = 0; +// The JWT token will be refreshed by the given amount before its expiry. +const EXPIRY_BUFFER: Duration = Duration::from_secs(60); +// The key of the LNURL k1 query parameter. +const K1_QUERY_PARAM: &str = "k1"; +// The key of the LNURL sig query parameter. +const SIG_QUERY_PARAM: &str = "sig"; +// The key of the LNURL key query parameter. +const KEY_QUERY_PARAM: &str = "key"; +// The authorization header name. +const AUTHORIZATION: &str = "Authorization"; + +#[derive(Debug, Clone)] +struct JwtToken { + token_str: String, + expiry: Option, +} + +impl JwtToken { + fn is_expired(&self) -> bool { + self.expiry + .and_then(|expiry| { + SystemTime::now() + .checked_add(EXPIRY_BUFFER) + .map(|now_with_buffer| now_with_buffer > expiry) + }) + .unwrap_or(false) + } +} + +/// Provides a JWT token based on LNURL Auth. +pub struct LnurlAuthToJwtProvider { + engine: Secp256k1, + parent_key: Xpriv, + url: String, + default_headers: HashMap, + client: reqwest::Client, + cached_jwt_token: RwLock>, +} + +impl LnurlAuthToJwtProvider { + /// Creates a new JWT provider based on LNURL Auth. + /// + /// The LNURL Auth keys are derived from a seed according to LUD-05. + /// The user is free to choose a consistent seed, such as a hardened derivation from the wallet + /// master key or otherwise for compatibility reasons. + /// The LNURL with the challenge will be retrieved by making a request to the given URL. + /// The JWT token will be returned in response to the signed LNURL request under a token field. + /// The given set of headers will be used for LNURL requests, and will also be returned together + /// with the JWT authorization header for VSS requests. + pub fn new( + seed: &[u8], url: String, default_headers: HashMap, + ) -> Result { + let engine = Secp256k1::signing_only(); + let master = Xpriv::new_master(Network::Testnet, seed).map_err(VssHeaderProviderError::from)?; + let child_number = + ChildNumber::from_hardened_idx(PARENT_DERIVATION_INDEX).map_err(VssHeaderProviderError::from)?; + let parent_key = master + .derive_priv(&engine, &vec![child_number]) + .map_err(VssHeaderProviderError::from)?; + let default_headermap = get_headermap(&default_headers)?; + let client = reqwest::Client::builder() + .default_headers(default_headermap) + .build() + .map_err(VssHeaderProviderError::from)?; + + Ok(LnurlAuthToJwtProvider { + engine, + parent_key, + url, + default_headers, + client, + cached_jwt_token: RwLock::new(None), + }) + } + + async fn fetch_jwt_token(&self) -> Result { + // Fetch the LNURL. + let lnurl_str = self + .client + .get(&self.url) + .send() + .await + .map_err(VssHeaderProviderError::from)? + .text() + .await + .map_err(VssHeaderProviderError::from)?; + + // Sign the LNURL and perform the request. + let signed_lnurl = sign_lnurl(&self.engine, &self.parent_key, &lnurl_str)?; + let lnurl_auth_response: LnurlAuthResponse = self + .client + .get(&signed_lnurl) + .send() + .await + .map_err(VssHeaderProviderError::from)? + .json() + .await + .map_err(VssHeaderProviderError::from)?; + + let untrusted_token = match lnurl_auth_response { + LnurlAuthResponse { token: Some(token), .. } => token, + LnurlAuthResponse { reason: Some(reason), .. } => { + return Err(VssHeaderProviderError::AuthorizationError { + error: format!("LNURL Auth failed, reason is: {}", reason.escape_debug()), + }); + } + _ => { + return Err(VssHeaderProviderError::InvalidData { + error: "LNURL Auth response did not contain a token nor an error".to_string(), + }); + } + }; + parse_jwt_token(untrusted_token) + } + + async fn get_jwt_token(&self, force_refresh: bool) -> Result { + let cached_token_str = if force_refresh { + None + } else { + let jwt_token = self.cached_jwt_token.read().unwrap(); + jwt_token.as_ref().filter(|t| !t.is_expired()).map(|t| t.token_str.clone()) + }; + if let Some(token_str) = cached_token_str { + Ok(token_str) + } else { + let jwt_token = self.fetch_jwt_token().await?; + *self.cached_jwt_token.write().unwrap() = Some(jwt_token.clone()); + Ok(jwt_token.token_str) + } + } +} + +#[async_trait] +impl VssHeaderProvider for LnurlAuthToJwtProvider { + async fn get_headers(&self, _request: &[u8]) -> Result, VssHeaderProviderError> { + let jwt_token = self.get_jwt_token(false).await?; + let mut headers = self.default_headers.clone(); + headers.insert(AUTHORIZATION.to_string(), format!("Bearer {}", jwt_token)); + Ok(headers) + } +} + +fn hashing_key(engine: &Secp256k1, parent_key: &Xpriv) -> Result { + let hashing_child_number = + ChildNumber::from_normal_idx(HASHING_DERIVATION_INDEX).map_err(VssHeaderProviderError::from)?; + parent_key + .derive_priv(engine, &vec![hashing_child_number]) + .map(|xpriv| xpriv.to_priv()) + .map_err(VssHeaderProviderError::from) +} + +fn linking_key_path(hashing_key: &PrivateKey, domain_name: &str) -> Result { + let mut engine = HmacEngine::::new(&hashing_key.inner[..]); + engine.input(domain_name.as_bytes()); + let result = Hmac::::from_engine(engine).to_byte_array(); + // unwrap safety: We take 4-byte chunks, so TryInto for [u8; 4] never fails. + let children = result + .chunks_exact(4) + .take(4) + .map(|i| u32::from_be_bytes(i.try_into().unwrap())) + .map(ChildNumber::from); + Ok(DerivationPath::from_iter(children)) +} + +fn sign_lnurl( + engine: &Secp256k1, parent_key: &Xpriv, lnurl_str: &str, +) -> Result { + // Parse k1 parameter to sign. + let invalid_lnurl = + || VssHeaderProviderError::InvalidData { error: format!("invalid lnurl: {}", lnurl_str.escape_debug()) }; + let mut lnurl = Url::parse(lnurl_str).map_err(|_| invalid_lnurl())?; + let domain = lnurl.domain().ok_or(invalid_lnurl())?; + let k1_str = lnurl + .query_pairs() + .find(|(k, _)| k == K1_QUERY_PARAM) + .ok_or(invalid_lnurl())? + .1 + .to_string(); + let k1: [u8; 32] = FromHex::from_hex(&k1_str).map_err(|_| invalid_lnurl())?; + + // Sign k1 parameter with linking private key. + let hashing_private_key = hashing_key(engine, parent_key)?; + let linking_key_path = linking_key_path(&hashing_private_key, domain)?; + let linking_private_key = parent_key + .derive_priv(engine, &linking_key_path) + .map_err(VssHeaderProviderError::from)? + .to_priv(); + let linking_public_key = linking_private_key.public_key(engine); + let message = Message::from_digest_slice(&k1) + .map_err(|_| VssHeaderProviderError::InvalidData { error: format!("invalid k1: {:?}", k1) })?; + let sig = engine.sign_ecdsa(&message, &linking_private_key.inner); + + // Compose LNURL with signature and linking public key. + lnurl + .query_pairs_mut() + .append_pair(SIG_QUERY_PARAM, &sig.serialize_der().to_string()) + .append_pair(KEY_QUERY_PARAM, &linking_public_key.to_string()); + Ok(lnurl.to_string()) +} + +#[derive(Deserialize, Debug, Clone)] +struct LnurlAuthResponse { + reason: Option, + token: Option, +} + +#[derive(Deserialize, Debug, Clone)] +struct ExpiryClaim { + #[serde(rename = "exp")] + expiry_secs: Option, +} + +fn parse_jwt_token(jwt_token: String) -> Result { + let parts: Vec<&str> = jwt_token.split('.').collect(); + let invalid = + || VssHeaderProviderError::InvalidData { error: format!("invalid JWT token: {}", jwt_token.escape_debug()) }; + if parts.len() != 3 { + return Err(invalid()); + } + let _ = URL_SAFE_NO_PAD.decode(parts[0]).map_err(|_| invalid())?; + let bytes = URL_SAFE_NO_PAD.decode(parts[1]).map_err(|_| invalid())?; + let _ = URL_SAFE_NO_PAD.decode(parts[2]).map_err(|_| invalid())?; + let claim: ExpiryClaim = serde_json::from_slice(&bytes).map_err(|_| invalid())?; + let expiry = claim + .expiry_secs + .and_then(|e| SystemTime::UNIX_EPOCH.checked_add(Duration::from_secs(e))); + Ok(JwtToken { token_str: jwt_token, expiry }) +} + +impl From for VssHeaderProviderError { + fn from(e: bitcoin::bip32::Error) -> VssHeaderProviderError { + VssHeaderProviderError::InternalError { error: e.to_string() } + } +} + +impl From for VssHeaderProviderError { + fn from(e: reqwest::Error) -> VssHeaderProviderError { + VssHeaderProviderError::RequestError { error: e.to_string() } + } +} + +#[cfg(test)] +mod test { + use crate::headers::lnurl_auth_jwt::{linking_key_path, sign_lnurl}; + use bitcoin::bip32::Xpriv; + use bitcoin::hashes::hex::FromHex; + use bitcoin::secp256k1::Secp256k1; + use bitcoin::secp256k1::SecretKey; + use bitcoin::Network; + use bitcoin::PrivateKey; + use std::str::FromStr; + + #[test] + fn test_linking_key_path() { + // Test vector from: + // https://github.com/lnurl/luds/blob/43cf7754de2033987a7661afc8b4a3998914a536/05.md + let hashing_key = PrivateKey::new( + SecretKey::from_str("7d417a6a5e9a6a4a879aeaba11a11838764c8fa2b959c242d43dea682b3e409b").unwrap(), + Network::Testnet, // The network only matters for serialization. + ); + let path = linking_key_path(&hashing_key, "site.com").unwrap(); + let numbers: Vec = path.into_iter().map(|c| u32::from(c.clone())).collect(); + assert_eq!(numbers, vec![1588488367, 2659270754, 38110259, 4136336762]); + } + + #[test] + fn test_sign_lnurl() { + let engine = Secp256k1::signing_only(); + let seed: [u8; 32] = + FromHex::from_hex("abababababababababababababababababababababababababababababababab").unwrap(); + let master = Xpriv::new_master(Network::Testnet, &seed).unwrap(); + let signed = sign_lnurl( + &engine, + &master, + "https://example.com/path?tag=login&k1=e2af6254a8df433264fa23f67eb8188635d15ce883e8fc020989d5f82ae6f11e", + ) + .unwrap(); + assert_eq!( + signed, + "https://example.com/path?tag=login&k1=e2af6254a8df433264fa23f67eb8188635d15ce883e8fc020989d5f82ae6f11e&sig=3045022100a75df468de452e618edb8030016eb0894204655c7d93ece1be007fcf36843522022048bc2f00a0a5a30601d274b49cfaf9ef4c76176e5401d0dfb195f5d6ab8ab4c4&key=02d9eb1b467517d685e3b5439082c14bb1a2c9ae672df4d9046d208c193a5846e0", + ); + } +} diff --git a/src/headers/mod.rs b/src/headers/mod.rs index 901480f..e6288ee 100644 --- a/src/headers/mod.rs +++ b/src/headers/mod.rs @@ -6,6 +6,12 @@ use std::fmt::Display; use std::fmt::Formatter; use std::str::FromStr; +#[cfg(feature = "lnurl-auth")] +mod lnurl_auth_jwt; + +#[cfg(feature = "lnurl-auth")] +pub use lnurl_auth_jwt::LnurlAuthToJwtProvider; + /// Defines a trait around how headers are provided for each VSS request. #[async_trait] pub trait VssHeaderProvider { @@ -25,6 +31,21 @@ pub enum VssHeaderProviderError { /// The error message. error: String, }, + /// An external request failed. + RequestError { + /// The error message. + error: String, + }, + /// Authorization was refused. + AuthorizationError { + /// The error message. + error: String, + }, + /// An application-level error occurred specific to the header provider functionality. + InternalError { + /// The error message. + error: String, + }, } impl Display for VssHeaderProviderError { @@ -33,6 +54,15 @@ impl Display for VssHeaderProviderError { Self::InvalidData { error } => { write!(f, "invalid data: {}", error) } + Self::RequestError { error } => { + write!(f, "error performing external request: {}", error) + } + Self::AuthorizationError { error } => { + write!(f, "authorization was refused: {}", error) + } + Self::InternalError { error } => { + write!(f, "internal error: {}", error) + } } } } @@ -58,7 +88,7 @@ impl VssHeaderProvider for FixedHeaders { } } -pub(crate) fn get_headermap(headers: HashMap) -> Result { +pub(crate) fn get_headermap(headers: &HashMap) -> Result { let mut headermap = HeaderMap::new(); for (name, value) in headers { headermap.insert( diff --git a/tests/lnurl_auth_jwt_tests.rs b/tests/lnurl_auth_jwt_tests.rs new file mode 100644 index 0000000..e90c6e6 --- /dev/null +++ b/tests/lnurl_auth_jwt_tests.rs @@ -0,0 +1,100 @@ +#[cfg(feature = "lnurl-auth")] +mod lnurl_auth_jwt_tests { + use base64::engine::general_purpose::URL_SAFE_NO_PAD; + use base64::Engine; + use mockito::Matcher; + use serde_json::json; + use std::collections::HashMap; + use std::time::SystemTime; + use vss_client::headers::LnurlAuthToJwtProvider; + use vss_client::headers::VssHeaderProvider; + + const APPLICATION_JSON: &'static str = "application/json"; + + fn lnurl_auth_response(jwt: &str) -> String { + json!({ + "status": "OK", + "token": jwt, + }) + .to_string() + } + + fn jwt_with_expiry(exp: u64) -> String { + let claims = json!({ + "exp": exp, + }) + .to_string(); + let ignored = URL_SAFE_NO_PAD.encode("ignored"); + let encoded = URL_SAFE_NO_PAD.encode(claims); + format!("{}.{}.{}", ignored, encoded, ignored) + } + + #[tokio::test] + async fn test_lnurl_auth_jwt() { + // Initialize LNURL Auth JWT provider connecting to the mock server. + let addr = mockito::server_address(); + let base_url = format!("http://localhost:{}", addr.port()); + let lnurl_auth_jwt = LnurlAuthToJwtProvider::new(&[0; 32], base_url.clone(), HashMap::new()).unwrap(); + { + // First request will be provided with an expired JWT token. + let k1 = "0000000000000000000000000000000000000000000000000000000000000000"; + let expired_jwt = jwt_with_expiry(0); + let lnurl = mockito::mock("GET", "/") + .expect(1) + .with_status(200) + .with_body(format!("{}/verify?tag=login&k1={}", base_url, k1)) + .create(); + let lnurl_verification = mockito::mock("GET", "/verify") + .match_query(Matcher::AllOf(vec![ + Matcher::UrlEncoded("k1".into(), k1.into()), + Matcher::Regex("sig=".into()), + Matcher::Regex("key=".into()), + ])) + .expect(1) + .with_status(200) + .with_header(reqwest::header::CONTENT_TYPE.as_str(), APPLICATION_JSON) + .with_body(lnurl_auth_response(&expired_jwt)) + .create(); + assert_eq!( + lnurl_auth_jwt.get_headers(&[]).await.unwrap().get("Authorization").unwrap(), + &format!("Bearer {}", expired_jwt), + ); + lnurl.assert(); + lnurl_verification.assert(); + } + { + // Second request will be provided with a non-expired JWT token. + // This will be cached. + let k1 = "1000000000000000000000000000000000000000000000000000000000000000"; + let valid_jwt = jwt_with_expiry( + SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() + 60 * 60 * 24 * 365, + ); + let lnurl = mockito::mock("GET", "/") + .expect(1) + .with_status(200) + .with_body(format!("{}/verify?tag=login&k1={}", base_url, k1)) + .create(); + let lnurl_verification = mockito::mock("GET", "/verify") + .match_query(Matcher::AllOf(vec![ + Matcher::UrlEncoded("k1".into(), k1.into()), + Matcher::Regex("sig=".to_string()), + Matcher::Regex("key=".to_string()), + ])) + .expect(1) + .with_status(200) + .with_header(reqwest::header::CONTENT_TYPE.as_str(), APPLICATION_JSON) + .with_body(lnurl_auth_response(&valid_jwt)) + .create(); + assert_eq!( + lnurl_auth_jwt.get_headers(&[]).await.unwrap().get("Authorization").unwrap(), + &format!("Bearer {}", valid_jwt), + ); + assert_eq!( + lnurl_auth_jwt.get_headers(&[]).await.unwrap().get("Authorization").unwrap(), + &format!("Bearer {}", valid_jwt), + ); + lnurl.assert(); + lnurl_verification.assert(); + } + } +}