Skip to content

Commit

Permalink
feat: Add authentication token to protocol handshake
Browse files Browse the repository at this point in the history
This ensures that only clients which have the token can connect to the
server and retrieve the file.
  • Loading branch information
flub committed Jan 25, 2023
2 parents 4007360 + c4a110f commit bf08478
Show file tree
Hide file tree
Showing 8 changed files with 176 additions and 20 deletions.
3 changes: 3 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ authors = ["dignifiedquire <me@dignifiedquire.com>"]
repository = "https://github.com/n0-computer/sendme"

[dependencies]
anyhow = "1.0.68"
anyhow = { version = "1.0.68", features = ["backtrace"] }
async-stream = "0.3.3"
bao = "0.12.1"
blake3 = "1.3.3"
Expand Down
26 changes: 23 additions & 3 deletions src/get.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::fmt::Debug;
use std::time::Duration;
use std::{io::Read, net::SocketAddr, time::Instant};

Expand All @@ -10,7 +11,7 @@ use s2n_quic::{client::Connect, Client};
use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
use tracing::debug;

use crate::protocol::{read_lp_data, write_lp, Handshake, Request, Res, Response};
use crate::protocol::{read_lp_data, write_lp, AuthToken, Handshake, Request, Res, Response};
use crate::tls::{self, Keypair, PeerId};

const MAX_DATA_SIZE: usize = 1024 * 1024 * 1024;
Expand Down Expand Up @@ -52,6 +53,7 @@ async fn setup(opts: Options) -> Result<(Client, Connection)> {
}

/// Stats about the transfer.
#[derive(Debug)]
pub struct Stats {
pub data_len: usize,
pub elapsed: Duration,
Expand All @@ -78,7 +80,25 @@ pub enum Event {
Done(Stats),
}

pub fn run(hash: bao::Hash, opts: Options) -> impl Stream<Item = Result<Event>> {
impl Debug for Event {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Connected => write!(f, "Connected"),
Self::Requested { size } => f.debug_struct("Requested").field("size", size).finish(),
Self::Receiving { hash, reader: _ } => f
.debug_struct("Receiving")
.field("hash", hash)
.field(
"reader",
&"Box<dyn AsyncRead + Unpin + Sync + Send + 'static>",
)
.finish(),
Self::Done(arg0) => f.debug_tuple("Done").field(arg0).finish(),
}
}
}

