Skip to content

Commit

Permalink
do not downcast
Browse files Browse the repository at this point in the history
  • Loading branch information
conradludgate committed Apr 17, 2024
1 parent 537f062 commit 102143f
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 270 deletions.
107 changes: 41 additions & 66 deletions proxy/src/serverless.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

mod backend;
mod conn_pool;
mod http_auto;
mod http_util;
mod json;
mod sql_over_http;
Expand All @@ -20,7 +19,8 @@ use futures::TryFutureExt;
use http::{Method, Response, StatusCode};
use http_body_util::Full;
use hyper1::body::Incoming;
use hyper_util::rt::TokioIo;
use hyper_util::rt::TokioExecutor;
use hyper_util::server::conn::auto::Builder;
use rand::rngs::StdRng;
use rand::SeedableRng;
pub use reqwest_middleware::{ClientWithMiddleware, Error};
Expand All @@ -37,7 +37,6 @@ use crate::protocol2::WithClientIp;
use crate::proxy::run_until_cancelled;
use crate::rate_limiter::EndpointRateLimiter;
use crate::serverless::backend::PoolingBackend;
use crate::serverless::http_auto::read_version;
use crate::serverless::http_util::{api_error_into_response, json_response};

use std::net::{IpAddr, SocketAddr};
Expand All @@ -48,8 +47,6 @@ use tokio_util::sync::CancellationToken;
use tracing::{error, info, warn, Instrument};
use utils::http::error::ApiError;

use self::http_auto::Builder;

pub const SERVERLESS_DRIVER_SNI: &str = "api";

