Skip to content

Commit a3e5ef3

Browse files
rklaehnFrando
andauthored
feat: Provider events refactor (#142)
## Description Refactor provider events into a proper irpc protocol. Also allow configuring for each event type if the event is sent as a notification, a proper request with answer, or not at all. ## Breaking Changes Provider events completely changed. Other than that the changes should be minimal. You can still create a BlobsProtocol by passing None. ## Notes & open questions Note: to review, best to start with looking at the limit example, then look at the docs. ## Change checklist - [ ] Self-review. - [ ] Documentation updates following the [style guide](https://rust-lang.github.io/rfcs/1574-more-api-documentation-conventions.html#appendix-a-full-conventions-text), if relevant. - [ ] Tests if relevant. - [ ] All breaking changes documented. --------- Co-authored-by: Frando <frando@unbiskant.org>
1 parent 55414b9 commit a3e5ef3

File tree

13 files changed

+1579
-678
lines changed

13 files changed

+1579
-678
lines changed

examples/limit.rs

Lines changed: 361 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,361 @@
1+
/// Example how to limit blob requests by hash and node id, and to add
2+
/// throttling or limiting the maximum number of connections.
3+
///
4+
/// Limiting is done via a fn that returns an EventSender and internally
5+
/// makes liberal use of spawn to spawn background tasks.
6+
///
7+
/// This is fine, since the tasks will terminate as soon as the [BlobsProtocol]
8+
/// instance holding the [EventSender] will be dropped. But for production
9+
/// grade code you might nevertheless put the tasks into a [tokio::task::JoinSet] or
10+
/// [n0_future::FuturesUnordered].
11+
mod common;
12+
use std::{
13+
collections::{HashMap, HashSet},
14+
path::PathBuf,
15+
sync::{
16+
atomic::{AtomicUsize, Ordering},
17+
Arc,
18+
},
19+
};
20+
21+
use anyhow::Result;
22+
use clap::Parser;
23+
use common::setup_logging;
24+
use iroh::{protocol::Router, NodeAddr, NodeId, SecretKey, Watcher};
25+
use iroh_blobs::{
26+
provider::events::{
27+
AbortReason, ConnectMode, EventMask, EventSender, ProviderMessage, RequestMode,
28+
ThrottleMode,
29+
},
30+
store::mem::MemStore,
31+
ticket::BlobTicket,
32+
BlobFormat, BlobsProtocol, Hash,
33+
};
34+
use rand::thread_rng;
35+
36+
use crate::common::get_or_generate_secret_key;
37+
38+
#[derive(Debug, Parser)]
39+
#[command(version, about)]
40+
pub enum Args {
41+
/// Limit requests by node id
42+
ByNodeId {
43+
/// Path for files to add.
44+
paths: Vec<PathBuf>,
45+
#[clap(long("allow"))]
46+
/// Nodes that are allowed to download content.
47+
allowed_nodes: Vec<NodeId>,
48+
/// Number of secrets to generate for allowed node ids.
49+
#[clap(long, default_value_t = 1)]
50+
secrets: usize,
51+
},
52+
/// Limit requests by hash, only first hash is allowed
53+
ByHash {
54+
/// Path for files to add.
55+
paths: Vec<PathBuf>,
56+
},
57+
/// Throttle requests
58+
Throttle {
59+
/// Path for files to add.
60+
paths: Vec<PathBuf>,
61+
/// Delay in milliseconds after sending a chunk group of 16 KiB.
62+
#[clap(long, default_value = "100")]
63+
delay_ms: u64,
64+
},
65+
/// Limit maximum number of connections.
66+
MaxConnections {
67+
/// Path for files to add.
68+
paths: Vec<PathBuf>,
69+
/// Maximum number of concurrent get requests.
70+
#[clap(long, default_value = "1")]
71+
max_connections: usize,
72+
},
73+
/// Get a blob. Just for completeness sake.
74+
Get {
75+
/// Ticket for the blob to download
76+
ticket: BlobTicket,
77+
},
78+
}
79+
80+
fn limit_by_node_id(allowed_nodes: HashSet<NodeId>) -> EventSender {
81+
let mask = EventMask {
82+
// We want a request for each incoming connection so we can accept
83+
// or reject them. We don't need any other events.
84+
connected: ConnectMode::Intercept,
85+
..EventMask::DEFAULT
86+
};
87+
let (tx, mut rx) = EventSender::channel(32, mask);
88+
n0_future::task::spawn(async move {
89+
while let Some(msg) = rx.recv().await {
90+
if let ProviderMessage::ClientConnected(msg) = msg {
91+
let node_id = msg.node_id;
92+
let res = if allowed_nodes.contains(&node_id) {
93+
println!("Client connected: {node_id}");
94+
Ok(())
95+
} else {
96+
println!("Client rejected: {node_id}");
97+
Err(AbortReason::Permission)
98+
};
99+
msg.tx.send(res).await.ok();
100+
}
101+
}
102+
});
103+
tx
104+
}
105+
106+
fn limit_by_hash(allowed_hashes: HashSet<Hash>) -> EventSender {
107+
let mask = EventMask {
108+
// We want to get a request for each get request that we can answer
109+
// with OK or not OK depending on the hash. We do not want detailed
110+
// events once it has been decided to handle a request.
111+
get: RequestMode::Intercept,
112+
..EventMask::DEFAULT
113+
};
114+
let (tx, mut rx) = EventSender::channel(32, mask);
115+
n0_future::task::spawn(async move {
116+
while let Some(msg) = rx.recv().await {
117+
if let ProviderMessage::GetRequestReceived(msg) = msg {
118+
let res = if !msg.request.ranges.is_blob() {
119+
println!("HashSeq request not allowed");
120+
Err(AbortReason::Permission)
121+
} else if !allowed_hashes.contains(&msg.request.hash) {
122+
println!("Request for hash {} not allowed", msg.request.hash);
123+
Err(AbortReason::Permission)
124+
} else {
125+
println!("Request for hash {} allowed", msg.request.hash);
126+
Ok(())
127+
};
128+
msg.tx.send(res).await.ok();
129+
}
130+
}
131+
});
132+
tx
133+
}
134+
135+
fn throttle(delay_ms: u64) -> EventSender {
136+
let mask = EventMask {
137+
// We want to get requests for each sent user data blob, so we can add a delay.
138+
// Other than that, we don't need any events.
139+
throttle: ThrottleMode::Intercept,
140+
..EventMask::DEFAULT
141+
};
142+
let (tx, mut rx) = EventSender::channel(32, mask);
143+
n0_future::task::spawn(async move {
144+
while let Some(msg) = rx.recv().await {
145+
if let ProviderMessage::Throttle(msg) = msg {
146+
n0_future::task::spawn(async move {
147+
println!(
148+
"Throttling {} {}, {}ms",
149+
msg.connection_id, msg.request_id, delay_ms
150+
);
151+
// we could compute the delay from the size of the data to have a fixed rate.
152+
// but the size is almost always 16 KiB (16 chunks).
153+
tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await;
154+
msg.tx.send(Ok(())).await.ok();
155+
});
156+
}
157+
}
158+
});
159+
tx
160+
}
161+
162+
fn limit_max_connections(max_connections: usize) -> EventSender {
163+
#[derive(Default, Debug, Clone)]
164+
struct ConnectionCounter(Arc<(AtomicUsize, usize)>);
165+
166+
impl ConnectionCounter {
167+
fn new(max: usize) -> Self {
168+
Self(Arc::new((Default::default(), max)))
169+
}
170+
171+
fn inc(&self) -> Result<usize, usize> {
172+
let (c, max) = &*self.0;
173+
c.fetch_update(Ordering::SeqCst, Ordering::SeqCst, |n| {
174+
if n >= *max {
175+
None
176+
} else {
177+
Some(n + 1)
178+
}
179+
})
180+
}
181+
182+
fn dec(&self) {
183+
let (c, _) = &*self.0;
184+
c.fetch_sub(1, Ordering::SeqCst);
185+
}
186+
}
187+
188+
let mask = EventMask {
189+
// For each get request, we want to get a request so we can decide
190+
// based on the current connection count if we want to accept or reject.
191+
// We also want detailed logging of events for the get request, so we can
192+
// detect when the request is finished one way or another.
193+
connected: ConnectMode::Intercept,
194+
..EventMask::DEFAULT
195+
};
196+
let (tx, mut rx) = EventSender::channel(32, mask);
197+
n0_future::task::spawn(async move {
198+
let requests = ConnectionCounter::new(max_connections);
199+
while let Some(msg) = rx.recv().await {
200+
match msg {
201+
ProviderMessage::ClientConnected(msg) => {
202+
let connection_id = msg.connection_id;
203+
let node_id = msg.node_id;
204+
let res = if let Ok(n) = requests.inc() {
205+
println!("Accepting connection {n}, node_id {node_id}, connection_id {connection_id}");
206+
Ok(())
207+
} else {
208+
Err(AbortReason::RateLimited)
209+
};
210+
msg.tx.send(res).await.ok();
211+
}
212+
ProviderMessage::ConnectionClosed(msg) => {
213+
requests.dec();
214+
println!("Connection closed, connection_id {}", msg.connection_id,);
215+
}
216+
_ => {}
217+
}
218+
}
219+
});
220+
tx
221+
}
222+
223+
#[tokio::main]
224+
async fn main() -> Result<()> {
225+
setup_logging();
226+
let args = Args::parse();
227+
let secret = get_or_generate_secret_key()?;
228+
let endpoint = iroh::Endpoint::builder()
229+
.secret_key(secret)
230+
.discovery_n0()
231+
.bind()
232+
.await?;
233+
match args {
234+
Args::Get { ticket } => {
235+
let connection = endpoint
236+
.connect(ticket.node_addr().clone(), iroh_blobs::ALPN)
237+
.await?;
238+
let (data, stats) = iroh_blobs::get::request::get_blob(connection, ticket.hash())
239+
.bytes_and_stats()
240+
.await?;
241+
println!("Downloaded {} bytes", data.len());
242+
println!("Stats: {stats:?}");
243+
}
244+
Args::ByNodeId {
245+
paths,
246+
allowed_nodes,
247+
secrets,
248+
} => {
249+
let mut allowed_nodes = allowed_nodes.into_iter().collect::<HashSet<_>>();
250+
if secrets > 0 {
251+
println!("Generating {secrets} new secret keys for allowed nodes:");
252+
let mut rand = thread_rng();
253+
for _ in 0..secrets {
254+
let secret = SecretKey::generate(&mut rand);
255+
let public = secret.public();
256+
allowed_nodes.insert(public);
257+
println!("IROH_SECRET={}", hex::encode(secret.to_bytes()));
258+
}
259+
}
260+
261+
let store = MemStore::new();
262+
let hashes = add_paths(&store, paths).await?;
263+
let events = limit_by_node_id(allowed_nodes.clone());
264+
let (router, addr) = setup(store, events).await?;
265+
266+
for (path, hash) in hashes {
267+
let ticket = BlobTicket::new(addr.clone(), hash, BlobFormat::Raw);
268+
println!("{}: {ticket}", path.display());
269+
}
270+
println!();
271+
println!("Node id: {}\n", router.endpoint().node_id());
272+
for id in &allowed_nodes {
273+
println!("Allowed node: {id}");
274+
}
275+
276+
tokio::signal::ctrl_c().await?;
277+
router.shutdown().await?;
278+
}
279+
Args::ByHash { paths } => {
280+
let store = MemStore::new();
281+
282+
let mut hashes = HashMap::new();
283+
let mut allowed_hashes = HashSet::new();
284+
for (i, path) in paths.into_iter().enumerate() {
285+
let tag = store.add_path(&path).await?;
286+
hashes.insert(path, tag.hash);
287+
if i == 0 {
288+
allowed_hashes.insert(tag.hash);
289+
}
290+
}
291+
292+
let events = limit_by_hash(allowed_hashes.clone());
293+
let (router, addr) = setup(store, events).await?;
294+
295+
for (path, hash) in hashes.iter() {
296+
let ticket = BlobTicket::new(addr.clone(), *hash, BlobFormat::Raw);
297+
let permitted = if allowed_hashes.contains(hash) {
298+
"allowed"
299+
} else {
300+
"forbidden"
301+
};
302+
println!("{}: {ticket} ({permitted})", path.display());
303+
}
304+
tokio::signal::ctrl_c().await?;
305+
router.shutdown().await?;
306+
}
307+
Args::Throttle { paths, delay_ms } => {
308+
let store = MemStore::new();
309+
let hashes = add_paths(&store, paths).await?;
310+
let events = throttle(delay_ms);
311+
let (router, addr) = setup(store, events).await?;
312+
for (path, hash) in hashes {
313+
let ticket = BlobTicket::new(addr.clone(), hash, BlobFormat::Raw);
314+
println!("{}: {ticket}", path.display());
315+
}
316+
tokio::signal::ctrl_c().await?;
317+
router.shutdown().await?;
318+
}
319+
Args::MaxConnections {
320+
paths,
321+
max_connections,
322+
} => {
323+
let store = MemStore::new();
324+
let hashes = add_paths(&store, paths).await?;
325+
let events = limit_max_connections(max_connections);
326+
let (router, addr) = setup(store, events).await?;
327+
for (path, hash) in hashes {
328+
let ticket = BlobTicket::new(addr.clone(), hash, BlobFormat::Raw);
329+
println!("{}: {ticket}", path.display());
330+
}
331+
tokio::signal::ctrl_c().await?;
332+
router.shutdown().await?;
333+
}
334+
}
335+
Ok(())
336+
}
337+
338+
async fn add_paths(store: &MemStore, paths: Vec<PathBuf>) -> Result<HashMap<PathBuf, Hash>> {
339+
let mut hashes = HashMap::new();
340+
for path in paths {
341+
let tag = store.add_path(&path).await?;
342+
hashes.insert(path, tag.hash);
343+
}
344+
Ok(hashes)
345+
}
346+
347+
async fn setup(store: MemStore, events: EventSender) -> Result<(Router, NodeAddr)> {
348+
let secret = get_or_generate_secret_key()?;
349+
let endpoint = iroh::Endpoint::builder()
350+
.discovery_n0()
351+
.secret_key(secret)
352+
.bind()
353+
.await?;
354+
let _ = endpoint.home_relay().initialized().await;
355+
let addr = endpoint.node_addr().initialized().await;
356+
let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events));
357+
let router = Router::builder(endpoint)
358+
.accept(iroh_blobs::ALPN, blobs)
359+
.spawn();
360+
Ok((router, addr))
361+
}

0 commit comments

Comments
 (0)