Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions src/conn/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1529,6 +1529,30 @@ mod test {
Ok(())
}

#[tokio::test]
async fn should_provide_multiresult_set_metadata() -> super::Result<()> {
let mut c = Conn::new(get_opts()).await?;
c.query_drop("CREATE TEMPORARY TABLE tmp (id INT, foo TEXT)")
.await?;

let mut result = c
.query_iter("SELECT 1; SELECT id, foo FROM tmp WHERE 1 = 2; DO 42; SELECT 2;")
.await?;
assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 1);

result.for_each(drop).await?;
assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 2);

result.for_each(drop).await?;
assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 0);

result.for_each(drop).await?;
assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 1);

c.disconnect().await?;
Ok(())
}

#[tokio::test]
async fn should_handle_local_infile() -> super::Result<()> {
use std::fs::write;
Expand Down
2 changes: 1 addition & 1 deletion src/queryable/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ impl Conn {
{
self.write_command_data(Command::COM_QUERY, query.as_ref().as_bytes())
.await?;
self.read_result_set::<TextProtocol>().await?;
self.read_result_set::<TextProtocol>(true).await?;
Ok(())
}
}
Expand Down
25 changes: 19 additions & 6 deletions src/queryable/query_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ where
if columns.is_empty() {
// Empty, but not yet consumed result set.
self.conn.set_pending_result(None);
return Ok(None);
} else {
// Not yet consumed non-empty result set.
let packet = match self.conn.read_packet().await {
Expand All @@ -106,7 +105,6 @@ where
if P::is_last_result_set_packet(self.conn.capabilities(), &packet) {
// `packet` is a result set terminator.
self.conn.set_pending_result(None);
return Ok(None);
} else {
// `packet` is a result set row.
return Ok(Some(P::read_result_set_row(&packet, columns)?));
Expand All @@ -118,8 +116,8 @@ where
if self.conn.more_results_exists() {
// More data will follow.
self.conn.sync_seq_id();
self.conn.read_result_set::<P>().await?;
continue;
self.conn.read_result_set::<P>(false).await?;
return Ok(None);
} else {
// The end of a query result.
return Ok(None);
Expand Down Expand Up @@ -329,11 +327,26 @@ where

impl crate::Conn {
/// Will read result set and write pending result into `self` (if any).
pub(crate) async fn read_result_set<P>(&mut self) -> Result<()>
pub(crate) async fn read_result_set<P>(&mut self, is_first_result_set: bool) -> Result<()>
where
P: Protocol,
{
let packet = self.read_packet().await?;
let packet = match self.read_packet().await {
Ok(packet) => packet,
Err(err @ Error::Server(_)) if is_first_result_set => {
// shortcut to emit an error right to the caller of a query/execute
return Err(err);
}
Err(Error::Server(error)) => {
// error will be consumed as a part of a multi-result set
self.set_pending_result(Some(ResultSetMeta::Error(error)));
return Ok(());
}
Err(err) => {
// non-server errors are fatal
return Err(err);
}
};

match packet.get(0) {
Some(0x00) => self.set_pending_result(Some(P::result_set_meta(Arc::from(
Expand Down
4 changes: 2 additions & 2 deletions src/queryable/stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ impl crate::Conn {
}

self.write_command_raw(body).await?;
self.read_result_set::<BinaryProtocol>().await?;
self.read_result_set::<BinaryProtocol>(true).await?;
break;
}
Params::Named(_) => {
Expand Down Expand Up @@ -303,7 +303,7 @@ impl crate::Conn {

let (body, _) = ComStmtExecuteRequestBuilder::new(statement.id()).build(&[]);
self.write_command_raw(body).await?;
self.read_result_set::<BinaryProtocol>().await?;
self.read_result_set::<BinaryProtocol>(true).await?;
break;
}
}
Expand Down