diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 4e44c962..8009e359 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -9,7 +9,7 @@ env: RUST_BACKTRACE: 1 toolchain_style: stable toolchain_msrv: 1.63 - toolchain_h3_quinn_msrv: 1.63 + toolchain_h3_quinn_msrv: 1.66 toolchain_doc: nightly-2023-10-21 toolchain_lint: stable toolchain_fuzz: nightly-2023-10-21 diff --git a/examples/Cargo.toml b/examples/Cargo.toml index d957d7e6..6cdf5feb 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -14,14 +14,14 @@ h3 = { path = "../h3" } h3-quinn = { path = "../h3-quinn" } h3-webtransport = { path = "../h3-webtransport" } http = "1" -quinn = { version = "0.10", default-features = false, features = [ +quinn = { version = "0.11", default-features = false, features = [ "runtime-tokio", - "tls-rustls", + "rustls", "ring", ] } -rcgen = { version = "0.12" } -rustls = { version = "0.21", features = ["dangerous_configuration"] } -rustls-native-certs = "0.6" +rcgen = { version = "0.13" } +rustls = { version = "0.23", default-features = false, features = ["logging", "ring", "std"] } +rustls-native-certs = "0.7" structopt = "0.3" tokio = { version = "1.27", features = ["full"] } tracing = "0.1.37" diff --git a/examples/client.rs b/examples/client.rs index 3f8d692b..f389e7f1 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -1,6 +1,7 @@ use std::{path::PathBuf, sync::Arc}; use futures::future; +use rustls::pki_types::CertificateDer; use structopt::StructOpt; use tokio::io::AsyncWriteExt; use tracing::{error, info}; @@ -64,7 +65,7 @@ async fn main() -> Result<(), Box> { match rustls_native_certs::load_native_certs() { Ok(certs) => { for cert in certs { - if let Err(e) = roots.add(&rustls::Certificate(cert.0)) { + if let Err(e) = roots.add(cert) { error!("failed to parse trust anchor: {}", e); } } @@ -76,14 +77,11 @@ async fn main() -> Result<(), Box> { // load certificate of CA who issues the server certificate // NOTE that this should be used for dev only - if let Err(e) = roots.add(&rustls::Certificate(std::fs::read(opt.ca)?)) { + if let Err(e) = roots.add(CertificateDer::from(std::fs::read(opt.ca)?)) { error!("failed to parse trust anchor: {}", e); } let mut tls_config = rustls::ClientConfig::builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() - .with_protocol_versions(&[&rustls::version::TLS13])? .with_root_certificates(roots) .with_no_client_auth(); @@ -99,7 +97,9 @@ async fn main() -> Result<(), Box> { let mut client_endpoint = h3_quinn::quinn::Endpoint::client("[::]:0".parse().unwrap())?; - let client_config = quinn::ClientConfig::new(Arc::new(tls_config)); + let client_config = quinn::ClientConfig::new(Arc::new( + quinn::crypto::rustls::QuicClientConfig::try_from(tls_config)?, + )); client_endpoint.set_default_client_config(client_config); let conn = client_endpoint.connect(addr, auth.host())?.await?; diff --git a/examples/server.rs b/examples/server.rs index 339e7e25..6a917ca1 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -2,13 +2,13 @@ use std::{net::SocketAddr, path::PathBuf, sync::Arc}; use bytes::{Bytes, BytesMut}; use http::{Request, StatusCode}; -use rustls::{Certificate, PrivateKey}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use structopt::StructOpt; use tokio::{fs::File, io::AsyncReadExt}; use tracing::{error, info, trace_span}; use h3::{error::ErrorLevel, quic::BidiStream, server::RequestStream}; -use h3_quinn::quinn; +use h3_quinn::quinn::{self, crypto::rustls::QuicServerConfig}; #[derive(StructOpt, Debug)] #[structopt(name = "server")] @@ -84,21 +84,18 @@ async fn main() -> Result<(), Box> { // create quinn server endpoint and bind UDP socket // both cert and key must be DER-encoded - let cert = Certificate(std::fs::read(cert)?); - let key = PrivateKey(std::fs::read(key)?); + let cert = CertificateDer::from(std::fs::read(cert)?); + let key = PrivateKeyDer::try_from(std::fs::read(key)?)?; let mut tls_config = rustls::ServerConfig::builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() - .with_protocol_versions(&[&rustls::version::TLS13]) - .unwrap() .with_no_client_auth() .with_single_cert(vec![cert], key)?; tls_config.max_early_data_size = u32::MAX; tls_config.alpn_protocols = vec![ALPN.into()]; - let server_config = quinn::ServerConfig::with_crypto(Arc::new(tls_config)); + let server_config = + quinn::ServerConfig::with_crypto(Arc::new(QuicServerConfig::try_from(tls_config)?)); let endpoint = quinn::Endpoint::server(server_config, opt.listen)?; info!("listening on {}", opt.listen); diff --git a/examples/webtransport_server.rs b/examples/webtransport_server.rs index 58d4ba43..9ecc7d15 100644 --- a/examples/webtransport_server.rs +++ b/examples/webtransport_server.rs @@ -6,13 +6,13 @@ use h3::{ quic::{self, RecvDatagramExt, SendDatagramExt, SendStreamUnframed}, server::Connection, }; -use h3_quinn::quinn; +use h3_quinn::quinn::{self, crypto::rustls::QuicServerConfig}; use h3_webtransport::{ server::{self, WebTransportSession}, stream, }; use http::Method; -use rustls::{Certificate, PrivateKey}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use std::{net::SocketAddr, path::PathBuf, sync::Arc, time::Duration}; use structopt::StructOpt; use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; @@ -81,14 +81,10 @@ async fn main() -> Result<(), Box> { // create quinn server endpoint and bind UDP socket // both cert and key must be DER-encoded - let cert = Certificate(std::fs::read(cert)?); - let key = PrivateKey(std::fs::read(key)?); + let cert = CertificateDer::from(std::fs::read(cert)?); + let key = PrivateKeyDer::try_from(std::fs::read(key)?)?; let mut tls_config = rustls::ServerConfig::builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() - .with_protocol_versions(&[&rustls::version::TLS13]) - .unwrap() .with_no_client_auth() .with_single_cert(vec![cert], key)?; @@ -102,7 +98,8 @@ async fn main() -> Result<(), Box> { ]; tls_config.alpn_protocols = alpn; - let mut server_config = quinn::ServerConfig::with_crypto(Arc::new(tls_config)); + let mut server_config = + quinn::ServerConfig::with_crypto(Arc::new(QuicServerConfig::try_from(tls_config)?)); let mut transport_config = quinn::TransportConfig::default(); transport_config.keep_alive_interval(Some(Duration::from_secs(2))); server_config.transport = Arc::new(transport_config); diff --git a/h3-quinn/Cargo.toml b/h3-quinn/Cargo.toml index 94a9e0dc..b634e024 100644 --- a/h3-quinn/Cargo.toml +++ b/h3-quinn/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "h3-quinn" version = "0.0.5" -rust-version = "1.63" +rust-version = "1.66" authors = ["Jean-Christophe BEGUE "] edition = "2021" documentation = "https://docs.rs/h3-quinn" @@ -15,10 +15,9 @@ license = "MIT" [dependencies] h3 = { version = "0.0.4", path = "../h3" } bytes = "1" -quinn = { version = "0.10", default-features = false, features = [ +quinn = { version = "0.11", default-features = false, features = [ "futures-io", ] } -quinn-proto = { version = "0.10", default-features = false } tokio-util = { version = "0.7.9" } futures = { version = "0.3.28" } tokio = { version = "1", features = ["io-util"], default-features = false } diff --git a/h3-quinn/src/lib.rs b/h3-quinn/src/lib.rs index 573fa823..daec2057 100644 --- a/h3-quinn/src/lib.rs +++ b/h3-quinn/src/lib.rs @@ -19,8 +19,8 @@ use futures::{ stream::{self, BoxStream}, StreamExt, }; -use quinn::ReadDatagram; pub use quinn::{self, AcceptBi, AcceptUni, Endpoint, OpenBi, OpenUni, VarInt, WriteError}; +use quinn::{ApplicationClose, ClosedStream, ReadDatagram}; use h3::{ ext::Datagram, @@ -81,10 +81,9 @@ impl Error for ConnectionError { fn err_code(&self) -> Option { match self.0 { - quinn::ConnectionError::ApplicationClosed(quinn_proto::ApplicationClose { - error_code, - .. - }) => Some(error_code.into_inner()), + quinn::ConnectionError::ApplicationClosed(ApplicationClose { error_code, .. }) => { + Some(error_code.into_inner()) + } _ => None, } } @@ -529,7 +528,7 @@ impl Error for ReadError { fn err_code(&self) -> Option { match self.0 { quinn::ReadError::ConnectionLost(quinn::ConnectionError::ApplicationClosed( - quinn_proto::ApplicationClose { error_code, .. }, + ApplicationClose { error_code, .. }, )) => Some(error_code.into_inner()), quinn::ReadError::Reset(error_code) => Some(error_code.into_inner()), _ => None, @@ -593,12 +592,8 @@ where Poll::Ready(Ok(())) } - fn poll_finish(&mut self, cx: &mut task::Context<'_>) -> Poll> { - self.stream - .as_mut() - .unwrap() - .poll_finish(cx) - .map_err(Into::into) + fn poll_finish(&mut self, _cx: &mut task::Context<'_>) -> Poll> { + Poll::Ready(self.stream.as_mut().unwrap().finish().map_err(|e| e.into())) } fn reset(&mut self, reset_code: u64) { @@ -680,6 +675,8 @@ pub enum SendStreamError { /// Error when the stream is not ready, because it is still sending /// data from a previous call NotReady, + /// Error when the stream is closed + StreamClosed(ClosedStream), } impl From for std::io::Error { @@ -689,6 +686,7 @@ impl From for std::io::Error { SendStreamError::NotReady => { std::io::Error::new(std::io::ErrorKind::Other, "send stream is not ready") } + SendStreamError::StreamClosed(err) => err.into(), } } } @@ -707,6 +705,12 @@ impl From for SendStreamError { } } +impl From for SendStreamError { + fn from(value: ClosedStream) -> Self { + Self::StreamClosed(value) + } +} + impl Error for SendStreamError { fn is_timeout(&self) -> bool { matches!( @@ -721,10 +725,7 @@ impl Error for SendStreamError { match self { Self::Write(quinn::WriteError::Stopped(error_code)) => Some(error_code.into_inner()), Self::Write(quinn::WriteError::ConnectionLost( - quinn::ConnectionError::ApplicationClosed(quinn_proto::ApplicationClose { - error_code, - .. - }), + quinn::ConnectionError::ApplicationClosed(ApplicationClose { error_code, .. }), )) => Some(error_code.into_inner()), _ => None, } diff --git a/h3/Cargo.toml b/h3/Cargo.toml index 1756f705..5d620e6b 100644 --- a/h3/Cargo.toml +++ b/h3/Cargo.toml @@ -35,14 +35,14 @@ fastrand = "2.0.1" assert_matches = "1.5.0" futures-util = { version = "0.3", default-features = false, features = ["io"] } proptest = "1" -quinn = { version = "0.10", default-features = false, features = [ +quinn = { version = "0.11", default-features = false, features = [ "runtime-tokio", - "tls-rustls", + "rustls", "ring", ] } -quinn-proto = { version = "0.10", default-features = false } -rcgen = "0.12" -rustls = "0.21" +quinn-proto = { version = "0.11", default-features = false } +rcgen = "0.13" +rustls = { version = "0.23", default-features = false, features = ["logging", "ring", "std"] } tokio = { version = "1", features = ["rt", "macros", "io-util", "io-std"] } tracing-subscriber = { version = "0.3", default-features = false, features = [ "fmt", diff --git a/h3/src/tests/connection.rs b/h3/src/tests/connection.rs index d88fa600..7d65b604 100644 --- a/h3/src/tests/connection.rs +++ b/h3/src/tests/connection.rs @@ -7,6 +7,7 @@ use assert_matches::assert_matches; use bytes::{Buf, Bytes, BytesMut}; use futures_util::future; use http::{Request, Response, StatusCode}; +use tokio::sync::oneshot::{self}; use crate::client::SendRequest; use crate::{client, server}; @@ -32,15 +33,16 @@ async fn connect() { let mut server = pair.server(); let client_fut = async { - let _ = client::new(pair.client().await).await.expect("client init"); + let (mut drive, _client) = client::new(pair.client().await).await.expect("client init"); + future::poll_fn(|cx| drive.poll_close(cx)).await.unwrap(); }; let server_fut = async { let conn = server.next().await; - let _ = server::Connection::new(conn).await.unwrap(); + let _server = server::Connection::new(conn).await.unwrap(); }; - tokio::join!(server_fut, client_fut); + tokio::select!(() = server_fut => (), () = client_fut => panic!("client resolved first")); } #[tokio::test] @@ -48,14 +50,21 @@ async fn accept_request_end_on_client_close() { let mut pair = Pair::default(); let mut server = pair.server(); + let (tx, rx) = oneshot::channel::<()>(); + let client_fut = async { - let _ = client::new(pair.client().await).await.expect("client init"); + let client = pair.client().await; + let client = client::new(client).await.expect("client init"); + // wait for the server to accept the connection + rx.await.unwrap(); // client is dropped, it will send H3_NO_ERROR + drop(client); }; let server_fut = async { let conn = server.next().await; let mut incoming = server::Connection::new(conn).await.unwrap(); + tx.send(()).unwrap(); // Accept returns Ok(None) assert!(incoming.accept().await.unwrap().is_none()); }; @@ -65,6 +74,7 @@ async fn accept_request_end_on_client_close() { #[tokio::test] async fn server_drop_close() { + init_tracing(); let mut pair = Pair::default(); let mut server = pair.server(); @@ -73,8 +83,8 @@ async fn server_drop_close() { let _ = server::Connection::new(conn).await.unwrap(); }; - let (mut conn, mut send) = client::new(pair.client().await).await.expect("client init"); let client_fut = async { + let (mut conn, mut send) = client::new(pair.client().await).await.expect("client init"); let request_fut = async move { let mut request_stream = send .send_request(Request::get("http://no.way").body(()).unwrap()) @@ -131,6 +141,7 @@ async fn server_send_data_without_finish() { #[tokio::test] async fn client_close_only_on_last_sender_drop() { + init_tracing(); let mut pair = Pair::default(); let mut server = pair.server(); @@ -145,18 +156,24 @@ async fn client_close_only_on_last_sender_drop() { let client_fut = async { let (mut conn, mut send1) = client::new(pair.client().await).await.expect("client init"); let mut send2 = send1.clone(); - let _ = send1 + let mut request_stream_1 = send1 .send_request(Request::get("http://no.way").body(()).unwrap()) .await - .unwrap() - .finish() - .await; - let _ = send2 + .unwrap(); + + let _ = request_stream_1.recv_response().await; + + let _ = request_stream_1.finish().await; + + let mut request_stream_2 = send2 .send_request(Request::get("http://no.way").body(()).unwrap()) .await - .unwrap() - .finish() - .await; + .unwrap(); + + let _ = request_stream_2.recv_response().await; + + let _ = request_stream_2.finish().await; + drop(send1); drop(send2); @@ -366,13 +383,27 @@ async fn control_close_send_error() { //# If either control //# stream is closed at any point, this MUST be treated as a connection //# error of type H3_CLOSED_CRITICAL_STREAM. - control_stream.finish().await.unwrap(); // close the client control stream immediately + control_stream.finish().unwrap(); // close the client control stream immediately - let (mut driver, _send) = client::new(h3_quinn::Connection::new(connection)) - .await - .unwrap(); + // create the Connection manually so it does not open a second Control stream + + let connection_error = loop { + let accepted = connection.accept_bi().await; + match accepted { + // do nothing with the stream + Ok(_) => continue, + Err(err) => break err, + } + }; - future::poll_fn(|cx| driver.poll_close(cx)).await + let err_code = match connection_error { + quinn::ConnectionError::ApplicationClosed(quinn::ApplicationClose { + error_code, + .. + }) => error_code.into_inner(), + e => panic!("unexpected error: {:?}", e), + }; + assert_eq!(err_code, Code::H3_CLOSED_CRITICAL_STREAM.value()); }; let server_fut = async { @@ -390,7 +421,7 @@ async fn control_close_send_error() { if *reason == *"control stream closed"); }; - tokio::select! { _ = server_fut => (), _ = client_fut => panic!("client resolved first") }; + tokio::join!(server_fut, client_fut); } #[tokio::test] @@ -508,7 +539,7 @@ async fn goaway_from_server_not_request_id() { let mut buf = BytesMut::new(); StreamType::CONTROL.encode(&mut buf); control_stream.write_all(&buf[..]).await.unwrap(); - control_stream.finish().await.unwrap(); // close the client control stream immediately + control_stream.finish().unwrap(); // close the client control stream immediately let (mut driver, _send) = client::new(h3_quinn::Connection::new(connection)) .await diff --git a/h3/src/tests/mod.rs b/h3/src/tests/mod.rs index 655ab16a..d34d234c 100644 --- a/h3/src/tests/mod.rs +++ b/h3/src/tests/mod.rs @@ -19,7 +19,8 @@ use std::{ }; use bytes::Bytes; -use rustls::{Certificate, PrivateKey}; +use quinn::crypto::rustls::{QuicClientConfig, QuicServerConfig}; +use rustls::pki_types::{CertificateDer, PrivateKeyDer}; use crate::quic; use h3_quinn::{quinn::TransportConfig, Connection}; @@ -32,11 +33,10 @@ pub fn init_tracing() { .try_init(); } -#[derive(Clone)] pub struct Pair { port: u16, - cert: Certificate, - key: PrivateKey, + cert: CertificateDer<'static>, + key: PrivateKeyDer<'static>, config: Arc, } @@ -63,18 +63,20 @@ impl Pair { } pub fn server_inner(&mut self) -> h3_quinn::Endpoint { - let mut crypto = rustls::ServerConfig::builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() - .with_protocol_versions(&[&rustls::version::TLS13]) - .unwrap() - .with_no_client_auth() - .with_single_cert(vec![self.cert.clone()], self.key.clone()) - .unwrap(); + let mut crypto = rustls::ServerConfig::builder_with_provider(Arc::new( + rustls::crypto::ring::default_provider(), + )) + .with_protocol_versions(&[&rustls::version::TLS13]) + .unwrap() + .with_no_client_auth() + .with_single_cert(vec![self.cert.clone()], self.key.clone_key()) + .unwrap(); crypto.max_early_data_size = u32::MAX; crypto.alpn_protocols = vec![b"h3".to_vec()]; - let mut server_config = h3_quinn::quinn::ServerConfig::with_crypto(Arc::new(crypto)); + let mut server_config = h3_quinn::quinn::ServerConfig::with_crypto(Arc::new( + QuicServerConfig::try_from(crypto).unwrap(), + )); server_config.transport = self.config.clone(); let endpoint = h3_quinn::quinn::Endpoint::server(server_config, "[::]:0".parse().unwrap()).unwrap(); @@ -97,18 +99,20 @@ impl Pair { .unwrap(); let mut root_cert_store = rustls::RootCertStore::empty(); - root_cert_store.add(&self.cert).unwrap(); - let mut crypto = rustls::ClientConfig::builder() - .with_safe_default_cipher_suites() - .with_safe_default_kx_groups() - .with_protocol_versions(&[&rustls::version::TLS13]) - .unwrap() - .with_root_certificates(root_cert_store) - .with_no_client_auth(); + root_cert_store.add(self.cert.clone()).unwrap(); + let mut crypto = rustls::ClientConfig::builder_with_provider(Arc::new( + rustls::crypto::ring::default_provider(), + )) + .with_protocol_versions(&[&rustls::version::TLS13]) + .unwrap() + .with_root_certificates(root_cert_store) + .with_no_client_auth(); crypto.enable_early_data = true; crypto.alpn_protocols = vec![b"h3".to_vec()]; - let client_config = h3_quinn::quinn::ClientConfig::new(Arc::new(crypto)); + let client_config = h3_quinn::quinn::ClientConfig::new(Arc::new( + QuicClientConfig::try_from(crypto).unwrap(), + )); let mut client_endpoint = h3_quinn::quinn::Endpoint::client("[::]:0".parse().unwrap()).unwrap(); @@ -135,9 +139,10 @@ impl Server { } } -pub fn build_certs() -> (Certificate, PrivateKey) { +pub fn build_certs() -> (CertificateDer<'static>, PrivateKeyDer<'static>) { let cert = rcgen::generate_simple_self_signed(vec!["localhost".into()]).unwrap(); - let key = PrivateKey(cert.serialize_private_key_der()); - let cert = Certificate(cert.serialize_der().unwrap()); - (cert, key) + ( + cert.cert.into(), + PrivateKeyDer::Pkcs8(cert.key_pair.serialize_der().into()), + ) } diff --git a/h3/src/tests/request.rs b/h3/src/tests/request.rs index 76fbb545..2001b0b3 100644 --- a/h3/src/tests/request.rs +++ b/h3/src/tests/request.rs @@ -69,6 +69,8 @@ async fn get() { .await .expect("send_data"); request_stream.finish().await.expect("finish"); + + let _ = incoming_req.accept().await.unwrap(); }; tokio::join!(server_fut, client_fut); @@ -131,6 +133,8 @@ async fn get_with_trailers_unknown_content_type() { .await .expect("send_trailers"); request_stream.finish().await.expect("finish"); + + let _ = incoming_req.accept().await.unwrap(); }; tokio::join!(server_fut, client_fut); @@ -193,6 +197,8 @@ async fn get_with_trailers_known_content_type() { .await .expect("send_trailers"); request_stream.finish().await.expect("finish"); + + let _ = incoming_req.accept().await.unwrap(); }; tokio::join!(server_fut, client_fut); @@ -246,6 +252,9 @@ async fn post() { .expect("server recv body"); assert_eq!(request_body.chunk(), b"wonderful json"); request_stream.finish().await.expect("client finish"); + + // keep connection until client is finished + let _ = incoming_req.accept().await.expect("accept"); }; tokio::join!(server_fut, client_fut); @@ -328,6 +337,7 @@ async fn header_too_big_response_from_server_trailers() { .await .expect("send trailers"); request_stream.finish().await.expect("client finish"); + let _ = request_stream.recv_response().await; }; tokio::select! {biased; _ = req_fut => (), _ = drive_fut => () } }; @@ -373,7 +383,18 @@ async fn header_too_big_client_error() { let client_fut = async { let (mut driver, mut client) = client::new(pair.client().await).await.expect("client init"); - let drive_fut = async { future::poll_fn(|cx| driver.poll_close(cx)).await }; + let drive_fut = async { + let err = future::poll_fn(|cx| driver.poll_close(cx)) + .await + .unwrap_err(); + match err.kind() { + // The client never sends a data on the request stream + Kind::Application { code, .. } => { + assert_eq!(code, Code::H3_REQUEST_INCOMPLETE) + } + _ => panic!("unexpected error: {:?}", err), + } + }; let req_fut = async { // pretend client already received server's settings client @@ -398,20 +419,19 @@ async fn header_too_big_client_error() { } ); }; - tokio::select! {biased; _ = req_fut => (),_ = drive_fut => () } + tokio::join! {req_fut, drive_fut } }; let server_fut = async { let conn = server.next().await; - //= https://www.rfc-editor.org/rfc/rfc9114#section-4.2.2 - //= type=test - //# An HTTP/3 implementation MAY impose a limit on the maximum size of - //# the message header it will accept on an individual HTTP message. - server::builder() + + let mut incoming_req = server::builder() .max_field_section_size(12) .build(conn) .await .unwrap(); + + let _ = incoming_req.accept().await; }; tokio::join!(server_fut, client_fut); @@ -425,7 +445,15 @@ async fn header_too_big_client_error_trailer() { let client_fut = async { let (mut driver, mut client) = client::new(pair.client().await).await.expect("client init"); - let drive_fut = async { future::poll_fn(|cx| driver.poll_close(cx)).await }; + let drive_fut = async { + let err = future::poll_fn(|cx| driver.poll_close(cx)) + .await + .unwrap_err(); + match err.kind() { + Kind::Timeout => (), + _ => panic!("unexpected error: {:?}", err), + } + }; let req_fut = async { client .shared_state() @@ -461,7 +489,7 @@ async fn header_too_big_client_error_trailer() { request_stream.finish().await.expect("client finish"); }; - tokio::select! {biased; _ = req_fut => (), _ = drive_fut => () } + tokio::join! {req_fut,drive_fut}; }; let server_fut = async { @@ -1406,7 +1434,7 @@ where let mut buf = BytesMut::new(); request(&mut buf); req_send.write_all(&buf[..]).await.unwrap(); - req_send.finish().await.unwrap(); + req_send.finish().unwrap(); let res = req_recv .read(&mut buf)