pub fn run(hash: bao::Hash, token: AuthToken, opts: Options) -> impl Stream<Item = Result<Event>> {
async_stream::try_stream! {
let now = Instant::now();
let (_client, mut connection) = setup(opts).await?;
Expand All @@ -97,7 +117,7 @@ pub fn run(hash: bao::Hash, opts: Options) -> impl Stream<Item = Result<Event>>
// 1. Send Handshake
{
debug!("sending handshake");
let handshake = Handshake::default();
let handshake = Handshake::new(token);
let used = postcard::to_slice(&handshake, &mut out_buffer)?;
write_lp(&mut writer, used).await?;
}
Expand Down
12 changes: 9 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ mod tests {
use std::{net::SocketAddr, path::PathBuf};

use crate::get::Event;
use crate::protocol::AuthToken;
use crate::tls::PeerId;

use super::*;
Expand All @@ -30,6 +31,7 @@ mod tests {
let addr = "127.0.0.1:4443".parse().unwrap();
let mut provider = provider::Provider::new(db);
let peer_id = provider.peer_id();
let token = provider.auth_token();

tokio::task::spawn(async move {
provider.run(provider::Options { addr }).await.unwrap();
Expand All @@ -39,7 +41,7 @@ mod tests {
addr,
peer_id: Some(peer_id),
};
let stream = get::run(hash, opts);
let stream = get::run(hash, token, opts);
tokio::pin!(stream);
while let Some(event) = stream.next().await {
let event = event?;
Expand Down Expand Up @@ -88,6 +90,7 @@ mod tests {
let hash = *db.iter().next().unwrap().0;
let mut provider = provider::Provider::new(db);
let peer_id = provider.peer_id();
let token = provider.auth_token();

let provider_task = tokio::task::spawn(async move {
provider.run(provider::Options { addr }).await.unwrap();
Expand All @@ -97,7 +100,7 @@ mod tests {
addr,
peer_id: Some(peer_id),
};
let stream = get::run(hash, opts);
let stream = get::run(hash, token, opts);
tokio::pin!(stream);
while let Some(event) = stream.next().await {
let event = event?;
Expand Down Expand Up @@ -132,13 +135,15 @@ mod tests {
let hash = *db.iter().next().unwrap().0;
let mut provider = provider::Provider::new(db);
let peer_id = provider.peer_id();
let token = provider.auth_token();

tokio::task::spawn(async move {
provider.run(provider::Options { addr }).await.unwrap();
});

async fn run_client(
hash: bao::Hash,
token: AuthToken,
addr: SocketAddr,
peer_id: PeerId,
content: Vec<u8>,
Expand All @@ -147,7 +152,7 @@ mod tests {
addr,
peer_id: Some(peer_id),
};
let stream = get::run(hash, opts);
let stream = get::run(hash, token, opts);
tokio::pin!(stream);
while let Some(event) = stream.next().await {
let event = event?;
Expand All @@ -169,6 +174,7 @@ mod tests {
for _i in 0..3 {
tasks.push(tokio::task::spawn(run_client(
hash,
token,
addr,
peer_id,
content.to_vec(),
Expand Down
29 changes: 25 additions & 4 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
use std::{net::SocketAddr, path::PathBuf};
use std::{net::SocketAddr, path::PathBuf, str::FromStr};

use anyhow::{anyhow, ensure, Context, Result};
use clap::{Parser, Subcommand};
use console::style;
use futures::StreamExt;
use indicatif::{HumanDuration, ProgressBar, ProgressDrawTarget, ProgressState, ProgressStyle};
use sendme::protocol::AuthToken;
use tracing::trace;
use tracing_subscriber::{fmt, prelude::*, EnvFilter};

use sendme::{get, provider, PeerId};
Expand All @@ -24,13 +26,19 @@ enum Commands {
#[clap(about = "Serve the data from the given path")]
Provide {
path: Option<PathBuf>,
#[clap(long, short)]
/// Optional port, defaults to 127.0.01:4433.
#[clap(long, short)]
addr: Option<SocketAddr>,
/// Auth token, defaults to random generated.
#[clap(long)]
auth_token: Option<String>,
},
/// Fetch some data
#[clap(about = "Fetch the data from the hash")]
Get {
/// The authentication token to present to the server.
token: String,
/// The root hash to retrieve.
hash: bao::Hash,
#[clap(long)]
/// PeerId of the provider.
Expand All @@ -55,6 +63,7 @@ async fn main() -> Result<()> {
match cli.command {
Commands::Get {
hash,
token,
peer_id,
addr,
out,
Expand All @@ -67,12 +76,15 @@ async fn main() -> Result<()> {
if let Some(addr) = addr {
opts.addr = addr;
}
let token =
AuthToken::from_str(&token).context("Wrong format for authentication token")?;

println!("{} Connecting ...", style("[1/3]").bold().dim());
let pb = ProgressBar::hidden();
let stream = get::run(hash, opts);
let stream = get::run(hash, token, opts);
tokio::pin!(stream);
while let Some(event) = stream.next().await {
trace!("client event: {:?}", event);
match event? {
get::Event::Connected => {
println!("{} Requesting ...", style("[2/3]").bold().dim());
Expand Down Expand Up @@ -130,7 +142,11 @@ async fn main() -> Result<()> {
}
}
}
Commands::Provide { path, addr } => {
Commands::Provide {
path,
addr,
auth_token,
} => {
let mut tmp_path = None;

let sources = if let Some(path) = path {
Expand All @@ -151,8 +167,13 @@ async fn main() -> Result<()> {
opts.addr = addr;
}
let mut provider = provider::Provider::new(db);
if let Some(ref hex) = auth_token {
let auth_token = AuthToken::from_str(hex)?;
provider.set_auth_token(auth_token);
}

println!("PeerID: {}", provider.peer_id());
println!("Auth token: {}", provider.auth_token());
provider.run(opts).await?;

// Drop tempath to signal it can be destroyed
Expand Down
85 changes: 80 additions & 5 deletions src/protocol.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
use std::fmt::Display;
use std::str::FromStr;

use anyhow::{bail, ensure, Result};
use bytes::BytesMut;
use postcard::experimental::max_size::MaxSize;
Expand All @@ -13,11 +16,15 @@ pub const VERSION: u64 = 1;
#[derive(Deserialize, Serialize, Debug, PartialEq, Clone, MaxSize)]
pub struct Handshake {
pub version: u64,
pub token: AuthToken,
}

impl Default for Handshake {
fn default() -> Self {
Handshake { version: VERSION }
impl Handshake {
pub fn new(token: AuthToken) -> Self {
Self {
version: VERSION,
token,
}
}
}

Expand Down Expand Up @@ -83,9 +90,7 @@ pub async fn read_lp<'a, R: AsyncRead + futures::io::AsyncRead + Unpin, T: Deser

while buffer.len() < size {
debug!("reading message {} {}", buffer.len(), size);
reader.read_buf(buffer).await?;
}

let response: T = postcard::from_bytes(&buffer[..size])?;
debug!("read message of size {}", size);

Expand Down Expand Up @@ -128,3 +133,73 @@ async fn read_prefix<R: AsyncRead + futures::io::AsyncRead + Unpin>(

Ok(size)
}

/// A token used to authenticate a handshake.
///
/// The token has a printable representation which can be serialised using [`Display`] and
/// deserialised using [`FromStr`].
#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Deserialize, Serialize, MaxSize)]
pub struct AuthToken {
bytes: [u8; 32],
}

impl AuthToken {
/// Generates a new random token.
pub fn generate() -> Self {
Self {
bytes: rand::random(),
}
}
}

/// Serialises the [`AuthToken`] to hex.
impl Display for AuthToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", hex::encode(self.bytes))
}
}

/// Error for parsing [`AuthToken`] using [`FromStr`].
#[derive(thiserror::Error, Debug)]
pub enum AuthTokenPraseError {
#[error("invalid encoding: {0}")]
Hex(#[from] hex::FromHexError),
#[error("invalid length: {0}")]
Length(usize),
}

/// Deserialises the [`AuthToken`] from hex.
impl FromStr for AuthToken {
type Err = AuthTokenPraseError;

fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
let decoded = hex::decode(s)?;
let bytes = decoded
.try_into()
.map_err(|v: Vec<u8>| AuthTokenPraseError::Length(v.len()))?;
Ok(AuthToken { bytes })
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_auth_token_hex() {
let token = AuthToken::generate();
println!("token: {token}");
let hex = token.to_string();
println!("token: {hex}");
let decoded = AuthToken::from_str(&hex).unwrap();
assert_eq!(decoded, token);

let err = AuthToken::from_str("not-hex").err().unwrap();
println!("err {err:#}");
assert!(matches!(err, AuthTokenPraseError::Hex(_)));

let err = AuthToken::from_str("abcd").err().unwrap();
println!("err {err:#}");
assert!(matches!(err, AuthTokenPraseError::Length(2)));
}
}
Loading

0 comments on commit bf08478

Please sign in to comment.