Skip to content

Commit

Permalink
Comments
Browse files Browse the repository at this point in the history
  • Loading branch information
wvanlint committed Mar 21, 2024
1 parent 35a9bc9 commit 865e15a
Show file tree
Hide file tree
Showing 6 changed files with 207 additions and 159 deletions.
14 changes: 7 additions & 7 deletions src/client.rs
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
use prost::Message;
use reqwest;
use reqwest::header::HeaderMap;
use reqwest::header::CONTENT_TYPE;
use reqwest::Client;
use std::default::Default;
use std::sync::Arc;

use crate::error::VssError;
use crate::headers::get_headermap;
use crate::headers::FixedHeaders;
use crate::headers::HeaderProvider;
use crate::headers::VssHeaderProvider;
use crate::types::{
DeleteObjectRequest, DeleteObjectResponse, GetObjectRequest, GetObjectResponse, ListKeyVersionsRequest,
ListKeyVersionsResponse, PutObjectRequest, PutObjectResponse,
Expand All @@ -27,7 +26,7 @@ where
base_url: String,
client: Client,
retry_policy: R,
header_provider: Arc<dyn HeaderProvider>,
header_provider: Arc<dyn VssHeaderProvider>,
}

impl<R: RetryPolicy<E = VssError>> VssClient<R> {
Expand All @@ -43,13 +42,13 @@ impl<R: RetryPolicy<E = VssError>> VssClient<R> {
base_url: String::from(base_url),
client,
retry_policy,
header_provider: Arc::new(FixedHeaders::new(HeaderMap::new())),
header_provider: Arc::new(FixedHeaders::new(Vec::new())),
}
}

