Skip to content

Commit 68d284d

Browse files
committed
Make sure to send a proper reset code when resetting a connection
so the other side can know if reconnecting is OK
1 parent 9a62a58 commit 68d284d

File tree

2 files changed

+97
-47
lines changed

2 files changed

+97
-47
lines changed

src/provider.rs

Lines changed: 63 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,12 @@ use tracing::{debug, debug_span, warn, Instrument};
2020

2121
use crate::{
2222
api::{
23-
self,
2423
blobs::{Bitfield, WriteProgress},
25-
Store,
24+
ExportBaoResult, Store,
2625
},
2726
hashseq::HashSeq,
2827
protocol::{GetManyRequest, GetRequest, ObserveItem, ObserveRequest, PushRequest, Request},
29-
provider::events::{ClientConnected, ClientError, ConnectionClosed, RequestTracker},
28+
provider::events::{ClientConnected, ConnectionClosed, RequestTracker},
3029
Hash,
3130
};
3231
pub mod events;
@@ -94,7 +93,7 @@ impl StreamPair {
9493
}
9594

9695
/// We are done with reading. Return a ProgressWriter that contains the read stats and connection id
97-
pub async fn into_writer(
96+
async fn into_writer(
9897
mut self,
9998
tracker: RequestTracker,
10099
) -> Result<ProgressWriter, ReadToEndError> {
@@ -118,7 +117,7 @@ impl StreamPair {
118117
))
119118
}
120119

