Skip to content

Commit

Permalink
feat: Create a Builder for the Provider
Browse files Browse the repository at this point in the history
This is a nicer API to create a builder. Crucially it can give all
the information to connect to a provider once it is started.
Something that wasn't quite possible before.
  • Loading branch information
flub committed Jan 27, 2023
2 parents 4befae2 + 67e3828 commit bf087b6
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 90 deletions.
28 changes: 8 additions & 20 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,7 @@ mod tests {
let expect_name = Some(filename.to_string());

let (db, hash) = provider::create_db(vec![provider::DataSource::File(path)]).await?;
let mut provider = provider::Provider::builder().database(db).build()?;
let peer_id = provider.peer_id();
let token = provider.auth_token();

tokio::task::spawn(async move {
provider.run(provider::Options { addr }).await.unwrap();
});
let provider = provider::Provider::builder(db).bind_addr(addr).spawn()?;

async fn run_client(
hash: bao::Hash,
Expand Down Expand Up @@ -137,11 +131,11 @@ mod tests {
for _i in 0..3 {
tasks.push(tokio::task::spawn(run_client(
hash,
token,
provider.auth_token(),
expect_hash,
expect_name.clone(),
addr,
peer_id,
provider.peer_id(),
content.to_vec(),
)));
}
Expand Down Expand Up @@ -199,19 +193,13 @@ mod tests {
let (db, collection_hash) = provider::create_db(files).await?;

let addr = format!("127.0.0.1:{port}").parse().unwrap();
let mut provider = provider::Provider::builder().database(db).build()?;
let peer_id = provider.peer_id();
let token = provider.auth_token();

let provider_task = tokio::task::spawn(async move {
provider.run(provider::Options { addr }).await.unwrap();
});
let provider = provider::Provider::builder(db).bind_addr(addr).spawn()?;

let opts = get::Options {
addr,
peer_id: Some(peer_id),
peer_id: Some(provider.peer_id()),
};
let stream = get::run(collection_hash, token, opts);
let stream = get::run(collection_hash, provider.auth_token(), opts);
tokio::pin!(stream);

let mut i = 0;
Expand All @@ -234,8 +222,8 @@ mod tests {
}
}

provider_task.abort();
let _ = provider_task.await;
provider.abort();
let _ = provider.join().await;
Ok(())
}
}
11 changes: 5 additions & 6 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,20 +237,19 @@ async fn main() -> Result<()> {
};

let (db, _) = provider::create_db(sources).await?;
let mut opts = provider::Options::default();
let mut builder = provider::Provider::builder(db).keypair(keypair);
if let Some(addr) = addr {
opts.addr = addr;
builder = builder.bind_addr(addr);
}
let mut provider_builder = provider::Provider::builder().database(db).keypair(keypair);
if let Some(ref hex) = auth_token {
let auth_token = AuthToken::from_str(hex)?;
provider_builder = provider_builder.auth_token(auth_token);
builder = builder.auth_token(auth_token);
}
let mut provider = provider_builder.build()?;
let provider = builder.spawn()?;

out_writer.println(format!("PeerID: {}", provider.peer_id()));
out_writer.println(format!("Auth token: {}", provider.auth_token()));
provider.run(opts).await?;
provider.join().await?;

// Drop tempath to signal it can be destroyed
drop(tmp_path);
Expand Down
167 changes: 103 additions & 64 deletions src/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,33 +7,28 @@ use bytes::{Bytes, BytesMut};
use s2n_quic::stream::BidirectionalStream;
use s2n_quic::Server as QuicServer;
use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio::task::{JoinError, JoinHandle};
use tracing::{debug, warn};

use crate::blobs::{Blob, Collection};
use crate::protocol::{read_lp, write_lp, AuthToken, Handshake, Request, Res, Response, VERSION};
use crate::tls::{self, Keypair, PeerId};

#[derive(Clone, Debug)]
pub struct Options {
/// Address to listen on.
pub addr: SocketAddr,
}

impl Default for Options {
fn default() -> Self {
Options {
addr: "127.0.0.1:4433".parse().unwrap(),
}
}
}

const MAX_CONNECTIONS: u64 = 1024;
const MAX_STREAMS: u64 = 10;

pub type Database = Arc<HashMap<bao::Hash, BlobOrCollection>>;

/// Builder for the [`Provider`].
///
/// You must supply a database which can be created using [`create_db`], everything else is
/// optional. Finally you can create and run the provider by calling [`Builder::spawn`].
///
/// The returned [`Provider`] provides [`Provider::join`] to wait for the spawned task.
/// Currently it needs to be aborted using [`Provider::abort`], graceful shutdown will come.
#[derive(Debug)]
pub struct Provider {
pub struct Builder {
bind_addr: SocketAddr,
keypair: Keypair,
auth_token: AuthToken,
db: Database,
Expand All @@ -45,78 +40,72 @@ pub enum BlobOrCollection {
Collection((Bytes, Bytes)),
}

/// Builder to configure a `Provider`.
#[derive(Debug, Default)]
pub struct ProviderBuilder {
auth_token: Option<AuthToken>,
keypair: Option<Keypair>,
db: Option<Database>,
}
impl Builder {
/// Creates a new builder for [`Provider`] using the given [`Database`].
pub fn with_db(db: Database) -> Self {
Self {
bind_addr: "127.0.0.1:4433".parse().unwrap(),
keypair: Keypair::generate(),
auth_token: AuthToken::generate(),
db,
}
}

impl ProviderBuilder {
/// Set the authentication token, if none is provided a new one is generated.
pub fn auth_token(mut self, auth_token: AuthToken) -> Self {
self.auth_token = Some(auth_token);
/// Binds the provider service to a different socket.
///
/// By default it binds to `127.0.0.1:4433`.
pub fn bind_addr(mut self, addr: SocketAddr) -> Self {
self.bind_addr = addr;
self
}

/// Set the keypair, if none is provided a new one is generated.
/// Uses the given [`Keypair`] for the [`PeerId`] instead of a newly generated one.
pub fn keypair(mut self, keypair: Keypair) -> Self {
self.keypair = Some(keypair);
self.keypair = keypair;
self
}

/// Set the database.
pub fn database(mut self, db: Database) -> Self {
self.db = Some(db);
/// Uses the given [`AuthToken`] instead of a newly generated one.
pub fn auth_token(mut self, auth_token: AuthToken) -> Self {
self.auth_token = auth_token;
self
}

/// Consumes the builder and constructs a `Provider`.
pub fn build(self) -> Result<Provider> {
ensure!(self.db.is_some(), "missing database");

Ok(Provider {
auth_token: self.auth_token.unwrap_or_else(AuthToken::generate),
keypair: self.keypair.unwrap_or_else(Keypair::generate),
db: self.db.unwrap(),
})
}
}

impl Provider {
/// Returns a new `ProviderBuilder`.
pub fn builder() -> ProviderBuilder {
ProviderBuilder::default()
}

pub fn peer_id(&self) -> PeerId {
self.keypair.public().into()
}

pub fn auth_token(&self) -> AuthToken {
self.auth_token
}

pub async fn run(&mut self, opts: Options) -> Result<()> {
/// Spawns the [`Provider`] in a tokio task.
///
/// This will create the underlying network server and spawn a tokio task accepting
/// connections. The returned [`Provider`] can be used to control the task as well as
/// get information about it.
pub fn spawn(self) -> Result<Provider> {
let server_config = tls::make_server_config(&self.keypair)?;
let tls = s2n_quic::provider::tls::rustls::Server::from(server_config);
let limits = s2n_quic::provider::limits::Limits::default()
.with_max_active_connection_ids(MAX_CONNECTIONS)?
.with_max_open_local_bidirectional_streams(MAX_STREAMS)?
.with_max_open_remote_bidirectional_streams(MAX_STREAMS)?;

let mut server = QuicServer::builder()
let server = QuicServer::builder()
.with_tls(tls)?
.with_io(opts.addr)?
.with_io(self.bind_addr)?
.with_limits(limits)?
.start()
.map_err(|e| anyhow!("{:?}", e))?;
let token = self.auth_token;
debug!("\nlistening at: {:#?}", server.local_addr().unwrap());
let listen_addr = server.local_addr().unwrap();
let db2 = self.db.clone();
let task = tokio::spawn(async move { Self::run(server, db2, self.auth_token).await });

Ok(Provider {
listen_addr,
keypair: self.keypair,
auth_token: self.auth_token,
task,
})
}

async fn run(mut server: s2n_quic::server::Server, db: Database, token: AuthToken) {
debug!("\nlistening at: {:#?}", server.local_addr().unwrap());
while let Some(mut connection) = server.accept().await {
let db = self.db.clone();
let db = db.clone();
tokio::spawn(async move {
debug!("connection accepted from {:?}", connection.remote_addr());

Expand All @@ -131,8 +120,58 @@ impl Provider {
}
});
}
}
}

/// A server which implements the sendme provider.
///
/// Clients can connect to this server and requests hashes from it.
///
/// The only way to create this is by using the [`Builder::spawn`]. [`Provider::builder`]
/// is a shorthand to create a suitable [`Builder`].
///
/// This runs a tokio task which can be aborted and joined if desired.
pub struct Provider {
listen_addr: SocketAddr,
keypair: Keypair,
auth_token: AuthToken,
task: JoinHandle<()>,
}

impl Provider {
/// Returns a new builder for the [`Provider`].
///
/// Once the done with the builder call [`Builder::spawn`] to create the provider.
pub fn builder(db: Database) -> Builder {
Builder::with_db(db)
}

/// Returns the address on which the server is listening for connections.
pub fn listen_addr(&self) -> SocketAddr {
self.listen_addr
}

/// Returns the [`PeerId`] of the provider.
pub fn peer_id(&self) -> PeerId {
self.keypair.public().into()
}

/// Returns the [`AuthToken`] needed to connect to the provider.
pub fn auth_token(&self) -> AuthToken {
self.auth_token
}

/// Blocks until the provider task completes.
// TODO: Maybe implement Future directly?
pub async fn join(self) -> Result<(), JoinError> {
self.task.await
}

Ok(())
/// Aborts the provider.
///
/// TODO: temporary, do graceful shutdown instead.
pub fn abort(&self) {
self.task.abort();
}
}

Expand Down

0 comments on commit bf087b6

Please sign in to comment.