@@ -20,13 +20,12 @@ use tracing::{debug, debug_span, warn, Instrument};
20
20
21
21
use crate :: {
22
22
api:: {
23
- self ,
24
23
blobs:: { Bitfield , WriteProgress } ,
25
- Store ,
24
+ ExportBaoResult , Store ,
26
25
} ,
27
26
hashseq:: HashSeq ,
28
27
protocol:: { GetManyRequest , GetRequest , ObserveItem , ObserveRequest , PushRequest , Request } ,
29
- provider:: events:: { ClientConnected , ClientError , ConnectionClosed , RequestTracker } ,
28
+ provider:: events:: { ClientConnected , ConnectionClosed , RequestTracker } ,
30
29
Hash ,
31
30
} ;
32
31
pub mod events;
@@ -94,7 +93,7 @@ impl StreamPair {
94
93
}
95
94
96
95
/// We are done with reading. Return a ProgressWriter that contains the read stats and connection id
97
- pub async fn into_writer (
96
+ async fn into_writer (
98
97
mut self ,
99
98
tracker : RequestTracker ,
100
99
) -> Result < ProgressWriter , ReadToEndError > {
@@ -118,7 +117,7 @@ impl StreamPair {
118
117
) )
119
118
}
120
119
121
- pub async fn into_reader (
120
+ async fn into_reader (
122
121
mut self ,
123
122
tracker : RequestTracker ,
124
123
) -> Result < ProgressReader , ClosedStream > {
@@ -141,39 +140,71 @@ impl StreamPair {
141
140
}
142
141
143
142
pub async fn get_request (
144
- & self ,
143
+ mut self ,
145
144
f : impl FnOnce ( ) -> GetRequest ,
146
- ) -> Result < RequestTracker , ClientError > {
147
- self . events
145
+ ) -> anyhow:: Result < ProgressWriter > {
146
+ let res = self
147
+ . events
148
148
. request ( f, self . connection_id , self . request_id )
149
- . await
149
+ . await ;
150
+ match res {
151
+ Err ( e) => {
152
+ self . writer . reset ( e. code ( ) ) . ok ( ) ;
153
+ Err ( e. into ( ) )
154
+ }
155
+ Ok ( tracker) => Ok ( self . into_writer ( tracker) . await ?) ,
156
+ }
150
157
}
151
158
152
159
pub async fn get_many_request (
153
- & self ,
160
+ mut self ,
154
161
f : impl FnOnce ( ) -> GetManyRequest ,
155
- ) -> Result < RequestTracker , ClientError > {
156
- self . events
162
+ ) -> anyhow:: Result < ProgressWriter > {
163
+ let res = self
164
+ . events
157
165
. request ( f, self . connection_id , self . request_id )
158
- . await
166
+ . await ;
167
+ match res {
168
+ Err ( e) => {
169
+ self . writer . reset ( e. code ( ) ) . ok ( ) ;
170
+ Err ( e. into ( ) )
171
+ }
172
+ Ok ( tracker) => Ok ( self . into_writer ( tracker) . await ?) ,
173
+ }
159
174
}
160
175
161
176
pub async fn push_request (
162
- & self ,
177
+ mut self ,
163
178
f : impl FnOnce ( ) -> PushRequest ,
164
- ) -> Result < RequestTracker , ClientError > {
165
- self . events
179
+ ) -> anyhow:: Result < ProgressReader > {
180
+ let res = self
181
+ . events
166
182
. request ( f, self . connection_id , self . request_id )
167
- . await
183
+ . await ;
184
+ match res {
185
+ Err ( e) => {
186
+ self . writer . reset ( e. code ( ) ) . ok ( ) ;
187
+ Err ( e. into ( ) )
188
+ }
189
+ Ok ( tracker) => Ok ( self . into_reader ( tracker) . await ?) ,
190
+ }
168
191
}
169
192
170
193
pub async fn observe_request (
171
- & self ,
194
+ mut self ,
172
195
f : impl FnOnce ( ) -> ObserveRequest ,
173
- ) -> Result < RequestTracker , ClientError > {
174
- self . events
196
+ ) -> anyhow:: Result < ProgressWriter > {
197
+ let res = self
198
+ . events
175
199
. request ( f, self . connection_id , self . request_id )
176
- . await
200
+ . await ;
201
+ match res {
202
+ Err ( e) => {
203
+ self . writer . reset ( e. code ( ) ) . ok ( ) ;
204
+ Err ( e. into ( ) )
205
+ }
206
+ Ok ( tracker) => Ok ( self . into_writer ( tracker) . await ?) ,
207
+ }
177
208
}
178
209
179
210
fn stats ( & self ) -> TransferStats {
@@ -299,7 +330,8 @@ pub async fn handle_connection(
299
330
} )
300
331
. await
301
332
{
302
- debug ! ( "client not authorized to connect: {cause}" ) ;
333
+ connection. close ( cause. code ( ) , cause. reason ( ) ) ;
334
+ debug ! ( "closing connection: {cause}" ) ;
303
335
return ;
304
336
}
305
337
while let Ok ( context) = StreamPair :: accept ( & connection, & progress) . await {
@@ -323,35 +355,32 @@ async fn handle_stream(store: Store, mut context: StreamPair) -> anyhow::Result<
323
355
324
356
match request {
325
357
Request :: Get ( request) => {
326
- let tracker = context. get_request ( || request. clone ( ) ) . await ?;
327
- let mut writer = context . into_writer ( tracker ) . await ? ;
328
- if handle_get ( store , request , & mut writer ) . await . is_ok ( ) {
358
+ let mut writer = context. get_request ( || request. clone ( ) ) . await ?;
359
+ let res = handle_get ( store , request , & mut writer ) . await ;
360
+ if res . is_ok ( ) {
329
361
writer. transfer_completed ( ) . await ;
330
362
} else {
331
363
writer. transfer_aborted ( ) . await ;
332
364
}
333
365
}
334
366
Request :: GetMany ( request) => {
335
- let tracker = context. get_many_request ( || request. clone ( ) ) . await ?;
336
- let mut writer = context. into_writer ( tracker) . await ?;
367
+ let mut writer = context. get_many_request ( || request. clone ( ) ) . await ?;
337
368
if handle_get_many ( store, request, & mut writer) . await . is_ok ( ) {
338
369
writer. transfer_completed ( ) . await ;
339
370
} else {
340
371
writer. transfer_aborted ( ) . await ;
341
372
}
342
373
}
343
374
Request :: Observe ( request) => {
344
- let tracker = context. observe_request ( || request. clone ( ) ) . await ?;
345
- let mut writer = context. into_writer ( tracker) . await ?;
375
+ let mut writer = context. observe_request ( || request. clone ( ) ) . await ?;
346
376
if handle_observe ( store, request, & mut writer) . await . is_ok ( ) {
347
377
writer. transfer_completed ( ) . await ;
348
378
} else {
349
379
writer. transfer_aborted ( ) . await ;
350
380
}
351
381
}
352
382
Request :: Push ( request) => {
353
- let tracker = context. push_request ( || request. clone ( ) ) . await ?;
354
- let mut reader = context. into_reader ( tracker) . await ?;
383
+ let mut reader = context. push_request ( || request. clone ( ) ) . await ?;
355
384
if handle_push ( store, request, & mut reader) . await . is_ok ( ) {
356
385
reader. transfer_completed ( ) . await ;
357
386
} else {
@@ -464,11 +493,11 @@ pub(crate) async fn send_blob(
464
493
hash : Hash ,
465
494
ranges : ChunkRanges ,
466
495
writer : & mut ProgressWriter ,
467
- ) -> api :: Result < ( ) > {
468
- Ok ( store
496
+ ) -> ExportBaoResult < ( ) > {
497
+ store
469
498
. export_bao ( hash, ranges)
470
499
. write_quinn_with_progress ( & mut writer. inner , & mut writer. context , & hash, index)
471
- . await ? )
500
+ . await
472
501
}
473
502
474
503
/// Handle a single push request.
0 commit comments