121-
pub async fn into_reader(
120+
async fn into_reader(
122121
mut self,
123122
tracker: RequestTracker,
124123
) -> Result<ProgressReader, ClosedStream> {
@@ -141,39 +140,71 @@ impl StreamPair {
141140
}
142141

143142
pub async fn get_request(
144-
&self,
143+
mut self,
145144
f: impl FnOnce() -> GetRequest,
146-
) -> Result<RequestTracker, ClientError> {
147-
self.events
145+
) -> anyhow::Result<ProgressWriter> {
146+
let res = self
147+
.events
148148
.request(f, self.connection_id, self.request_id)
149-
.await
149+
.await;
150+
match res {
151+
Err(e) => {
152+
self.writer.reset(e.code()).ok();
153+
Err(e.into())
154+
}
155+
Ok(tracker) => Ok(self.into_writer(tracker).await?),
156+
}
150157
}
151158

152159
pub async fn get_many_request(
153-
&self,
160+
mut self,
154161
f: impl FnOnce() -> GetManyRequest,
155-
) -> Result<RequestTracker, ClientError> {
156-
self.events
162+
) -> anyhow::Result<ProgressWriter> {
163+
let res = self
164+
.events
157165
.request(f, self.connection_id, self.request_id)
158-
.await
166+
.await;
167+
match res {
168+
Err(e) => {
169+
self.writer.reset(e.code()).ok();
170+
Err(e.into())
171+
}
172+
Ok(tracker) => Ok(self.into_writer(tracker).await?),
173+
}
159174
}
160175

161176
pub async fn push_request(
162-
&self,
177+
mut self,
163178
f: impl FnOnce() -> PushRequest,
164-
) -> Result<RequestTracker, ClientError> {
165-
self.events
179+
) -> anyhow::Result<ProgressReader> {
180+
let res = self
181+
.events
166182
.request(f, self.connection_id, self.request_id)
167-
.await
183+
.await;
184+
match res {
185+
Err(e) => {
186+
self.writer.reset(e.code()).ok();
187+
Err(e.into())
188+
}
189+
Ok(tracker) => Ok(self.into_reader(tracker).await?),
190+
}
168191
}
169192

170193
pub async fn observe_request(
171-
&self,
194+
mut self,
172195
f: impl FnOnce() -> ObserveRequest,
173-
) -> Result<RequestTracker, ClientError> {
174-
self.events
196+
) -> anyhow::Result<ProgressWriter> {
197+
let res = self
198+
.events
175199
.request(f, self.connection_id, self.request_id)
176-
.await
200+
.await;
201+
match res {
202+
Err(e) => {
203+
self.writer.reset(e.code()).ok();
204+
Err(e.into())
205+
}
206+
Ok(tracker) => Ok(self.into_writer(tracker).await?),
207+
}
177208
}
178209

179210
fn stats(&self) -> TransferStats {
@@ -299,7 +330,8 @@ pub async fn handle_connection(
299330
})
300331
.await
301332
{
302-
debug!("client not authorized to connect: {cause}");
333+
connection.close(cause.code(), cause.reason());
334+
debug!("closing connection: {cause}");
303335
return;
304336
}
305337
while let Ok(context) = StreamPair::accept(&connection, &progress).await {
@@ -323,35 +355,32 @@ async fn handle_stream(store: Store, mut context: StreamPair) -> anyhow::Result<
323355

324356
match request {
325357
Request::Get(request) => {
326-
let tracker = context.get_request(|| request.clone()).await?;
327-
let mut writer = context.into_writer(tracker).await?;
328-
if handle_get(store, request, &mut writer).await.is_ok() {
358+
let mut writer = context.get_request(|| request.clone()).await?;
359+
let res = handle_get(store, request, &mut writer).await;
360+
if res.is_ok() {
329361
writer.transfer_completed().await;
330362
} else {
331363
writer.transfer_aborted().await;
332364
}
333365
}
334366
Request::GetMany(request) => {
335-
let tracker = context.get_many_request(|| request.clone()).await?;
336-
let mut writer = context.into_writer(tracker).await?;
367+
let mut writer = context.get_many_request(|| request.clone()).await?;
337368
if handle_get_many(store, request, &mut writer).await.is_ok() {
338369
writer.transfer_completed().await;
339370
} else {
340371
writer.transfer_aborted().await;
341372
}
342373
}
343374
Request::Observe(request) => {
344-
let tracker = context.observe_request(|| request.clone()).await?;
345-
let mut writer = context.into_writer(tracker).await?;
375+
let mut writer = context.observe_request(|| request.clone()).await?;
346376
if handle_observe(store, request, &mut writer).await.is_ok() {
347377
writer.transfer_completed().await;
348378
} else {
349379
writer.transfer_aborted().await;
350380
}
351381
}
352382
Request::Push(request) => {
353-
let tracker = context.push_request(|| request.clone()).await?;
354-
let mut reader = context.into_reader(tracker).await?;
383+
let mut reader = context.push_request(|| request.clone()).await?;
355384
if handle_push(store, request, &mut reader).await.is_ok() {
356385
reader.transfer_completed().await;
357386
} else {
@@ -464,11 +493,11 @@ pub(crate) async fn send_blob(
464493
hash: Hash,
465494
ranges: ChunkRanges,
466495
writer: &mut ProgressWriter,
467-
) -> api::Result<()> {
468-
Ok(store
496+
) -> ExportBaoResult<()> {
497+
store
469498
.export_bao(hash, ranges)
470499
.write_quinn_with_progress(&mut writer.inner, &mut writer.context, &hash, index)
471-
.await?)
500+
.await
472501
}
473502

474503
/// Handle a single push request.

src/provider/events.rs

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@ use serde::{Deserialize, Serialize};
88
use snafu::Snafu;
99

1010
use crate::{
11-
protocol::{GetManyRequest, GetRequest, ObserveRequest, PushRequest},
11+
protocol::{
12+
GetManyRequest, GetRequest, ObserveRequest, PushRequest, ERR_INTERNAL, ERR_LIMIT,
13+
ERR_PERMISSION,
14+
},
1215
provider::{events::irpc_ext::IrpcClientExt, TransferStats},
1316
Hash,
1417
};
@@ -82,6 +85,24 @@ pub enum ClientError {
8285
},
8386
}
8487

88+
impl ClientError {
89+
pub fn code(&self) -> quinn::VarInt {
90+
match self {
91+
ClientError::RateLimited => ERR_LIMIT.into(),
92+
ClientError::Permission => ERR_PERMISSION.into(),
93+
ClientError::Irpc { .. } => ERR_INTERNAL.into(),
94+
}
95+
}
96+
97+
pub fn reason(&self) -> &'static [u8] {
98+
match self {
99+
ClientError::RateLimited => b"limit",
100+
ClientError::Permission => b"permission",
101+
ClientError::Irpc { .. } => b"internal",
102+
}
103+
}
104+
}
105+
85106
impl From<AbortReason> for ClientError {
86107
fn from(value: AbortReason) -> Self {
87108
match value {
@@ -211,11 +232,14 @@ impl RequestTracker {
211232
/// Transfer for index `index` started, size `size`
212233
pub async fn transfer_started(&self, index: u64, hash: &Hash, size: u64) -> irpc::Result<()> {
213234
if let RequestUpdates::Active(tx) = &self.updates {
214-
tx.send(RequestUpdate::Started(TransferStarted {
215-
index,
216-
hash: *hash,
217-
size,
218-
}))
235+
tx.send(
236+
TransferStarted {
237+
index,
238+
hash: *hash,
239+
size,
240+
}
241+
.into(),
242+
)
219243
.await?;
220244
}
221245
Ok(())
@@ -224,8 +248,7 @@ impl RequestTracker {
224248
/// Transfer progress for the previously reported blob, end_offset is the new end offset in bytes.
225249
pub async fn transfer_progress(&mut self, len: u64, end_offset: u64) -> ClientResult {
226250
if let RequestUpdates::Active(tx) = &mut self.updates {
227-
tx.try_send(RequestUpdate::Progress(TransferProgress { end_offset }))
228-
.await?;
251+
tx.try_send(TransferProgress { end_offset }.into()).await?;
229252
}
230253
if let Some((throttle, connection_id, request_id)) = &self.throttle {
231254
throttle
@@ -242,17 +265,15 @@ impl RequestTracker {
242265
/// Transfer completed for the previously reported blob.
243266
pub async fn transfer_completed(&self, f: impl Fn() -> Box<TransferStats>) -> irpc::Result<()> {
244267
if let RequestUpdates::Active(tx) = &self.updates {
245-
tx.send(RequestUpdate::Completed(TransferCompleted { stats: f() }))
246-
.await?;
268+
tx.send(TransferCompleted { stats: f() }.into()).await?;
247269
}
248270
Ok(())
249271
}
250272

251273
/// Transfer aborted for the previously reported blob.
252274
pub async fn transfer_aborted(&self, f: impl Fn() -> Box<TransferStats>) -> irpc::Result<()> {
253275
if let RequestUpdates::Active(tx) = &self.updates {
254-
tx.send(RequestUpdate::Aborted(TransferAborted { stats: f() }))
255-
.await?;
276+
tx.send(TransferAborted { stats: f() }.into()).await?;
256277
}
257278
Ok(())
258279
}
@@ -583,7 +604,7 @@ mod proto {
583604
}
584605

585606
/// Stream of updates for a single request
586-
#[derive(Debug, Serialize, Deserialize)]
607+
#[derive(Debug, Serialize, Deserialize, derive_more::From)]
587608
pub enum RequestUpdate {
588609
/// Start of transfer for a blob, mandatory event
589610
Started(TransferStarted),

0 commit comments

Comments
 (0)