diff --git a/src/conn/mod.rs b/src/conn/mod.rs index e6040b33..e3e7b0ab 100644 --- a/src/conn/mod.rs +++ b/src/conn/mod.rs @@ -1529,6 +1529,30 @@ mod test { Ok(()) } + #[tokio::test] + async fn should_provide_multiresult_set_metadata() -> super::Result<()> { + let mut c = Conn::new(get_opts()).await?; + c.query_drop("CREATE TEMPORARY TABLE tmp (id INT, foo TEXT)") + .await?; + + let mut result = c + .query_iter("SELECT 1; SELECT id, foo FROM tmp WHERE 1 = 2; DO 42; SELECT 2;") + .await?; + assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 1); + + result.for_each(drop).await?; + assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 2); + + result.for_each(drop).await?; + assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 0); + + result.for_each(drop).await?; + assert_eq!(result.columns().map(|x| x.len()).unwrap_or_default(), 1); + + c.disconnect().await?; + Ok(()) + } + #[tokio::test] async fn should_handle_local_infile() -> super::Result<()> { use std::fs::write; diff --git a/src/queryable/mod.rs b/src/queryable/mod.rs index 19108f1b..c7251891 100644 --- a/src/queryable/mod.rs +++ b/src/queryable/mod.rs @@ -96,7 +96,7 @@ impl Conn { { self.write_command_data(Command::COM_QUERY, query.as_ref().as_bytes()) .await?; - self.read_result_set::().await?; + self.read_result_set::(true).await?; Ok(()) } } diff --git a/src/queryable/query_result.rs b/src/queryable/query_result.rs index 4c37fb0f..4893a667 100644 --- a/src/queryable/query_result.rs +++ b/src/queryable/query_result.rs @@ -91,7 +91,6 @@ where if columns.is_empty() { // Empty, but not yet consumed result set. self.conn.set_pending_result(None); - return Ok(None); } else { // Not yet consumed non-empty result set. let packet = match self.conn.read_packet().await { @@ -106,7 +105,6 @@ where if P::is_last_result_set_packet(self.conn.capabilities(), &packet) { // `packet` is a result set terminator. self.conn.set_pending_result(None); - return Ok(None); } else { // `packet` is a result set row. return Ok(Some(P::read_result_set_row(&packet, columns)?)); @@ -118,8 +116,8 @@ where if self.conn.more_results_exists() { // More data will follow. self.conn.sync_seq_id(); - self.conn.read_result_set::

().await?; - continue; + self.conn.read_result_set::

(false).await?; + return Ok(None); } else { // The end of a query result. return Ok(None); @@ -329,11 +327,26 @@ where impl crate::Conn { /// Will read result set and write pending result into `self` (if any). - pub(crate) async fn read_result_set

(&mut self) -> Result<()> + pub(crate) async fn read_result_set

(&mut self, is_first_result_set: bool) -> Result<()> where P: Protocol, { - let packet = self.read_packet().await?; + let packet = match self.read_packet().await { + Ok(packet) => packet, + Err(err @ Error::Server(_)) if is_first_result_set => { + // shortcut to emit an error right to the caller of a query/execute + return Err(err); + } + Err(Error::Server(error)) => { + // error will be consumed as a part of a multi-result set + self.set_pending_result(Some(ResultSetMeta::Error(error))); + return Ok(()); + } + Err(err) => { + // non-server errors are fatal + return Err(err); + } + }; match packet.get(0) { Some(0x00) => self.set_pending_result(Some(P::result_set_meta(Arc::from( diff --git a/src/queryable/stmt.rs b/src/queryable/stmt.rs index 743a98b1..7c6a8025 100644 --- a/src/queryable/stmt.rs +++ b/src/queryable/stmt.rs @@ -274,7 +274,7 @@ impl crate::Conn { } self.write_command_raw(body).await?; - self.read_result_set::().await?; + self.read_result_set::(true).await?; break; } Params::Named(_) => { @@ -303,7 +303,7 @@ impl crate::Conn { let (body, _) = ComStmtExecuteRequestBuilder::new(statement.id()).build(&[]); self.write_command_raw(body).await?; - self.read_result_set::().await?; + self.read_result_set::(true).await?; break; } }