Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
361 changes: 361 additions & 0 deletions examples/limit.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,361 @@
/// Example how to limit blob requests by hash and node id, and to add
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is very cool overall!

/// throttling or limiting the maximum number of connections.
///
/// Limiting is done via a fn that returns an EventSender and internally
/// makes liberal use of spawn to spawn background tasks.
///
/// This is fine, since the tasks will terminate as soon as the [BlobsProtocol]
/// instance holding the [EventSender] will be dropped. But for production
/// grade code you might nevertheless put the tasks into a [tokio::task::JoinSet] or
/// [n0_future::FuturesUnordered].
mod common;
use std::{
collections::{HashMap, HashSet},
path::PathBuf,
sync::{
atomic::{AtomicUsize, Ordering},
Arc,
},
};

use anyhow::Result;
use clap::Parser;
use common::setup_logging;
use iroh::{protocol::Router, NodeAddr, NodeId, SecretKey, Watcher};
use iroh_blobs::{
provider::events::{
AbortReason, ConnectMode, EventMask, EventSender, ProviderMessage, RequestMode,
ThrottleMode,
},
store::mem::MemStore,
ticket::BlobTicket,
BlobFormat, BlobsProtocol, Hash,
};
use rand::thread_rng;

use crate::common::get_or_generate_secret_key;

#[derive(Debug, Parser)]
#[command(version, about)]
pub enum Args {
/// Limit requests by node id
ByNodeId {
/// Path for files to add.
paths: Vec<PathBuf>,
#[clap(long("allow"))]
/// Nodes that are allowed to download content.
allowed_nodes: Vec<NodeId>,
/// Number of secrets to generate for allowed node ids.
#[clap(long, default_value_t = 1)]
secrets: usize,
},
/// Limit requests by hash, only first hash is allowed
ByHash {
/// Path for files to add.
paths: Vec<PathBuf>,
},
/// Throttle requests
Throttle {
/// Path for files to add.
paths: Vec<PathBuf>,
/// Delay in milliseconds after sending a chunk group of 16 KiB.
#[clap(long, default_value = "100")]
delay_ms: u64,
},
/// Limit maximum number of connections.
MaxConnections {
/// Path for files to add.
paths: Vec<PathBuf>,
/// Maximum number of concurrent get requests.
#[clap(long, default_value = "1")]
max_connections: usize,
},
/// Get a blob. Just for completeness sake.
Get {
/// Ticket for the blob to download
ticket: BlobTicket,
},
}

fn limit_by_node_id(allowed_nodes: HashSet<NodeId>) -> EventSender {
let mask = EventMask {
// We want a request for each incoming connection so we can accept
// or reject them. We don't need any other events.
connected: ConnectMode::Intercept,
..EventMask::DEFAULT
};
let (tx, mut rx) = EventSender::channel(32, mask);
n0_future::task::spawn(async move {
while let Some(msg) = rx.recv().await {
if let ProviderMessage::ClientConnected(msg) = msg {
let node_id = msg.node_id;
let res = if allowed_nodes.contains(&node_id) {
println!("Client connected: {node_id}");
Ok(())
} else {
println!("Client rejected: {node_id}");
Err(AbortReason::Permission)
};
msg.tx.send(res).await.ok();
}
}
});
tx
}

fn limit_by_hash(allowed_hashes: HashSet<Hash>) -> EventSender {
let mask = EventMask {
// We want to get a request for each get request that we can answer
// with OK or not OK depending on the hash. We do not want detailed
// events once it has been decided to handle a request.
get: RequestMode::Intercept,
..EventMask::DEFAULT
};
let (tx, mut rx) = EventSender::channel(32, mask);
n0_future::task::spawn(async move {
while let Some(msg) = rx.recv().await {
if let ProviderMessage::GetRequestReceived(msg) = msg {
let res = if !msg.request.ranges.is_blob() {
println!("HashSeq request not allowed");
Err(AbortReason::Permission)
} else if !allowed_hashes.contains(&msg.request.hash) {
println!("Request for hash {} not allowed", msg.request.hash);
Err(AbortReason::Permission)
} else {
println!("Request for hash {} allowed", msg.request.hash);
Ok(())
};
msg.tx.send(res).await.ok();
}
}
});
tx
}

fn throttle(delay_ms: u64) -> EventSender {
let mask = EventMask {
// We want to get requests for each sent user data blob, so we can add a delay.
// Other than that, we don't need any events.
throttle: ThrottleMode::Intercept,
..EventMask::DEFAULT
};
let (tx, mut rx) = EventSender::channel(32, mask);
n0_future::task::spawn(async move {
while let Some(msg) = rx.recv().await {
if let ProviderMessage::Throttle(msg) = msg {
n0_future::task::spawn(async move {
println!(
"Throttling {} {}, {}ms",
msg.connection_id, msg.request_id, delay_ms
);
// we could compute the delay from the size of the data to have a fixed rate.
// but the size is almost always 16 KiB (16 chunks).
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
msg.tx.send(Ok(())).await.ok();
});
}
}
});
tx
}

fn limit_max_connections(max_connections: usize) -> EventSender {
#[derive(Default, Debug, Clone)]
struct ConnectionCounter(Arc<(AtomicUsize, usize)>);

impl ConnectionCounter {
fn new(max: usize) -> Self {
Self(Arc::new((Default::default(), max)))
}

fn inc(&self) -> Result<usize, usize> {
let (c, max) = &*self.0;
c.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |n| {
if n >= *max {
None
} else {
Some(n + 1)
}
})
}

fn dec(&self) {
let (c, _) = &*self.0;
c.fetch_sub(1, Ordering::SeqCst);
}
}

