Skip to content

Commit b26b408

Browse files
committed
refactor: make limits example more DRY
1 parent 33333a9 commit b26b408

File tree

2 files changed

+94
-98
lines changed

2 files changed

+94
-98
lines changed

examples/limit.rs

Lines changed: 86 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -18,17 +18,18 @@ use std::{
1818
},
1919
};
2020

21+
use anyhow::Result;
2122
use clap::Parser;
2223
use common::setup_logging;
23-
use iroh::{NodeId, SecretKey, Watcher};
24+
use iroh::{protocol::Router, NodeAddr, NodeId, SecretKey, Watcher};
2425
use iroh_blobs::{
2526
provider::events::{
2627
AbortReason, ConnectMode, EventMask, EventSender, ProviderMessage, RequestMode,
2728
ThrottleMode,
2829
},
2930
store::mem::MemStore,
3031
ticket::BlobTicket,
31-
BlobsProtocol, Hash,
32+
BlobFormat, BlobsProtocol, Hash,
3233
};
3334
use rand::thread_rng;
3435

@@ -77,7 +78,13 @@ pub enum Args {
7778
}
7879

7980
fn limit_by_node_id(allowed_nodes: HashSet<NodeId>) -> EventSender {
80-
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
81+
let mask = EventMask {
82+
// We want a request for each incoming connection so we can accept
83+
// or reject them. We don't need any other events.
84+
connected: ConnectMode::Request,
85+
..EventMask::DEFAULT
86+
};
87+
let (tx, mut rx) = EventSender::channel(32, mask);
8188
n0_future::task::spawn(async move {
8289
while let Some(msg) = rx.recv().await {
8390
if let ProviderMessage::ClientConnected(msg) = msg {
@@ -93,19 +100,18 @@ fn limit_by_node_id(allowed_nodes: HashSet<NodeId>) -> EventSender {
93100
}
94101
}
95102
});
96-
EventSender::new(
97-
tx,
98-
EventMask {
99-
// We want a request for each incoming connection so we can accept
100-
// or reject them. We don't need any other events.
101-
connected: ConnectMode::Request,
102-
..EventMask::DEFAULT
103-
},
104-
)
103+
tx
105104
}
106105

107106
fn limit_by_hash(allowed_hashes: HashSet<Hash>) -> EventSender {
108-
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
107+
let mask = EventMask {
108+
// We want to get a request for each get request that we can answer
109+
// with OK or not OK depending on the hash. We do not want detailed
110+
// events once it has been decided to handle a request.
111+
get: RequestMode::Request,
112+
..EventMask::DEFAULT
113+
};
114+
let (tx, mut rx) = EventSender::channel(32, mask);
109115
n0_future::task::spawn(async move {
110116
while let Some(msg) = rx.recv().await {
111117
if let ProviderMessage::GetRequestReceived(msg) = msg {
@@ -123,20 +129,17 @@ fn limit_by_hash(allowed_hashes: HashSet<Hash>) -> EventSender {
123129
}
124130
}
125131
});
126-
EventSender::new(
127-
tx,
128-
EventMask {
129-
// We want to get a request for each get request that we can answer
130-
// with OK or not OK depending on the hash. We do not want detailed
131-
// events once it has been decided to handle a request.
132-
get: RequestMode::Request,
133-
..EventMask::DEFAULT
134-
},
135-
)
132+
tx
136133
}
137134

138135
fn throttle(delay_ms: u64) -> EventSender {
139-
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
136+
let mask = EventMask {
137+
// We want to get requests for each sent user data blob, so we can add a delay.
138+
// Other than that, we don't need any events.
139+
throttle: ThrottleMode::Throttle,
140+
..EventMask::DEFAULT
141+
};
142+
let (tx, mut rx) = EventSender::channel(32, mask);
140143
n0_future::task::spawn(async move {
141144
while let Some(msg) = rx.recv().await {
142145
if let ProviderMessage::Throttle(msg) = msg {
@@ -153,15 +156,7 @@ fn throttle(delay_ms: u64) -> EventSender {
153156
}
154157
}
155158
});
156-
EventSender::new(
157-
tx,
158-
EventMask {
159-
// We want to get requests for each sent user data blob, so we can add a delay.
160-
// Other than that, we don't need any events.
161-
throttle: ThrottleMode::Throttle,
162-
..EventMask::DEFAULT
163-
},
164-
)
159+
tx
165160
}
166161

167162
fn limit_max_connections(max_connections: usize) -> EventSender {
@@ -190,7 +185,15 @@ fn limit_max_connections(max_connections: usize) -> EventSender {
190185
}
191186
}
192187

193-
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
188+
let mask = EventMask {
189+
// For each get request, we want to get a request so we can decide
190+
// based on the current connection count if we want to accept or reject.
191+
// We also want detailed logging of events for the get request, so we can
192+
// detect when the request is finished one way or another.
193+
get: RequestMode::RequestLog,
194+
..EventMask::DEFAULT
195+
};
196+
let (tx, mut rx) = EventSender::channel(32, mask);
194197
n0_future::task::spawn(async move {
195198
let requests = ConnectionCounter::new(max_connections);
196199
while let Some(msg) = rx.recv().await {
@@ -224,21 +227,11 @@ fn limit_max_connections(max_connections: usize) -> EventSender {
224227
}
225228
}
226229
});
227-
EventSender::new(
228-
tx,
229-
EventMask {
230-
// For each get request, we want to get a request so we can decide
231-
// based on the current connection count if we want to accept or reject.
232-
// We also want detailed logging of events for the get request, so we can
233-
// detect when the request is finished one way or another.
234-
get: RequestMode::RequestLog,
235-
..EventMask::DEFAULT
236-
},
237-
)
230+
tx
238231
}
239232

240233
#[tokio::main]
241-
async fn main() -> anyhow::Result<()> {
234+
async fn main() -> Result<()> {
242235
setup_logging();
243236
let args = Args::parse();
244237
match args {
@@ -274,35 +267,28 @@ async fn main() -> anyhow::Result<()> {
274267
println!("IROH_SECRET={}", hex::encode(secret.to_bytes()));
275268
}
276269
}
277-
let endpoint = iroh::Endpoint::builder().discovery_n0().bind().await?;
270+
278271
let store = MemStore::new();
279-
let mut hashes = HashMap::new();
280-
for path in paths {
281-
let tag = store.add_path(&path).await?;
282-
hashes.insert(path, tag.hash);
283-
}
284-
let _ = endpoint.home_relay().initialized().await;
285-
let addr = endpoint.node_addr().initialized().await;
272+
let hashes = add_paths(&store, paths).await?;
286273
let events = limit_by_node_id(allowed_nodes.clone());
287-
let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events));
288-
let router = iroh::protocol::Router::builder(endpoint)
289-
.accept(iroh_blobs::ALPN, blobs)
290-
.spawn();
274+
let (router, addr) = setup(MemStore::new(), events).await?;
275+
276+
for (path, hash) in hashes {
277+
let ticket = BlobTicket::new(addr.clone(), hash, BlobFormat::Raw);
278+
println!("{}: {ticket}", path.display());
279+
}
280+
println!();
291281
println!("Node id: {}\n", router.endpoint().node_id());
292282
for id in &allowed_nodes {
293283
println!("Allowed node: {id}");
294284
}
295-
println!();
296-
for (path, hash) in &hashes {
297-
let ticket = BlobTicket::new(addr.clone(), *hash, iroh_blobs::BlobFormat::Raw);
298-
println!("{}: {ticket}", path.display());
299-
}
285+
300286
tokio::signal::ctrl_c().await?;
301287
router.shutdown().await?;
302288
}
303289
Args::ByHash { paths } => {
304-
let endpoint = iroh::Endpoint::builder().discovery_n0().bind().await?;
305290
let store = MemStore::new();
291+
306292
let mut hashes = HashMap::new();
307293
let mut allowed_hashes = HashSet::new();
308294
for (i, path) in paths.into_iter().enumerate() {
@@ -312,38 +298,25 @@ async fn main() -> anyhow::Result<()> {
312298
allowed_hashes.insert(tag.hash);
313299
}
314300
}
315-
let _ = endpoint.home_relay().initialized().await;
316-
let addr = endpoint.node_addr().initialized().await;
317-
let events = limit_by_hash(allowed_hashes.clone());
318-
let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events));
319-
let router = iroh::protocol::Router::builder(endpoint)
320-
.accept(iroh_blobs::ALPN, blobs)
321-
.spawn();
301+
302+
let events = limit_by_hash(allowed_hashes);
303+
let (router, addr) = setup(MemStore::new(), events).await?;
304+
322305
for (i, (path, hash)) in hashes.iter().enumerate() {
323-
let ticket = BlobTicket::new(addr.clone(), *hash, iroh_blobs::BlobFormat::Raw);
306+
let ticket = BlobTicket::new(addr.clone(), *hash, BlobFormat::Raw);
324307
let permitted = if i == 0 { "" } else { "limited" };
325308
println!("{}: {ticket} ({permitted})", path.display());
326309
}
327310
tokio::signal::ctrl_c().await?;
328311
router.shutdown().await?;
329312
}
330313
Args::Throttle { paths, delay_ms } => {
331-
let endpoint = iroh::Endpoint::builder().discovery_n0().bind().await?;
332314
let store = MemStore::new();
333-
let mut hashes = HashMap::new();
334-
for path in paths {
335-
let tag = store.add_path(&path).await?;
336-
hashes.insert(path, tag.hash);
337-
}
338-
let _ = endpoint.home_relay().initialized().await;
339-
let addr = endpoint.node_addr().initialized().await;
315+
let hashes = add_paths(&store, paths).await?;
340316
let events = throttle(delay_ms);
341-
let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events));
342-
let router = iroh::protocol::Router::builder(endpoint)
343-
.accept(iroh_blobs::ALPN, blobs)
344-
.spawn();
317+
let (router, addr) = setup(MemStore::new(), events).await?;
345318
for (path, hash) in hashes {
346-
let ticket = BlobTicket::new(addr.clone(), hash, iroh_blobs::BlobFormat::Raw);
319+
let ticket = BlobTicket::new(addr.clone(), hash, BlobFormat::Raw);
347320
println!("{}: {ticket}", path.display());
348321
}
349322
tokio::signal::ctrl_c().await?;
@@ -353,22 +326,12 @@ async fn main() -> anyhow::Result<()> {
353326
paths,
354327
max_connections,
355328
} => {
356-
let endpoint = iroh::Endpoint::builder().discovery_n0().bind().await?;
357329
let store = MemStore::new();
358-
let mut hashes = HashMap::new();
359-
for path in paths {
360-
let tag = store.add_path(&path).await?;
361-
hashes.insert(path, tag.hash);
362-
}
363-
let _ = endpoint.home_relay().initialized().await;
364-
let addr = endpoint.node_addr().initialized().await;
330+
let hashes = add_paths(&store, paths).await?;
365331
let events = limit_max_connections(max_connections);
366-
let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events));
367-
let router = iroh::protocol::Router::builder(endpoint)
368-
.accept(iroh_blobs::ALPN, blobs)
369-
.spawn();
332+
let (router, addr) = setup(MemStore::new(), events).await?;
370333
for (path, hash) in hashes {
371-
let ticket = BlobTicket::new(addr.clone(), hash, iroh_blobs::BlobFormat::Raw);
334+
let ticket = BlobTicket::new(addr.clone(), hash, BlobFormat::Raw);
372335
println!("{}: {ticket}", path.display());
373336
}
374337
tokio::signal::ctrl_c().await?;
@@ -377,3 +340,28 @@ async fn main() -> anyhow::Result<()> {
377340
}
378341
Ok(())
379342
}
343+
344+
async fn add_paths(store: &MemStore, paths: Vec<PathBuf>) -> Result<HashMap<PathBuf, Hash>> {
345+
let mut hashes = HashMap::new();
346+
for path in paths {
347+
let tag = store.add_path(&path).await?;
348+
hashes.insert(path, tag.hash);
349+
}
350+
Ok(hashes)
351+
}
352+
353+
async fn setup(store: MemStore, events: EventSender) -> Result<(Router, NodeAddr)> {
354+
let secret = get_or_generate_secret_key()?;
355+
let endpoint = iroh::Endpoint::builder()
356+
.discovery_n0()
357+
.secret_key(secret)
358+
.bind()
359+
.await?;
360+
let _ = endpoint.home_relay().initialized().await;
361+
let addr = endpoint.node_addr().initialized().await;
362+
let blobs = BlobsProtocol::new(&store, endpoint.clone(), Some(events));
363+
let router = Router::builder(endpoint)
364+
.accept(iroh_blobs::ALPN, blobs)
365+
.spawn();
366+
Ok((router, addr))
367+
}

src/provider/events.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,14 @@ impl EventSender {
276276
}
277277
}
278278

279+
pub fn channel(
280+
capacity: usize,
281+
mask: EventMask,
282+
) -> (Self, tokio::sync::mpsc::Receiver<ProviderMessage>) {
283+
let (tx, rx) = tokio::sync::mpsc::channel(capacity);
284+
(Self::new(tx, mask), rx)
285+
}
286+
279287
/// Log request events at trace level.
280288
pub fn tracing(&self, mask: EventMask) -> Self {
281289
use tracing::trace;

0 commit comments

Comments
 (0)