/// Constructs a [`VssClient`] using `base_url` as the VSS server endpoint.
/// HTTP headers will be provided by the given `header_provider`.
pub fn new_with_headers(base_url: &str, retry_policy: R, header_provider: Arc<dyn HeaderProvider>) -> Self {
pub fn new_with_headers(base_url: &str, retry_policy: R, header_provider: Arc<dyn VssHeaderProvider>) -> Self {
let client = Client::new();
Self { base_url: String::from(base_url), client, retry_policy, header_provider }
}
Expand Down Expand Up @@ -133,11 +132,12 @@ impl<R: RetryPolicy<E = VssError>> VssClient<R> {
.get_headers()
.await
.map_err(|e| VssError::InternalError(e.to_string()))?;
let headermap = get_headermap(&headers).map_err(VssError::InternalError)?;
let response_raw = self
.client
.post(url)
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
.headers(headers)
.headers(headermap)
.body(request_body)
.send()
.await?;
Expand Down
201 changes: 96 additions & 105 deletions src/headers/lnurl_auth_jwt.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use crate::headers::HeaderProvider;
use crate::headers::HeaderProviderError;
use crate::headers::get_headermap;
use crate::headers::VssHeaderProvider;
use crate::headers::VssHeaderProviderError;
use crate::util::string::UntrustedString;
use async_trait::async_trait;
use base64::engine::general_purpose::URL_SAFE_NO_PAD;
use base64::Engine;
Expand All @@ -10,11 +12,9 @@ use bitcoin::hashes::{Hash, HashEngine, Hmac, HmacEngine};
use bitcoin::secp256k1::{All, Message, Secp256k1};
use bitcoin::Network;
use bitcoin::PrivateKey;
use reqwest::header::HeaderMap;
use reqwest::header::AUTHORIZATION;
use serde::Deserialize;
use std::str::FromStr;
use std::sync::Mutex;
use std::ops::Deref;
use std::sync::RwLock;
use std::time::SystemTime;
use url::Url;

Expand All @@ -30,17 +30,24 @@ const K1_QUERY_PARAM: &str = "k1";
const SIG_QUERY_PARAM: &str = "sig";
// The key of the LNURL key query parameter.
const KEY_QUERY_PARAM: &str = "key";
// The authorization header name.
const AUTHORIZATION: &str = "authorization";

#[derive(Debug, Clone)]
struct JwtToken {
token_str: String,
expiry: Option<u64>,
}

/// Provides a JWT token based on LNURL Auth.
/// The LNURL and JWT token are exchanged over a Websocket connection.
pub struct LnurlAuthJwt {
engine: Secp256k1<All>,
parent_key: ExtendedPrivKey,
url: String,
headers: HeaderMap,
default_headers: Vec<(String, String)>,
client: reqwest::Client,
jwt_token: Mutex<Option<String>>,
expiry: Mutex<Option<u64>>,
jwt_token: RwLock<Option<JwtToken>>,
}

impl LnurlAuthJwt {
Expand All @@ -51,48 +58,38 @@ impl LnurlAuthJwt {
/// The JWT token will be returned in response to the signed LNURL request under a token field.
/// The given set of headers will be used for LNURL requests, and will also be returned together
/// with the JWT authorization header for VSS requests.
pub fn new(seed: &[u8], url: String, headers: Vec<(String, String)>) -> Result<LnurlAuthJwt, HeaderProviderError> {
pub fn new(
seed: &[u8], url: String, default_headers: Vec<(String, String)>,
) -> Result<LnurlAuthJwt, VssHeaderProviderError> {
let engine = Secp256k1::new();
let master = ExtendedPrivKey::new_master(Network::Testnet, seed).map_err(HeaderProviderError::from)?;
let master = ExtendedPrivKey::new_master(Network::Testnet, seed).map_err(VssHeaderProviderError::from)?;
let child_number =
ChildNumber::from_hardened_idx(PARENT_DERIVATION_INDEX).map_err(HeaderProviderError::from)?;
ChildNumber::from_hardened_idx(PARENT_DERIVATION_INDEX).map_err(VssHeaderProviderError::from)?;
let parent_key = master
.derive_priv(&engine, &vec![child_number])
.map_err(HeaderProviderError::from)?;
let mut headermap = HeaderMap::new();
for (name, value) in headers {
headermap.insert(
reqwest::header::HeaderName::from_str(&name).map_err(HeaderProviderError::from)?,
reqwest::header::HeaderValue::from_str(&value).map_err(HeaderProviderError::from)?,
);
}
.map_err(VssHeaderProviderError::from)?;
let default_headermap =
get_headermap(&default_headers).map_err(|error| VssHeaderProviderError::InvalidData { error })?;
let client = reqwest::Client::builder()
.default_headers(headermap.clone())
.default_headers(default_headermap)
.build()
.map_err(HeaderProviderError::from)?;
.map_err(VssHeaderProviderError::from)?;

Ok(LnurlAuthJwt {
engine,
parent_key,
url,
headers: headermap,
client,
jwt_token: Mutex::new(None),
expiry: Mutex::new(None),
})
Ok(LnurlAuthJwt { engine, parent_key, url, default_headers, client, jwt_token: RwLock::new(None) })
}

async fn fetch_jwt_token(&self) -> Result<String, HeaderProviderError> {
async fn fetch_jwt_token(&self) -> Result<JwtToken, VssHeaderProviderError> {
// Fetch the LNURL.
let lnurl_str = self
.client
.get(&self.url)
.send()
.await
.map_err(HeaderProviderError::from)?
.text()
.await
.map_err(HeaderProviderError::from)?;
let lnurl_str = UntrustedString::new(
self.client
.get(&self.url)
.send()
.await
.map_err(VssHeaderProviderError::from)?
.text()
.await
.map_err(VssHeaderProviderError::from)?,
);

// Sign the LNURL and perform the request.
let signed_lnurl = sign_lnurl(&self.engine, &self.parent_key, &lnurl_str)?;
Expand All @@ -101,40 +98,45 @@ impl LnurlAuthJwt {
.get(&signed_lnurl)
.send()
.await
.map_err(HeaderProviderError::from)?
.map_err(VssHeaderProviderError::from)?
.json()
.await
.map_err(HeaderProviderError::from)?;
.map_err(VssHeaderProviderError::from)?;

match lnurl_auth_response {
LnurlAuthResponse { token: Some(token), .. } => Ok(token),
let untrusted_token = match lnurl_auth_response {
LnurlAuthResponse { token: Some(token), .. } => token,
LnurlAuthResponse { reason: Some(reason), .. } => {
Err(HeaderProviderError::ApplicationError(format!("LNURL Auth failed, reason is: {}", reason)))
return Err(VssHeaderProviderError::ApplicationError {
error: format!("LNURL Auth failed, reason is: {}", reason),
});
}
_ => Err(HeaderProviderError::InvalidData(
"LNURL Auth response did not contain a token nor an error".to_string(),
)),
}
_ => {
return Err(VssHeaderProviderError::InvalidData {
error: "LNURL Auth response did not contain a token nor an error".to_string(),
});
}
};
parse_jwt_token(untrusted_token)
}

async fn get_jwt_token(&self, force_refresh: bool) -> Result<String, HeaderProviderError> {
async fn get_jwt_token(&self, force_refresh: bool) -> Result<String, VssHeaderProviderError> {
if !self.is_expired() && !force_refresh {
let jwt_token = self.jwt_token.lock().unwrap();
if let Some(jwt_token) = jwt_token.as_deref() {
return Ok(jwt_token.to_string());
let jwt_token = self.jwt_token.read().unwrap();
if let Some(jwt_token) = jwt_token.deref() {
return Ok(jwt_token.token_str.clone());
}
}
let jwt_token = self.fetch_jwt_token().await?;
let expiry = parse_expiry(&jwt_token)?;
*self.jwt_token.lock().unwrap() = Some(jwt_token.clone());
*self.expiry.lock().unwrap() = expiry;
Ok(jwt_token)
*self.jwt_token.write().unwrap() = Some(jwt_token.clone());
Ok(jwt_token.token_str)
}

fn is_expired(&self) -> bool {
self.expiry
.lock()
self.jwt_token
.read()
.unwrap()
.as_ref()
.and_then(|token| token.expiry)
.map(|expiry| {
SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs() + EXPIRY_BUFFER_SECS
> expiry
Expand All @@ -144,41 +146,39 @@ impl LnurlAuthJwt {
}

#[async_trait]
impl HeaderProvider for LnurlAuthJwt {
async fn get_headers(&self) -> Result<HeaderMap, HeaderProviderError> {
impl VssHeaderProvider for LnurlAuthJwt {
async fn get_headers(&self) -> Result<Vec<(String, String)>, VssHeaderProviderError> {
let jwt_token = self.get_jwt_token(false).await?;
let mut headers = self.headers.clone();
let value = format!("Bearer {}", jwt_token).parse().map_err(HeaderProviderError::from)?;
headers.insert(AUTHORIZATION, value);
let mut headers = self.default_headers.clone();
headers.push((AUTHORIZATION.to_string(), format!("Bearer {}", jwt_token)));
Ok(headers)
}
}

fn hashing_key(engine: &Secp256k1<All>, parent_key: &ExtendedPrivKey) -> Result<PrivateKey, HeaderProviderError> {
fn hashing_key(engine: &Secp256k1<All>, parent_key: &ExtendedPrivKey) -> Result<PrivateKey, VssHeaderProviderError> {
let hashing_child_number =
ChildNumber::from_normal_idx(HASHING_DERIVATION_INDEX).map_err(HeaderProviderError::from)?;
ChildNumber::from_normal_idx(HASHING_DERIVATION_INDEX).map_err(VssHeaderProviderError::from)?;
parent_key
.derive_priv(engine, &vec![hashing_child_number])
.map(|xpriv| xpriv.to_priv())
.map_err(HeaderProviderError::from)
.map_err(VssHeaderProviderError::from)
}

fn linking_key_path(hashing_key: &PrivateKey, domain_name: &str) -> Result<DerivationPath, HeaderProviderError> {
fn linking_key_path(hashing_key: &PrivateKey, domain_name: &str) -> Result<DerivationPath, VssHeaderProviderError> {
let mut engine = HmacEngine::<sha256::Hash>::new(&hashing_key.inner[..]);
engine.input(domain_name.as_bytes());
let result = Hmac::<sha256::Hash>::from_engine(engine).to_byte_array();
let children: Vec<ChildNumber> = (0..4)
let children = (0..4)
.map(|i| u32::from_be_bytes(result[(i * 4)..((i + 1) * 4)].try_into().unwrap()))
.map(ChildNumber::from)
.collect::<Vec<_>>();
Ok(DerivationPath::from(children))
.map(ChildNumber::from);
Ok(DerivationPath::from_iter(children))
}

fn sign_lnurl(
engine: &Secp256k1<All>, parent_key: &ExtendedPrivKey, lnurl_str: &str,
) -> Result<String, HeaderProviderError> {
engine: &Secp256k1<All>, parent_key: &ExtendedPrivKey, lnurl_str: &UntrustedString,
) -> Result<String, VssHeaderProviderError> {
// Parse k1 parameter to sign.
let invalid_lnurl = || HeaderProviderError::InvalidData(format!("invalid lnurl: {}", lnurl_str));
let invalid_lnurl = || VssHeaderProviderError::InvalidData { error: format!("invalid lnurl: {}", lnurl_str) };
let mut lnurl = Url::parse(lnurl_str).map_err(|_| invalid_lnurl())?;
let domain = lnurl.domain().ok_or(invalid_lnurl())?;
let k1_str = lnurl
Expand All @@ -194,11 +194,11 @@ fn sign_lnurl(
let linking_key_path = linking_key_path(&hashing_key, domain)?;
let private_key = parent_key
.derive_priv(engine, &linking_key_path)
.map_err(HeaderProviderError::from)?
.map_err(VssHeaderProviderError::from)?
.to_priv();
let public_key = private_key.public_key(engine);
let message =
Message::from_slice(&k1).map_err(|_| HeaderProviderError::InvalidData(format!("invalid k1: {:?}", k1)))?;
let message = Message::from_slice(&k1)
.map_err(|_| VssHeaderProviderError::InvalidData { error: format!("invalid k1: {:?}", k1) })?;
let sig = engine.sign_ecdsa(&message, &private_key.inner);

// Compose LNURL with signature and linking key.
Expand All @@ -209,55 +209,46 @@ fn sign_lnurl(
Ok(lnurl.to_string())
}

#[derive(Deserialize)]
#[derive(Deserialize, Debug, Clone)]
struct LnurlAuthResponse {
reason: Option<String>,
token: Option<String>,
reason: Option<UntrustedString>,
token: Option<UntrustedString>,
}

#[derive(Deserialize)]
#[derive(Deserialize, Debug, Clone)]
struct ExpiryClaim {
exp: Option<u64>,
}

fn parse_expiry(jwt_token: &str) -> Result<Option<u64>, HeaderProviderError> {
fn parse_jwt_token(jwt_token: UntrustedString) -> Result<JwtToken, VssHeaderProviderError> {
let parts: Vec<&str> = jwt_token.split('.').collect();
let invalid = || HeaderProviderError::InvalidData(format!("invalid JWT token: {}", jwt_token));
let invalid = || VssHeaderProviderError::InvalidData { error: format!("invalid JWT token: {}", jwt_token) };
if parts.len() != 3 {
return Err(invalid());
}
let _ = URL_SAFE_NO_PAD.decode(parts[0]).map_err(|_| invalid())?;
let bytes = URL_SAFE_NO_PAD.decode(parts[1]).map_err(|_| invalid())?;
let _ = URL_SAFE_NO_PAD.decode(parts[2]).map_err(|_| invalid())?;
let claim: ExpiryClaim = serde_json::from_slice(&bytes).map_err(|_| invalid())?;
Ok(claim.exp)
}

impl From<bitcoin::bip32::Error> for HeaderProviderError {
fn from(e: bitcoin::bip32::Error) -> HeaderProviderError {
HeaderProviderError::InvalidData(e.to_string())
}
}

impl From<reqwest::header::InvalidHeaderName> for HeaderProviderError {
fn from(e: reqwest::header::InvalidHeaderName) -> HeaderProviderError {
HeaderProviderError::InvalidData(e.to_string())
}
Ok(JwtToken { token_str: jwt_token.into_inner(), expiry: claim.exp })
}

impl From<reqwest::header::InvalidHeaderValue> for HeaderProviderError {
fn from(e: reqwest::header::InvalidHeaderValue) -> HeaderProviderError {
HeaderProviderError::InvalidData(e.to_string())
impl From<bitcoin::bip32::Error> for VssHeaderProviderError {
fn from(e: bitcoin::bip32::Error) -> VssHeaderProviderError {
VssHeaderProviderError::InvalidData { error: e.to_string() }
}
}

impl From<reqwest::Error> for HeaderProviderError {
fn from(e: reqwest::Error) -> HeaderProviderError {
HeaderProviderError::RequestError(e.to_string())
impl From<reqwest::Error> for VssHeaderProviderError {
fn from(e: reqwest::Error) -> VssHeaderProviderError {
VssHeaderProviderError::RequestError { error: e.to_string() }
}
}

#[cfg(test)]
mod test {
use crate::headers::lnurl_auth_jwt::{linking_key_path, sign_lnurl};
use crate::util::string::UntrustedString;
use bitcoin::bip32::ExtendedPrivKey;
use bitcoin::hashes::hex::FromHex;
use bitcoin::secp256k1::Secp256k1;
Expand Down Expand Up @@ -288,7 +279,7 @@ mod test {
let signed = sign_lnurl(
&engine,
&master,
"https://example.com/path?tag=login&k1=e2af6254a8df433264fa23f67eb8188635d15ce883e8fc020989d5f82ae6f11e",
&UntrustedString::new("https://example.com/path?tag=login&k1=e2af6254a8df433264fa23f67eb8188635d15ce883e8fc020989d5f82ae6f11e".to_string()),
)
.unwrap();
assert_eq!(
Expand Down
Loading

0 comments on commit 865e15a

Please sign in to comment.