let mask = EventMask {
// For each get request, we want to get a request so we can decide
// based on the current connection count if we want to accept or reject.
// We also want detailed logging of events for the get request, so we can
// detect when the request is finished one way or another.
connected: ConnectMode::Intercept,
..EventMask::DEFAULT
};
let (tx, mut rx) = EventSender::channel(32, mask);
n0_future::task::spawn(async move {
let requests = ConnectionCounter::new(max_connections);
while let Some(msg) = rx.recv().await {
match msg {
ProviderMessage::ClientConnected(msg) => {
let connection_id = msg.connection_id;
let node_id = msg.node_id;
let res = if let Ok(n) = requests.inc() {
println!("Accepting connection {n}, node_id {node_id}, connection_id {connection_id}");
Ok(())
} else {
Err(AbortReason::RateLimited)
};
msg.tx.send(res).await.ok();
}
ProviderMessage::ConnectionClosed(msg) => {
requests.dec();
println!("Connection closed, connection_id {}", msg.connection_id,);
}
_ => {}
}
}
});
tx
}

#[tokio::main]
async fn main() -> Result<()> {
setup_logging();
let args = Args::parse();
let secret = get_or_generate_secret_key()?;
let endpoint = iroh::Endpoint::builder()
.secret_key(secret)
.discovery_n0()
.bind()
.await?;
match args {
Args::Get { ticket } => {
let connection = endpoint
.connect(ticket.node_addr().clone(), iroh_blobs::ALPN)
.await?;
let (data, stats) = iroh_blobs::get::request::get_blob(connection, ticket.hash())
.bytes_and_stats()
.await?;
println!("Downloaded {} bytes", data.len());
println!("Stats: {stats:?}");
}
Args::ByNodeId {
paths,
allowed_nodes,
secrets,
} => {
let mut allowed_nodes = allowed_nodes.into_iter().collect::<HashSet<_>>();
if secrets > 0 {
println!("Generating {secrets} new secret keys for allowed nodes:");
let mut rand = thread_rng();
for _ in 0..secrets {
let secret = SecretKey::generate(&mut rand);
let public = secret.public();
allowed_nodes.insert(public);
println!("IROH_SECRET={}", hex::encode(secret.to_bytes()));
}
}

let store = MemStore::new();
let hashes = add_paths(&store, paths).await?;
let events = limit_by_node_id(allowed_nodes.clone());
let (router, addr) = setup(store, events).await?;

for (path, hash) in hashes {
let ticket = BlobTicket::new(addr.clone(), hash, BlobFormat::Raw);
println!("{}: {ticket}", path.display());
}
println!();
println!("Node id: {}\n", router.endpoint().node_id());
for id in &allowed_nodes {
println!("Allowed node: {id}");
}

tokio::signal::ctrl_c().await?;
router.shutdown().await?;
}
Args::ByHash { paths } => {
let store = MemStore::new();

let mut hashes = HashMap::new();
let mut allowed_hashes = HashSet::new();
for (i, path) in paths.into_iter().enumerate() {
let tag = store.add_path(&path).await?;
hashes.insert(path, tag.hash);
if i == 0 {
allowed_hashes.insert(tag.hash);
}
}

let events = limit_by_hash(allowed_hashes.clone());
let (router, addr) = setup(store, events).await?;

for (path, hash) in hashes.iter() {
let ticket = BlobTicket::new(addr.clone(), *hash, BlobFormat::Raw);
let permitted = if allowed_hashes.contains(hash) {
"allowed"
} else {
"forbidden"
};
println!("{}: {ticket} ({permitted})", path.display());
}
tokio::signal::ctrl_c().await?;
router.shutdown().await?;
}
Args::Throttle { paths, delay_ms } => {
let store = MemStore::new();
let hashes = add_paths(&store, paths).await?;
let events = throttle(delay_ms);
let (router, addr) = setup(store, events).await?;
for (path, hash) in hashes {
let ticket = BlobTicket::new(addr.clone(), hash, BlobFormat::Raw);
println!("{}: {ticket}", path.display());
}
tokio::signal::ctrl_c().await?;
router.shutdown().await?;
}
Args::MaxConnections {
paths,
max_connections,
} => {
let store = MemStore::new();
let hashes = add_paths(&store, paths).await?;
let events = limit_max_connections(max_connections);
let (router, addr) = setup(store, events).await?;
for (path, hash) in hashes {
let ticket = BlobTicket::new(addr.clone(), hash, BlobFormat::Raw);
println!("{}: {ticket}", path.display());
}
tokio::signal::ctrl_c().await?;
router.shutdown().await?;
}
}
Ok(())
}

async fn add_paths(store: &MemStore, paths: Vec<PathBuf>) -> Result<HashMap<PathBuf, Hash>> {
let mut hashes = HashMap::new();
for path in paths {
let tag = store.add_path(&path).await?;
hashes.insert(path, tag.hash);
}
Ok(hashes)
}

async fn setup(store: MemStore, events: EventSender) -> Result<(Router, NodeAddr)> {
let secret = get_or_generate_secret_key()?;
let endpoint = iroh::Endpoint::builder()
.discovery_n0()
.secret_key(secret)
.bind()
.await?;
let _ = endpoint.home_relay().initialized().await;
let addr = endpoint.node_addr().initialized().await;
let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events));
let router = Router::builder(endpoint)
.accept(iroh_blobs::ALPN, blobs)
.spawn();
Ok((router, addr))
}
Loading
Loading