Skip to content

Commit

Permalink
fix: ensure we emit TransferAborted event if anything goes wrong du…
Browse files Browse the repository at this point in the history
…ring the transfer (#150)

This refactor splits the `handle_stream` function up into smaller
(logical) functions. If these functions error, we can emit a
`TransferAborted` event before returning the error.

This has the added benefit of making the `handle_stream` function a bit
easier to take in. None of these new functions have any code changes,
the only difference is how errors are handled and when we return.
  • Loading branch information
ramfox committed Feb 17, 2023
1 parent 153073a commit 19e2b05
Show file tree
Hide file tree
Showing 3 changed files with 185 additions and 87 deletions.
11 changes: 9 additions & 2 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ mod tests {
let mut events = Vec::new();
while let Ok(event) = provider_events.recv().await {
match event {
Event::TransferCompleted { .. } => {
Event::TransferCompleted { .. } | Event::TransferAborted { .. } => {
events.push(event);
break;
}
Expand Down Expand Up @@ -265,11 +265,18 @@ mod tests {
provider.shutdown();
provider.await?;

assert_eq!(events.len(), 3);
assert_events(events);

Ok(())
}

fn assert_events(events: Vec<Event>) {
assert_eq!(events.len(), 3);
assert!(matches!(events[0], Event::ClientConnected { .. }));
assert!(matches!(events[1], Event::RequestReceived { .. }));
assert!(matches!(events[2], Event::TransferCompleted { .. }));
}

fn setup_logging() {
tracing_subscriber::registry()
.with(tracing_subscriber::fmt::layer().with_writer(std::io::stderr))
Expand Down
2 changes: 1 addition & 1 deletion src/progress.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ impl ProgressEmitter {
///
/// This exists so it can be Arc'd into [`ProgressEmitter`] and we can easily have multiple
/// `Send + Sync` copies of it. This is used by the
/// [`ProgressEmitter::ProgressAsyncReader`] to update the progress without intertwining
/// [`ProgressAsyncReader`] to update the progress without intertwining
/// lifetimes.
#[derive(Debug)]
struct InnerProgressEmitter {
Expand Down
259 changes: 175 additions & 84 deletions src/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use anyhow::{bail, ensure, Context, Result};
use bytes::{Bytes, BytesMut};
use futures::future;
use serde::{Deserialize, Serialize};
use tokio::io::{AsyncWrite, AsyncWriteExt};
use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
use tokio::sync::broadcast;
use tokio::task::{JoinError, JoinHandle};
use tokio_util::io::SyncIoBridge;
Expand Down Expand Up @@ -232,8 +232,9 @@ pub enum Event {
TransferAborted {
/// The quic connection id.
connection_id: u64,
/// The request id.
request_id: u64,
/// The request id. When `None`, the transfer was aborted before or during reading and decoding
/// the transfer request.
request_id: Option<u64>,
},
}

Expand Down Expand Up @@ -338,111 +339,201 @@ async fn handle_connection(
.await
}

async fn handle_stream(
db: Database,
/// Read and decode the handshake.
///
/// Will fail if there is an error while reading, there is a token
/// mismatch, or no valid handshake was received.
///
/// When successful, the reader is still useable after this function and the buffer will be drained of any handshake
/// data.
async fn read_handshake<R: AsyncRead + Unpin>(
mut reader: R,
buffer: &mut BytesMut,
token: AuthToken,
connection_id: u64,
(mut writer, mut reader): (quinn::SendStream, quinn::RecvStream),
events: broadcast::Sender<Event>,
) -> Result<()> {
let mut out_buffer = BytesMut::with_capacity(1024);
let mut in_buffer = BytesMut::with_capacity(1024);

// 1. Read Handshake
debug!("reading handshake");
if let Some((handshake, size)) = read_lp::<_, Handshake>(&mut reader, &mut in_buffer).await? {
if let Some((handshake, size)) = read_lp::<_, Handshake>(&mut reader, buffer).await? {
ensure!(
handshake.version == VERSION,
"expected version {} but got {}",
VERSION,
handshake.version
);
ensure!(handshake.token == token, "AuthToken mismatch");
let _ = in_buffer.split_to(size);
let _ = buffer.split_to(size);
} else {
bail!("no valid handshake received");
}
Ok(())
}

// 2. Decode the request.
debug!("reading request");
let request = read_lp::<_, Request>(&mut reader, &mut in_buffer).await?;
/// Read the request from the getter.
///
/// Will fail if there is an error while reading, if the reader
/// contains more data than the Request, or if no valid request is sent.
///
/// When successful, the buffer is empty after this function call.
async fn read_request(mut reader: quinn::RecvStream, buffer: &mut BytesMut) -> Result<Request> {
let request = read_lp::<_, Request>(&mut reader, buffer).await?;
ensure!(
reader.read_chunk(8, false).await?.is_none(),
"Extra data past request"
);
if let Some((request, _size)) = request {
let hash = request.name;
debug!("got request({})", request.id);
let _ = events.send(Event::RequestReceived {
connection_id,
request_id: request.id,
hash,
});
Ok(request)
} else {
bail!("No request received");
}
}

match db.get(&hash) {
// We only respond to requests for collections, not individual blobs
Some(BlobOrCollection::Collection((outboard, data))) => {
debug!("found collection {}", hash);
/// Transfers the collection & blob data.
///
/// First, it transfers the collection data & its associated outboard encoding data. Then it sequentially transfers each individual blob data & its associated outboard
/// encoding data.
///
/// Will fail if there is an error writing to the getter or reading from
/// the database.
///
/// If a blob from the collection cannot be found in the database, the transfer will gracefully
/// close the writer, and return with `Ok(SentStatus::NotFound)`.
///
/// If the transfer does _not_ end in error, the buffer will be empty and the writer is gracefully closed.
async fn transfer_collection(
// Database from which to fetch blobs.
db: &Database,
// Quinn stream.
mut writer: quinn::SendStream,
// Buffer used when writing to writer.
buffer: &mut BytesMut,
// The id of the transfer request.
request_id: u64,
// The bao outboard encoded data.
outboard: &Bytes,
// The actual blob data.
data: &Bytes,
) -> Result<SentStatus> {
// We only respond to requests for collections, not individual blobs
let mut extractor = SliceExtractor::new_outboard(
std::io::Cursor::new(&data[..]),
std::io::Cursor::new(&outboard[..]),
0,
data.len() as u64,
);
let encoded_size: usize = abao::encode::encoded_size(data.len() as u64)
.try_into()
.unwrap();
let mut encoded = Vec::with_capacity(encoded_size);
extractor.read_to_end(&mut encoded)?;

let c: Collection = postcard::from_bytes(data)?;

// TODO: we should check if the blobs referenced in this container
// actually exist in this provider before returning `FoundCollection`
write_response(
&mut writer,
buffer,
request_id,
Res::FoundCollection {
total_blobs_size: c.total_blobs_size,
},
)
.await?;

let mut data = BytesMut::from(&encoded[..]);
writer.write_buf(&mut data).await?;
for (i, blob) in c.blobs.iter().enumerate() {
debug!("writing blob {}/{}", i, c.blobs.len());
let (status, writer1) =
send_blob(db.clone(), blob.hash, writer, buffer, request_id).await?;
writer = writer1;
if SentStatus::NotFound == status {
write_response(&mut writer, buffer, request_id, Res::NotFound).await?;
writer.finish().await?;
return Ok(status);
}
}

let mut extractor = SliceExtractor::new_outboard(
std::io::Cursor::new(&data[..]),
std::io::Cursor::new(&outboard[..]),
0,
data.len() as u64,
);
let encoded_size: usize = abao::encode::encoded_size(data.len() as u64)
.try_into()
.unwrap();
let mut encoded = Vec::with_capacity(encoded_size);
extractor.read_to_end(&mut encoded)?;

let c: Collection = postcard::from_bytes(data)?;

// TODO: we should check if the blobs referenced in this container
// actually exist in this provider before returning `FoundCollection`
write_response(
&mut writer,
&mut out_buffer,
request.id,
Res::FoundCollection {
total_blobs_size: c.total_blobs_size,
},
)
.await?;

let mut data = BytesMut::from(&encoded[..]);
writer.write_buf(&mut data).await?;
for (i, blob) in c.blobs.iter().enumerate() {
debug!("writing blob {}/{}", i, c.blobs.len());
let (status, writer1) =
send_blob(db.clone(), blob.hash, writer, &mut out_buffer, request.id)
.await?;
writer = writer1;
if SentStatus::NotFound == status {
break;
}
}
writer.finish().await?;
Ok(SentStatus::Sent)
}

writer.finish().await?;
let _ = events.send(Event::TransferCompleted {
connection_id,
request_id: request.id,
});
debug!("finished response");
}
_ => {
debug!("not found {}", hash);
write_response(&mut writer, &mut out_buffer, request.id, Res::NotFound).await?;
writer.finish().await?;
// TODO: If the connection drops mid-way we also need to emit this!
let _ = events.send(Event::TransferAborted {
connection_id,
request_id: request.id,
});
}
fn notify_transfer_aborted(
events: broadcast::Sender<Event>,
connection_id: u64,
request_id: Option<u64>,
) {
let _ = events.send(Event::TransferAborted {
connection_id,
request_id,
});
}

async fn handle_stream(
db: Database,
token: AuthToken,
connection_id: u64,
(mut writer, mut reader): (quinn::SendStream, quinn::RecvStream),
events: broadcast::Sender<Event>,
) -> Result<()> {
let mut out_buffer = BytesMut::with_capacity(1024);
let mut in_buffer = BytesMut::with_capacity(1024);

// 1. Read Handshake
debug!("reading handshake");
if let Err(e) = read_handshake(&mut reader, &mut in_buffer, token).await {
notify_transfer_aborted(events, connection_id, None);
return Err(e);
}

// 2. Decode the request.
debug!("reading request");
let request = match read_request(reader, &mut in_buffer).await {
Ok(r) => r,
Err(e) => {
notify_transfer_aborted(events, connection_id, None);
return Err(e);
}
};

let hash = request.name;
debug!("got request({})", request.id);
let _ = events.send(Event::RequestReceived {
connection_id,
request_id: request.id,
hash,
});

// 4. Attempt to find hash
let (outboard, data) = match db.get(&hash) {
// We only respond to requests for collections, not individual blobs
Some(BlobOrCollection::Collection(d)) => d,
_ => {
debug!("not found {}", hash);
notify_transfer_aborted(events, connection_id, Some(request.id));
write_response(&mut writer, &mut out_buffer, request.id, Res::NotFound).await?;
writer.finish().await?;

return Ok(());
}
};

// 5. Transfer data!
match transfer_collection(&db, writer, &mut out_buffer, request.id, outboard, data).await {
Ok(SentStatus::Sent) => {
let _ = events.send(Event::TransferCompleted {
connection_id,
request_id: request.id,
});
}
Ok(SentStatus::NotFound) => {
notify_transfer_aborted(events, connection_id, Some(request.id));
}
Err(e) => {
notify_transfer_aborted(events, connection_id, Some(request.id));
return Err(e);
}
}

debug!("finished response");
Ok(())
}

Expand Down

0 comments on commit 19e2b05

Please sign in to comment.