pub async fn task_main(
Expand Down Expand Up @@ -103,7 +100,7 @@ pub async fn task_main(
let connections = tokio_util::task::task_tracker::TaskTracker::new();
connections.close(); // allows `connections.wait to complete`

let server = Builder::new();
let server = Builder::new(TokioExecutor::new());

while let Some(res) = run_until_cancelled(ws_listener.accept(), &cancellation_token).await {
let (conn, peer_addr) = res.context("could not accept TCP stream")?;
Expand Down Expand Up @@ -152,7 +149,7 @@ async fn connection_handler(
cancellation_handler: Arc<CancellationHandlerMain>,
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
cancellation_token: CancellationToken,
server: Builder,
server: Builder<TokioExecutor>,
tls_acceptor: TlsAcceptor,
conn: TcpStream,
peer_addr: SocketAddr,
Expand Down Expand Up @@ -211,69 +208,47 @@ async fn connection_handler(
let http_cancellation_token = CancellationToken::new();
let _cancel_connection = http_cancellation_token.clone().drop_guard();

let Some(conn) = run_until_cancelled(read_version(conn), &cancellation_token).await else {
return;
};
let (version, rewind) = match conn {
Ok(d) => d,
Err(e) => {
tracing::warn!(%peer_addr, "HTTP connection error {e}");
return;
}
};
let service = hyper1::service::service_fn(move |req: hyper1::Request<Incoming>| {
// First HTTP request shares the same session ID
let session_id = session_id.take().unwrap_or_else(uuid::Uuid::new_v4);

// Cancel the current inflight HTTP request if the requets stream is closed.
// This is slightly different to `_cancel_connection` in that
// h2 can cancel individual requests with a `RST_STREAM`.
let http_request_token = http_cancellation_token.child_token();
let cancel_request = http_request_token.clone().drop_guard();

// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
// By spawning the future, we ensure it never gets cancelled until it decides to.
let handler = connections.spawn(
request_handler(
req,
config,
backend.clone(),
connections.clone(),
cancellation_handler.clone(),
session_id,
peer_addr,
endpoint_rate_limiter.clone(),
http_request_token,
)
.in_current_span()
.map_ok_or_else(api_error_into_response, |r| r),
);

async move {
let res = handler.await;
cancel_request.disarm();
res
}
});
let conn = match version {
http_auto::Version::H1 => Either::Left(
server
.http1
.serve_connection(TokioIo::new(rewind), service)
.with_upgrades(),
),
http_auto::Version::H2 => {
Either::Right(server.http2.serve_connection(TokioIo::new(rewind), service))
}
};
let conn = server.serve_connection_with_upgrades(
hyper_util::rt::TokioIo::new(conn),
hyper1::service::service_fn(move |req: hyper1::Request<Incoming>| {
// First HTTP request shares the same session ID
let session_id = session_id.take().unwrap_or_else(uuid::Uuid::new_v4);

// Cancel the current inflight HTTP request if the requets stream is closed.
// This is slightly different to `_cancel_connection` in that
// h2 can cancel individual requests with a `RST_STREAM`.
let http_request_token = http_cancellation_token.child_token();
let cancel_request = http_request_token.clone().drop_guard();

// `request_handler` is not cancel safe. It expects to be cancelled only at specific times.
// By spawning the future, we ensure it never gets cancelled until it decides to.
let handler = connections.spawn(
request_handler(
req,
config,
backend.clone(),
connections.clone(),
cancellation_handler.clone(),
session_id,
peer_addr,
endpoint_rate_limiter.clone(),
http_request_token,
)
.in_current_span()
.map_ok_or_else(api_error_into_response, |r| r),
);
async move {
let res = handler.await;
cancel_request.disarm();
res
}
}),
);

// On cancellation, trigger the HTTP connection handler to shut down.
let res = match select(pin!(cancellation_token.cancelled()), pin!(conn)).await {
Either::Left((_cancelled, mut conn)) => {
match conn.as_mut().as_pin_mut() {
Either::Left(h1) => h1.graceful_shutdown(),
Either::Right(h2) => h2.graceful_shutdown(),
}
conn.as_mut().graceful_shutdown();
conn.await
}
Either::Right((res, _)) => res,
Expand Down
184 changes: 0 additions & 184 deletions proxy/src/serverless/http_auto.rs

This file was deleted.

23 changes: 3 additions & 20 deletions proxy/src/serverless/websocket.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,24 @@ use crate::{
context::RequestMonitoring,
error::{io_error, ReportableError},
metrics::Metrics,
protocol2::WithClientIp,
proxy::{handle_client, ClientMode},
rate_limiter::EndpointRateLimiter,
};
use bytes::{Buf, BufMut, Bytes, BytesMut};
use bytes::{Buf, Bytes};
use framed_websockets::{Frame, OpCode, WebSocketServer};
use futures::{Sink, Stream};
use hyper1::upgrade::OnUpgrade;
use hyper_util::rt::TokioIo;
use pin_project_lite::pin_project;
use tokio_rustls::server::TlsStream;

use std::{
pin::Pin,
sync::Arc,
task::{ready, Context, Poll},
};
use tokio::{
io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf},
net::TcpStream,
};
use tokio::io::{self, AsyncBufRead, AsyncRead, AsyncWrite, ReadBuf};
use tracing::warn;

use super::http_auto::Rewind;

pin_project! {
/// This is a wrapper around a [`WebSocketStream`] that
/// implements [`AsyncRead`] and [`AsyncWrite`].
Expand Down Expand Up @@ -139,17 +132,7 @@ pub async fn serve_websocket(
endpoint_rate_limiter: Arc<EndpointRateLimiter>,
) -> anyhow::Result<()> {
let websocket = websocket.await?;
let websocket = websocket
.downcast::<TokioIo<Rewind<TlsStream<WithClientIp<TcpStream>>>>>()
.expect("downcast error");

let pre0 = websocket.read_buf;
let (pre1, inner) = websocket.io.into_inner().into_inner();

let mut buf = BytesMut::with_capacity(8192);
buf.put(pre0);
buf.put(pre1);
let websocket = WebSocketServer::after_handshake_with_bytes(inner, buf);
let websocket = WebSocketServer::after_handshake(TokioIo::new(websocket));

let conn_gauge = Metrics::get()
.proxy
Expand Down

0 comments on commit 102143f

Please sign in to comment.