Skip to content

Commit

Permalink
Revert "feat: use DuplexStream instead of UnixStream to communica…
Browse files Browse the repository at this point in the history
…te with workers (supabase#320)"

This reverts commit 4e53e2a.

# Conflicts:
#	crates/base/src/server.rs
#	crates/cli/src/main.rs
  • Loading branch information
nyannyacha committed Apr 22, 2024
1 parent 6bacace commit 04a32f7
Show file tree
Hide file tree
Showing 8 changed files with 113 additions and 167 deletions.
24 changes: 12 additions & 12 deletions crates/base/src/deno_runtime.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::inspector_server::Inspector;
use crate::rt_worker::rt;
use crate::rt_worker::supervisor::{CPUUsage, CPUUsageMetrics};
use crate::rt_worker::worker::DuplexStreamEntry;
use crate::rt_worker::worker::UnixStreamEntry;
use crate::utils::units::{bytes_to_display, mib_to_bytes};

use anyhow::{anyhow, bail, Context, Error};
Expand Down Expand Up @@ -33,6 +33,7 @@ use std::borrow::Cow;
use std::collections::HashMap;
use std::ffi::c_void;
use std::fmt;
use std::os::fd::RawFd;
use std::sync::Arc;
use std::task::Poll;
use std::time::Duration;
Expand Down Expand Up @@ -524,7 +525,7 @@ impl DenoRuntime {
}

if conf.is_main_worker() || conf.is_user_worker() {
op_state.put::<HashMap<usize, watch::Receiver<ConnSync>>>(HashMap::new());
op_state.put::<HashMap<RawFd, watch::Receiver<ConnSync>>>(HashMap::new());
}

if conf.is_user_worker() {
Expand Down Expand Up @@ -595,15 +596,14 @@ impl DenoRuntime {

pub async fn run(
&mut self,
duplex_stream_rx: mpsc::UnboundedReceiver<DuplexStreamEntry>,
unix_stream_rx: mpsc::UnboundedReceiver<UnixStreamEntry>,
maybe_cpu_usage_metrics_tx: Option<mpsc::UnboundedSender<CPUUsageMetrics>>,
name: Option<String>,
) -> (Result<(), Error>, i64) {
{
let op_state_rc = self.js_runtime.op_state();
let mut op_state = op_state_rc.borrow_mut();

op_state.put::<mpsc::UnboundedReceiver<DuplexStreamEntry>>(duplex_stream_rx);
op_state.put::<mpsc::UnboundedReceiver<UnixStreamEntry>>(unix_stream_rx);

if self.conf.is_main_worker() {
op_state.put::<mpsc::UnboundedSender<UserWorkerMsgs>>(
Expand Down Expand Up @@ -887,7 +887,7 @@ extern "C" fn mem_check_gc_prologue_callback_fn(
#[cfg(test)]
mod test {
use crate::deno_runtime::DenoRuntime;
use crate::rt_worker::worker::DuplexStreamEntry;
use crate::rt_worker::worker::UnixStreamEntry;
use deno_core::{FastString, ModuleCodeString, PollEventLoopOptions};
use sb_graph::emitter::EmitterFactory;
use sb_graph::{generate_binary_eszip, EszipPayloadKind};
Expand Down Expand Up @@ -1467,8 +1467,8 @@ mod test {
let mut user_rt =
create_basic_user_runtime("./test_cases/array_buffers", 20, 1000, &[]).await;

let (_tx, duplex_stream_rx) = mpsc::unbounded_channel::<DuplexStreamEntry>();
let (result, _) = user_rt.run(duplex_stream_rx, None, None).await;
let (_tx, unix_stream_rx) = mpsc::unbounded_channel::<UnixStreamEntry>();
let (result, _) = user_rt.run(unix_stream_rx, None, None).await;

assert!(result.is_ok(), "expected no errors");

Expand All @@ -1482,8 +1482,8 @@ mod test {
let mut user_rt =
create_basic_user_runtime("./test_cases/array_buffers", 15, 1000, &[]).await;

let (_tx, duplex_stream_rx) = mpsc::unbounded_channel::<DuplexStreamEntry>();
let (result, _) = user_rt.run(duplex_stream_rx, None, None).await;
let (_tx, unix_stream_rx) = mpsc::unbounded_channel::<UnixStreamEntry>();
let (result, _) = user_rt.run(unix_stream_rx, None, None).await;

match result {
Err(err) => {
Expand All @@ -1501,7 +1501,7 @@ mod test {
memory_limit_mb: u64,
worker_timeout_ms: u64,
) {
let (_duplex_stream_tx, duplex_stream_rx) = mpsc::unbounded_channel::<DuplexStreamEntry>();
let (_unix_stream_tx, unix_stream_rx) = mpsc::unbounded_channel::<UnixStreamEntry>();
let (callback_tx, mut callback_rx) = mpsc::unbounded_channel::<()>();
let mut user_rt =
create_basic_user_runtime(path, memory_limit_mb, worker_timeout_ms, static_patterns)
Expand All @@ -1518,7 +1518,7 @@ mod test {
});

let wait_fut = async move {
let (result, _) = user_rt.run(duplex_stream_rx, None, None).await;
let (result, _) = user_rt.run(unix_stream_rx, None, None).await;

assert_eq!(
result.unwrap_err().to_string(),
Expand Down
6 changes: 3 additions & 3 deletions crates/base/src/rt_worker/implementation/default_handler.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::deno_runtime::DenoRuntime;
use crate::rt_worker::supervisor::CPUUsageMetrics;
use crate::rt_worker::worker::{DuplexStreamEntry, HandleCreationType, Worker, WorkerHandler};
use crate::rt_worker::worker::{HandleCreationType, UnixStreamEntry, Worker, WorkerHandler};
use anyhow::Error;
use event_worker::events::{BootFailureEvent, PseudoEvent, UncaughtExceptionEvent, WorkerEvents};
use log::error;
Expand All @@ -19,14 +19,14 @@ impl WorkerHandler for Worker {
fn handle_creation<'r>(
&self,
created_rt: &'r mut DenoRuntime,
duplex_stream_rx: UnboundedReceiver<DuplexStreamEntry>,
unix_stream_rx: UnboundedReceiver<UnixStreamEntry>,
termination_event_rx: Receiver<WorkerEvents>,
maybe_cpu_usage_metrics_tx: Option<UnboundedSender<CPUUsageMetrics>>,
name: Option<String>,
) -> HandleCreationType<'r> {
let run_worker_rt = async move {
match created_rt
.run(duplex_stream_rx, maybe_cpu_usage_metrics_tx, name)
.run(unix_stream_rx, maybe_cpu_usage_metrics_tx, name)
.await
{
// if the error is execution terminated, check termination event reason
Expand Down
18 changes: 9 additions & 9 deletions crates/base/src/rt_worker/worker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use sb_workers::context::{UserWorkerMsgs, WorkerContextInitOpts};
use std::any::Any;
use std::future::{pending, Future};
use std::pin::Pin;
use tokio::io;
use tokio::net::UnixStream;
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
use tokio::sync::oneshot::{Receiver, Sender};
use tokio::sync::{oneshot, watch};
Expand All @@ -44,14 +44,14 @@ pub struct Worker {
}

pub type HandleCreationType<'r> = Pin<Box<dyn Future<Output = Result<WorkerEvents, Error>> + 'r>>;
pub type DuplexStreamEntry = (io::DuplexStream, Option<watch::Receiver<ConnSync>>);
pub type UnixStreamEntry = (UnixStream, Option<watch::Receiver<ConnSync>>);

pub trait WorkerHandler: Send {
fn handle_error(&self, error: Error) -> Result<WorkerEvents, Error>;
fn handle_creation<'r>(
&self,
created_rt: &'r mut DenoRuntime,
duplex_stream_rx: UnboundedReceiver<DuplexStreamEntry>,
unix_stream_rx: UnboundedReceiver<UnixStreamEntry>,
termination_event_rx: Receiver<WorkerEvents>,
maybe_cpu_metrics_tx: Option<UnboundedSender<CPUUsageMetrics>>,
name: Option<String>,
Expand Down Expand Up @@ -91,9 +91,9 @@ impl Worker {
pub fn start(
&self,
mut opts: WorkerContextInitOpts,
duplex_stream_pair: (
UnboundedSender<DuplexStreamEntry>,
UnboundedReceiver<DuplexStreamEntry>,
unix_stream_pair: (
UnboundedSender<UnixStreamEntry>,
UnboundedReceiver<UnixStreamEntry>,
),
booter_signal: Sender<Result<MetricSource, Error>>,
termination_token: Option<TerminationToken>,
Expand All @@ -104,7 +104,7 @@ impl Worker {
let event_metadata = self.event_metadata.clone();
let supervisor_policy = self.supervisor_policy;

let (duplex_stream_tx, duplex_stream_rx) = duplex_stream_pair;
let (unix_stream_tx, unix_stream_rx) = unix_stream_pair;
let events_msg_tx = self.events_msg_tx.clone();
let pool_msg_tx = self.pool_msg_tx.clone();

Expand Down Expand Up @@ -244,7 +244,7 @@ impl Worker {
let result = method_cloner
.handle_creation(
&mut runtime,
duplex_stream_rx,
unix_stream_rx,
termination_event_rx,
maybe_cpu_usage_metrics_tx,
Some(worker_name),
Expand Down Expand Up @@ -283,7 +283,7 @@ impl Worker {
}
};

drop(duplex_stream_tx);
drop(unix_stream_tx);

match result {
Ok(event) => {
Expand Down
31 changes: 17 additions & 14 deletions crates/base/src/rt_worker/worker_ctx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ use std::io::ErrorKind;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
use tokio::io::{self, copy_bidirectional};
use tokio::net::TcpStream;
use tokio::io::copy_bidirectional;
use tokio::net::{TcpStream, UnixStream};
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use tokio::sync::{mpsc, oneshot, watch, Mutex};
use tokio_rustls::server::TlsStream;
Expand All @@ -42,7 +42,7 @@ use uuid::Uuid;

use super::rt;
use super::supervisor::{self, CPUTimerParam, CPUUsageMetrics};
use super::worker::DuplexStreamEntry;
use super::worker::UnixStreamEntry;
use super::worker_pool::{SupervisorPolicy, WorkerPoolPolicy};

#[derive(Clone)]
Expand Down Expand Up @@ -93,25 +93,28 @@ impl TerminationToken {
}

async fn handle_request(
duplex_stream_tx: mpsc::UnboundedSender<DuplexStreamEntry>,
unix_stream_tx: mpsc::UnboundedSender<UnixStreamEntry>,
msg: WorkerRequestMsg,
) -> Result<(), Error> {
let (ours, theirs) = io::duplex(1024);
// create a unix socket pair
let (sender_stream, recv_stream) = UnixStream::pair()?;
let WorkerRequestMsg {
mut req,
res_tx,
conn_watch,
} = msg;

let _ = duplex_stream_tx.send((theirs, conn_watch.clone()));
let _ = unix_stream_tx.send((recv_stream, conn_watch.clone()));
let req_upgrade_type = get_upgrade_type(req.headers());
let req_upgrade = req_upgrade_type
.clone()
.and_then(|it| Some(it).zip(req.extensions_mut().remove::<OnUpgrade>()));

// send the HTTP request to the worker over duplex stream
let (mut request_sender, connection) =
http1::Builder::new().writev(true).handshake(ours).await?;
// send the HTTP request to the worker over Unix stream
let (mut request_sender, connection) = http1::Builder::new()
.writev(true)
.handshake(sender_stream)
.await?;

let (upgrade_tx, upgrade_rx) = oneshot::channel();

Expand Down Expand Up @@ -174,7 +177,7 @@ async fn handle_request(

async fn relay_upgraded_request_and_response(
downstream: OnUpgrade,
parts: http1::Parts<io::DuplexStream>,
parts: http1::Parts<UnixStream>,
) {
let mut upstream = Upgraded2::new(parts.io, parts.read_buf);
let mut downstream = downstream.await.expect("failed to upgrade request");
Expand All @@ -187,7 +190,7 @@ async fn relay_upgraded_request_and_response(
// `close_notify` before shutdown an upstream if downstream is a
// TLS stream.

// INVARIANT: `UnexpectedEof` due to shutdown `DuplexStream` is
// INVARIANT: `UnexpectedEof` due to shutdown `UnixStream` is
// only expected to occur in the context of `TlsStream`.
panic!("unhandleable unexpected eof");
};
Expand Down Expand Up @@ -513,7 +516,7 @@ pub async fn create_worker<Opt: Into<CreateWorkerArgs>>(
init_opts: Opt,
inspector: Option<Inspector>,
) -> Result<(MetricSource, mpsc::UnboundedSender<WorkerRequestMsg>), Error> {
let (duplex_stream_tx, duplex_stream_rx) = mpsc::unbounded_channel::<DuplexStreamEntry>();
let (unix_stream_tx, unix_stream_rx) = mpsc::unbounded_channel::<UnixStreamEntry>();
let (worker_boot_result_tx, worker_boot_result_rx) =
oneshot::channel::<Result<MetricSource, Error>>();

Expand All @@ -536,7 +539,7 @@ pub async fn create_worker<Opt: Into<CreateWorkerArgs>>(
if let Some(worker_struct_ref) = downcast_reference {
worker_struct_ref.start(
init_opts,
(duplex_stream_tx.clone(), duplex_stream_rx),
(unix_stream_tx.clone(), unix_stream_rx),
worker_boot_result_tx,
maybe_termination_token.clone(),
inspector,
Expand All @@ -546,7 +549,7 @@ pub async fn create_worker<Opt: Into<CreateWorkerArgs>>(
let (worker_req_tx, mut worker_req_rx) = mpsc::unbounded_channel::<WorkerRequestMsg>();

let worker_req_handle: tokio::task::JoinHandle<Result<(), Error>> = tokio::task::spawn({
let stream_tx = duplex_stream_tx;
let stream_tx = unix_stream_tx;
async move {
while let Some(msg) = worker_req_rx.recv().await {
tokio::task::spawn({
Expand Down
49 changes: 12 additions & 37 deletions crates/sb_core/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,14 @@ use futures::Future;
use hyper::upgrade::{OnUpgrade, Parts};
use log::error;
use serde::Serialize;
use tokio::io::{copy_bidirectional, AsyncReadExt, AsyncWriteExt, DuplexStream};
use tokio::{
io::{AsyncRead, AsyncWrite},
net::UnixStream,
sync::{oneshot, watch},
};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::UnixStream,
};

use crate::{
conn_sync::ConnSync,
Expand All @@ -35,27 +37,15 @@ deno_core::extension!(
middleware = sb_http_middleware,
);

pub(crate) struct Stream2<S>(S, Option<watch::Receiver<ConnSync>>);
pub(crate) struct UnixStream2(UnixStream, Option<watch::Receiver<ConnSync>>);

impl<S> Stream2<S>
where
S: AsyncWrite + AsyncRead + Unpin,
{
pub fn new(stream: S, watcher: Option<watch::Receiver<ConnSync>>) -> Self {
impl UnixStream2 {
pub fn new(stream: UnixStream, watcher: Option<watch::Receiver<ConnSync>>) -> Self {
Self(stream, watcher)
}
}

impl<S> Stream2<S> {
fn into_inner(self) -> (S, Option<watch::Receiver<ConnSync>>) {
(self.0, self.1)
}
}

impl<S> AsyncRead for Stream2<S>
where
S: AsyncRead + Unpin,
{
impl AsyncRead for UnixStream2 {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
Expand All @@ -65,10 +55,7 @@ where
}
}

impl<S> AsyncWrite for Stream2<S>
where
S: AsyncWrite + Unpin,
{
impl AsyncWrite for UnixStream2 {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
Expand Down Expand Up @@ -119,9 +106,6 @@ where
}
}

pub(crate) type DuplexStream2 = Stream2<DuplexStream>;
pub(crate) type UnixStream2 = Stream2<UnixStream>;

fn http_error(message: &'static str) -> AnyError {
custom_error("Http", message)
}
Expand All @@ -145,19 +129,10 @@ async fn op_http_upgrade_websocket2(
};

let upgraded = hyper::upgrade::on(request).await?;
let Parts { io, read_buf, .. } = upgraded.downcast::<DuplexStream2>().unwrap();
let (mut rw, conn_sync) = io.into_inner();

// NOTE(Nyannyacha): We use `UnixStream` out of necessity here because
// `ws_create_server_stream` only supports network stream types.
let (ours, theirs) = UnixStream::pair()?;

tokio::spawn(async move {
let mut theirs = UnixStream2::new(theirs, conn_sync);
let _ = copy_bidirectional(&mut rw, &mut theirs).await;
});
let Parts { io, read_buf, .. } = upgraded.downcast::<UnixStream2>().unwrap();

ws_create_server_stream(&mut state.borrow_mut(), ours.into(), read_buf)
let ws_rid = ws_create_server_stream(&mut state.borrow_mut(), io.0.into(), read_buf)?;
Ok(ws_rid)
}

#[op2]
Expand Down
Loading

0 comments on commit 04a32f7

Please sign in to comment.