diff --git a/apps/app/src/api/oauth_utils/auth_code_reply.rs b/apps/app/src/api/oauth_utils/auth_code_reply.rs index 4e4a529284..fedffcb09d 100644 --- a/apps/app/src/api/oauth_utils/auth_code_reply.rs +++ b/apps/app/src/api/oauth_utils/auth_code_reply.rs @@ -11,7 +11,7 @@ //! [RFC 8252]: https://datatracker.ietf.org/doc/html/rfc8252 use std::{ - net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + net::SocketAddr, sync::{LazyLock, Mutex}, time::Duration, }; @@ -19,10 +19,8 @@ use std::{ use hyper::body::Incoming; use hyper_util::rt::{TokioIo, TokioTimer}; use theseus::ErrorKind; -use tokio::{ - net::TcpListener, - sync::{broadcast, oneshot}, -}; +use theseus::prelude::tcp_listen_any_loopback; +use tokio::sync::{broadcast, oneshot}; static SERVER_SHUTDOWN: LazyLock> = LazyLock::new(|| broadcast::channel(1024).0); @@ -35,17 +33,7 @@ static SERVER_SHUTDOWN: LazyLock> = pub async fn listen( listen_socket_tx: oneshot::Sender>, ) -> Result, theseus::Error> { - // IPv4 is tried first for the best compatibility and performance with most systems. - // IPv6 is also tried in case IPv4 is not available. Resolving "localhost" is avoided - // to prevent failures deriving from improper name resolution setup. Any available - // ephemeral port is used to prevent conflicts with other services. This is all as per - // RFC 8252's recommendations - const ANY_LOOPBACK_SOCKET: &[SocketAddr] = &[ - SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), - SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 0), - ]; - - let listener = match TcpListener::bind(ANY_LOOPBACK_SOCKET).await { + let listener = match tcp_listen_any_loopback().await { Ok(listener) => { listen_socket_tx .send(listener.local_addr().map_err(|e| { diff --git a/packages/app-lib/build.rs b/packages/app-lib/build.rs index 48da4b4570..10ed29b99f 100644 --- a/packages/app-lib/build.rs +++ b/packages/app-lib/build.rs @@ -53,7 +53,6 @@ fn build_java_jars() { .arg("build") .arg("--no-daemon") .arg("--console=rich") - .arg("--info") .current_dir(dunce::canonicalize("java").unwrap()) .status() .expect("Failed to wait on Gradle build"); diff --git a/packages/app-lib/java/build.gradle.kts b/packages/app-lib/java/build.gradle.kts index a671dd6f95..98c95c8c91 100644 --- a/packages/app-lib/java/build.gradle.kts +++ b/packages/app-lib/java/build.gradle.kts @@ -11,6 +11,7 @@ repositories { dependencies { implementation("org.ow2.asm:asm:9.8") implementation("org.ow2.asm:asm-tree:9.8") + implementation("com.google.code.gson:gson:2.13.1") testImplementation(libs.junit.jupiter) testRuntimeOnly("org.junit.platform:junit-platform-launcher") diff --git a/packages/app-lib/java/src/main/java/com/modrinth/theseus/MinecraftLaunch.java b/packages/app-lib/java/src/main/java/com/modrinth/theseus/MinecraftLaunch.java index 9d61a0c0be..b474ba02c9 100644 --- a/packages/app-lib/java/src/main/java/com/modrinth/theseus/MinecraftLaunch.java +++ b/packages/app-lib/java/src/main/java/com/modrinth/theseus/MinecraftLaunch.java @@ -1,11 +1,13 @@ package com.modrinth.theseus; -import java.io.ByteArrayOutputStream; +import com.modrinth.theseus.rpc.RpcHandlers; +import com.modrinth.theseus.rpc.TheseusRpc; import java.io.IOException; import java.lang.reflect.AccessibleObject; import java.lang.reflect.Method; import java.lang.reflect.Modifier; import java.util.Arrays; +import java.util.concurrent.CompletableFuture; public final class MinecraftLaunch { public static void main(String[] args) throws IOException, ReflectiveOperationException { @@ -13,43 +15,17 @@ public static void main(String[] args) throws IOException, ReflectiveOperationEx final String[] gameArgs = Arrays.copyOfRange(args, 1, args.length); System.setProperty("modrinth.process.args", String.join("\u001f", gameArgs)); - parseInput(); - relaunch(mainClass, gameArgs); - } - - private static void parseInput() throws IOException { - final ByteArrayOutputStream line = new ByteArrayOutputStream(); - while (true) { - final int b = System.in.read(); - if (b < 0) { - throw new IllegalStateException("Stdin terminated while parsing"); - } - if (b != '\n') { - line.write(b); - continue; - } - if (handleLine(line.toString("UTF-8"))) { - break; - } - line.reset(); - } - } - - private static boolean handleLine(String line) { - final String[] parts = line.split("\t", 2); - switch (parts[0]) { - case "property": { - final String[] keyValue = parts[1].split("\t", 2); - System.setProperty(keyValue[0], keyValue[1]); - return false; - } - case "launch": - return true; - } + final CompletableFuture waitForLaunch = new CompletableFuture<>(); + TheseusRpc.connectAndStart( + System.getProperty("modrinth.internal.ipc.host"), + Integer.getInteger("modrinth.internal.ipc.port"), + new RpcHandlers() + .handler("set_system_property", String.class, String.class, System::setProperty) + .handler("launch", () -> waitForLaunch.complete(null))); - System.err.println("Unknown input line " + line); - return false; + waitForLaunch.join(); + relaunch(mainClass, gameArgs); } private static void relaunch(String mainClassName, String[] args) throws ReflectiveOperationException { diff --git a/packages/app-lib/java/src/main/java/com/modrinth/theseus/rpc/RpcHandlers.java b/packages/app-lib/java/src/main/java/com/modrinth/theseus/rpc/RpcHandlers.java new file mode 100644 index 0000000000..257148ef51 --- /dev/null +++ b/packages/app-lib/java/src/main/java/com/modrinth/theseus/rpc/RpcHandlers.java @@ -0,0 +1,46 @@ +package com.modrinth.theseus.rpc; + +import com.google.gson.JsonElement; +import com.google.gson.JsonNull; +import java.util.HashMap; +import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.Function; + +public class RpcHandlers { + private final Map> handlers = new HashMap<>(); + private boolean frozen; + + public RpcHandlers handler(String functionName, Runnable handler) { + return addHandler(functionName, args -> { + handler.run(); + return JsonNull.INSTANCE; + }); + } + + public RpcHandlers handler( + String functionName, Class arg1Type, Class arg2Type, BiConsumer handler) { + return addHandler(functionName, args -> { + if (args.length != 2) { + throw new IllegalArgumentException(functionName + " expected 2 arguments"); + } + final A arg1 = TheseusRpc.GSON.fromJson(args[0], arg1Type); + final B arg2 = TheseusRpc.GSON.fromJson(args[1], arg2Type); + handler.accept(arg1, arg2); + return JsonNull.INSTANCE; + }); + } + + private RpcHandlers addHandler(String functionName, Function handler) { + if (frozen) { + throw new IllegalStateException("Cannot add handler to frozen RpcHandlers instance"); + } + handlers.put(functionName, handler); + return this; + } + + Map> build() { + frozen = true; + return handlers; + } +} diff --git a/packages/app-lib/java/src/main/java/com/modrinth/theseus/rpc/RpcMethodException.java b/packages/app-lib/java/src/main/java/com/modrinth/theseus/rpc/RpcMethodException.java new file mode 100644 index 0000000000..f9ab75a35e --- /dev/null +++ b/packages/app-lib/java/src/main/java/com/modrinth/theseus/rpc/RpcMethodException.java @@ -0,0 +1,9 @@ +package com.modrinth.theseus.rpc; + +public class RpcMethodException extends RuntimeException { + private static final long serialVersionUID = 1922360184188807964L; + + public RpcMethodException(String message) { + super(message); + } +} diff --git a/packages/app-lib/java/src/main/java/com/modrinth/theseus/rpc/TheseusRpc.java b/packages/app-lib/java/src/main/java/com/modrinth/theseus/rpc/TheseusRpc.java new file mode 100644 index 0000000000..ff460ff894 --- /dev/null +++ b/packages/app-lib/java/src/main/java/com/modrinth/theseus/rpc/TheseusRpc.java @@ -0,0 +1,183 @@ +package com.modrinth.theseus.rpc; + +import com.google.gson.*; +import com.google.gson.reflect.TypeToken; +import java.io.*; +import java.net.Socket; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; + +public final class TheseusRpc { + static final Gson GSON = new GsonBuilder() + .setStrictness(Strictness.STRICT) + .setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES) + .disableHtmlEscaping() + .create(); + private static final TypeToken MESSAGE_TYPE = TypeToken.get(RpcMessage.class); + + private static final AtomicReference RPC = new AtomicReference<>(); + + private final BlockingQueue mainThreadQueue = new LinkedBlockingQueue<>(); + private final Map> awaitingResponse = new ConcurrentHashMap<>(); + private final Map> handlers; + private final Socket socket; + + private TheseusRpc(Socket socket, RpcHandlers handlers) { + this.socket = socket; + this.handlers = handlers.build(); + } + + public static void connectAndStart(String host, int port, RpcHandlers handlers) throws IOException { + if (RPC.get() != null) { + throw new IllegalStateException("Can only connect to RPC once"); + } + + final Socket socket = new Socket(host, port); + final TheseusRpc rpc = new TheseusRpc(socket, handlers); + final Thread mainThread = new Thread(rpc::mainThread, "Theseus RPC Main"); + final Thread readThread = new Thread(rpc::readThread, "Theseus RPC Read"); + mainThread.setDaemon(true); + readThread.setDaemon(true); + mainThread.start(); + readThread.start(); + RPC.set(rpc); + } + + public static TheseusRpc getRpc() { + final TheseusRpc rpc = RPC.get(); + if (rpc == null) { + throw new IllegalStateException("Called getRpc before RPC initialized"); + } + return rpc; + } + + public CompletableFuture callMethod(TypeToken returnType, String method, Object... args) { + final JsonElement[] jsonArgs = new JsonElement[args.length]; + for (int i = 0; i < args.length; i++) { + jsonArgs[i] = GSON.toJsonTree(args[i]); + } + + final RpcMessage message = new RpcMessage(method, jsonArgs); + final ResponseWaiter responseWaiter = new ResponseWaiter<>(returnType); + awaitingResponse.put(message.id, responseWaiter); + mainThreadQueue.add(message); + return responseWaiter.future; + } + + private void mainThread() { + try { + final Writer writer = new OutputStreamWriter(socket.getOutputStream(), StandardCharsets.UTF_8); + while (true) { + final RpcMessage message = mainThreadQueue.take(); + final RpcMessage toSend; + if (message.isForSending) { + toSend = message; + } else { + final Function handler = handlers.get(message.method); + if (handler == null) { + System.err.println("Unknown theseus RPC method " + message.method); + continue; + } + RpcMessage response; + try { + response = new RpcMessage(message.id, handler.apply(message.args)); + } catch (Exception e) { + response = new RpcMessage(message.id, e.toString()); + } + toSend = response; + } + GSON.toJson(toSend, writer); + writer.write('\n'); + writer.flush(); + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } catch (InterruptedException ignored) { + } + } + + private void readThread() { + try { + final BufferedReader reader = + new BufferedReader(new InputStreamReader(socket.getInputStream(), StandardCharsets.UTF_8)); + while (true) { + final RpcMessage message = GSON.fromJson(reader.readLine(), MESSAGE_TYPE); + if (message.method == null) { + final ResponseWaiter waiter = awaitingResponse.get(message.id); + if (waiter != null) { + handleResponse(waiter, message); + } + } else { + mainThreadQueue.put(message); + } + } + } catch (IOException e) { + throw new UncheckedIOException(e); + } catch (InterruptedException ignored) { + } + } + + private void handleResponse(ResponseWaiter waiter, RpcMessage message) { + if (message.error != null) { + waiter.future.completeExceptionally(new RpcMethodException(message.error)); + return; + } + try { + waiter.future.complete(GSON.fromJson(message.response, waiter.type)); + } catch (JsonSyntaxException e) { + waiter.future.completeExceptionally(e); + } + } + + private static class RpcMessage { + final UUID id; + final String method; // Optional + final JsonElement[] args; // Optional + final JsonElement response; // Optional + final String error; // Optional + final transient boolean isForSending; + + RpcMessage(String method, JsonElement[] args) { + id = UUID.randomUUID(); + this.method = method; + this.args = args; + response = null; + error = null; + isForSending = true; + } + + RpcMessage(UUID id, JsonElement response) { + this.id = id; + method = null; + args = null; + this.response = response; + error = null; + isForSending = true; + } + + RpcMessage(UUID id, String error) { + this.id = id; + method = null; + args = null; + response = null; + this.error = error; + isForSending = true; + } + } + + private static class ResponseWaiter { + final TypeToken type; + final CompletableFuture future = new CompletableFuture<>(); + + ResponseWaiter(TypeToken type) { + this.type = type; + } + } +} diff --git a/packages/app-lib/src/api/mod.rs b/packages/app-lib/src/api/mod.rs index b173d035fd..020afbe49a 100644 --- a/packages/app-lib/src/api/mod.rs +++ b/packages/app-lib/src/api/mod.rs @@ -35,6 +35,9 @@ pub mod prelude { jre, metadata, minecraft_auth, mr_auth, pack, process, profile::{self, Profile, create}, settings, - util::io::{IOError, canonicalize}, + util::{ + io::{IOError, canonicalize}, + network::tcp_listen_any_loopback, + }, }; } diff --git a/packages/app-lib/src/error.rs b/packages/app-lib/src/error.rs index 75c144f554..773d55daa1 100644 --- a/packages/app-lib/src/error.rs +++ b/packages/app-lib/src/error.rs @@ -151,6 +151,9 @@ pub enum ErrorKind { "A skin texture must have a dimension of either 64x64 or 64x32 pixels" )] InvalidSkinTexture, + + #[error("RPC error: {0}")] + RpcError(String), } #[derive(Debug)] diff --git a/packages/app-lib/src/launcher/args.rs b/packages/app-lib/src/launcher/args.rs index 350d67c0b4..e2093f6185 100644 --- a/packages/app-lib/src/launcher/args.rs +++ b/packages/app-lib/src/launcher/args.rs @@ -16,6 +16,7 @@ use daedalus::{ use dunce::canonicalize; use hashlink::LinkedHashSet; use std::io::{BufRead, BufReader}; +use std::net::SocketAddr; use std::{collections::HashMap, path::Path}; use uuid::Uuid; @@ -124,6 +125,7 @@ pub fn get_jvm_arguments( quick_play_type: &QuickPlayType, quick_play_version: QuickPlayVersion, log_config: Option<&LoggingConfiguration>, + ipc_addr: SocketAddr, ) -> crate::Result> { let mut parsed_arguments = Vec::new(); @@ -181,6 +183,11 @@ pub fn get_jvm_arguments( .to_string_lossy() )); + parsed_arguments + .push(format!("-Dmodrinth.internal.ipc.host={}", ipc_addr.ip())); + parsed_arguments + .push(format!("-Dmodrinth.internal.ipc.port={}", ipc_addr.port())); + parsed_arguments.push(format!( "-Dmodrinth.internal.quickPlay.serverVersion={}", serde_json::to_value(quick_play_version.server)? diff --git a/packages/app-lib/src/launcher/mod.rs b/packages/app-lib/src/launcher/mod.rs index 64eb1d90e0..1b7a7d7e0e 100644 --- a/packages/app-lib/src/launcher/mod.rs +++ b/packages/app-lib/src/launcher/mod.rs @@ -12,6 +12,7 @@ use crate::state::{ Credentials, JavaVersion, ProcessMetadata, ProfileInstallStage, }; use crate::util::io; +use crate::util::rpc::RpcServerBuilder; use crate::{State, get_resource_file, process, state as st}; use chrono::Utc; use daedalus as d; @@ -22,7 +23,6 @@ use serde::Deserialize; use st::Profile; use std::fmt::Write; use std::path::PathBuf; -use tokio::io::AsyncWriteExt; use tokio::process::Command; mod args; @@ -608,6 +608,8 @@ pub async fn launch_minecraft( let (main_class_keep_alive, main_class_path) = get_resource_file!(env "JAVA_JARS_DIR" / "theseus.jar")?; + let rpc_server = RpcServerBuilder::new().launch().await?; + command.args( args::get_jvm_arguments( args.get(&d::minecraft::ArgumentType::Jvm) @@ -633,6 +635,7 @@ pub async fn launch_minecraft( .logging .as_ref() .and_then(|x| x.get(&LoggingSide::Client)), + rpc_server.address(), )? .into_iter(), ); @@ -767,7 +770,8 @@ pub async fn launch_minecraft( state.directories.profile_logs_dir(&profile.path), version_info.logging.is_some(), main_class_keep_alive, - async |process: &ProcessMetadata, stdin| { + rpc_server, + async |process: &ProcessMetadata, rpc_server| { let process_start_time = process.start_time.to_rfc3339(); let profile_created_time = profile.created.to_rfc3339(); let profile_modified_time = profile.modified.to_rfc3339(); @@ -790,14 +794,11 @@ pub async fn launch_minecraft( let Some(value) = value else { continue; }; - stdin.write_all(b"property\t").await?; - stdin.write_all(key.as_bytes()).await?; - stdin.write_u8(b'\t').await?; - stdin.write_all(value.as_bytes()).await?; - stdin.write_u8(b'\n').await?; + rpc_server + .call_method_2::<()>("set_system_property", key, value) + .await?; } - stdin.write_all(b"launch\n").await?; - stdin.flush().await?; + rpc_server.call_method::<()>("launch").await?; Ok(()) }, ) diff --git a/packages/app-lib/src/state/process.rs b/packages/app-lib/src/state/process.rs index faf1c9b4f3..4cff0a33e9 100644 --- a/packages/app-lib/src/state/process.rs +++ b/packages/app-lib/src/state/process.rs @@ -2,6 +2,7 @@ use crate::event::emit::{emit_process, emit_profile}; use crate::event::{ProcessPayloadType, ProfilePayloadType}; use crate::profile; use crate::util::io::IOError; +use crate::util::rpc::RpcServer; use chrono::{DateTime, NaiveDateTime, TimeZone, Utc}; use dashmap::DashMap; use quick_xml::Reader; @@ -15,7 +16,7 @@ use std::path::{Path, PathBuf}; use std::process::ExitStatus; use tempfile::TempDir; use tokio::io::{AsyncBufReadExt, BufReader}; -use tokio::process::{Child, ChildStdin, Command}; +use tokio::process::{Child, Command}; use uuid::Uuid; const LAUNCHER_LOG_PATH: &str = "launcher_log.txt"; @@ -46,9 +47,10 @@ impl ProcessManager { logs_folder: PathBuf, xml_logging: bool, main_class_keep_alive: TempDir, + rpc_server: RpcServer, post_process_init: impl AsyncFnOnce( &ProcessMetadata, - &mut ChildStdin, + &RpcServer, ) -> crate::Result<()>, ) -> crate::Result { mc_command.stdout(std::process::Stdio::piped()); @@ -67,14 +69,12 @@ impl ProcessManager { profile_path: profile_path.to_string(), }, child: mc_proc, + rpc_server, _main_class_keep_alive: main_class_keep_alive, }; - if let Err(e) = post_process_init( - &process.metadata, - &mut process.child.stdin.as_mut().unwrap(), - ) - .await + if let Err(e) = + post_process_init(&process.metadata, &process.rpc_server).await { tracing::error!("Failed to run post-process init: {e}"); let _ = process.child.kill().await; @@ -165,6 +165,10 @@ impl ProcessManager { self.processes.get(&id).map(|x| x.metadata.clone()) } + pub fn get_rpc(&self, id: Uuid) -> Option { + self.processes.get(&id).map(|x| x.rpc_server.clone()) + } + pub fn get_all(&self) -> Vec { self.processes .iter() @@ -215,6 +219,7 @@ struct Process { metadata: ProcessMetadata, child: Child, _main_class_keep_alive: TempDir, + rpc_server: RpcServer, } #[derive(Debug, Default)] diff --git a/packages/app-lib/src/util/mod.rs b/packages/app-lib/src/util/mod.rs index 67c5ede167..7656b4a033 100644 --- a/packages/app-lib/src/util/mod.rs +++ b/packages/app-lib/src/util/mod.rs @@ -2,6 +2,8 @@ pub mod fetch; pub mod io; pub mod jre; +pub mod network; pub mod platform; pub mod protocol_version; +pub mod rpc; pub mod server_ping; diff --git a/packages/app-lib/src/util/network.rs b/packages/app-lib/src/util/network.rs new file mode 100644 index 0000000000..2837516c56 --- /dev/null +++ b/packages/app-lib/src/util/network.rs @@ -0,0 +1,17 @@ +use std::io; +use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}; +use tokio::net::TcpListener; + +pub async fn tcp_listen_any_loopback() -> io::Result { + // IPv4 is tried first for the best compatibility and performance with most systems. + // IPv6 is also tried in case IPv4 is not available. Resolving "localhost" is avoided + // to prevent failures deriving from improper name resolution setup. Any available + // ephemeral port is used to prevent conflicts with other services. This is all as per + // RFC 8252's recommendations + const ANY_LOOPBACK_SOCKET: &[SocketAddr] = &[ + SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0), + SocketAddr::new(IpAddr::V6(Ipv6Addr::LOCALHOST), 0), + ]; + + TcpListener::bind(ANY_LOOPBACK_SOCKET).await +} diff --git a/packages/app-lib/src/util/rpc.rs b/packages/app-lib/src/util/rpc.rs new file mode 100644 index 0000000000..d6902bd855 --- /dev/null +++ b/packages/app-lib/src/util/rpc.rs @@ -0,0 +1,258 @@ +use crate::prelude::tcp_listen_any_loopback; +use crate::{ErrorKind, Result}; +use futures::{SinkExt, StreamExt}; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; +use serde_json::Value; +use std::collections::HashMap; +use std::net::SocketAddr; +use std::pin::Pin; +use std::sync::{Arc, Mutex}; +use tokio::net::TcpListener; +use tokio::sync::{mpsc, oneshot}; +use tokio::task::AbortHandle; +use tokio_util::codec::{Decoder, LinesCodec, LinesCodecError}; +use uuid::Uuid; + +type HandlerFuture = Pin>>>; +type HandlerMethod = Box) -> HandlerFuture>; +type HandlerMap = HashMap<&'static str, HandlerMethod>; +type WaitingResponsesMap = + Arc>>>>; + +pub struct RpcServerBuilder { + handlers: HandlerMap, +} + +impl RpcServerBuilder { + pub fn new() -> Self { + Self { + handlers: HashMap::new(), + } + } + + // We'll use this function in the future. Please remove this #[allow] when we do. + #[allow(dead_code)] + pub fn handler( + mut self, + function_name: &'static str, + handler: HandlerMethod, + ) -> Self { + self.handlers.insert(function_name, Box::new(handler)); + self + } + + pub async fn launch(self) -> Result { + let socket = tcp_listen_any_loopback().await?; + let address = socket.local_addr()?; + let (message_sender, message_receiver) = mpsc::unbounded_channel(); + let waiting_responses = Arc::new(Mutex::new(HashMap::new())); + + let join_handle = { + let waiting_responses = waiting_responses.clone(); + tokio::spawn(async move { + let mut server = RunningRpcServer { + message_receiver, + handlers: self.handlers, + waiting_responses: waiting_responses.clone(), + }; + if let Err(e) = server.run(socket).await { + tracing::error!("Failed to run RPC server: {e}"); + } + waiting_responses.lock().unwrap().clear(); + }) + }; + Ok(RpcServer { + address, + message_sender, + waiting_responses, + abort_handle: join_handle.abort_handle(), + }) + } +} + +#[derive(Debug, Clone)] +pub struct RpcServer { + address: SocketAddr, + message_sender: mpsc::UnboundedSender, + waiting_responses: WaitingResponsesMap, + abort_handle: AbortHandle, +} + +impl RpcServer { + pub fn address(&self) -> SocketAddr { + self.address + } + + pub async fn call_method( + &self, + method: &str, + ) -> Result { + self.call_method_any(method, vec![]).await + } + + pub async fn call_method_2( + &self, + method: &str, + arg1: impl Serialize, + arg2: impl Serialize, + ) -> Result { + self.call_method_any( + method, + vec![serde_json::to_value(arg1)?, serde_json::to_value(arg2)?], + ) + .await + } + + async fn call_method_any( + &self, + method: &str, + args: Vec, + ) -> Result { + if self.message_sender.is_closed() { + return Err(ErrorKind::RpcError( + "RPC connection closed".to_string(), + ) + .into()); + } + + let id = Uuid::new_v4(); + let (send, recv) = oneshot::channel(); + self.waiting_responses.lock().unwrap().insert(id, send); + + let message = RpcMessage { + id, + body: RpcMessageBody::Call { + method: method.to_owned(), + args, + }, + }; + if self.message_sender.send(message).is_err() { + self.waiting_responses.lock().unwrap().remove(&id); + return Err(ErrorKind::RpcError( + "RPC connection closed while sending".to_string(), + ) + .into()); + } + + tracing::debug!("Waiting on result for {id}"); + let Ok(result) = recv.await else { + self.waiting_responses.lock().unwrap().remove(&id); + return Err(ErrorKind::RpcError( + "RPC connection closed while waiting for response".to_string(), + ) + .into()); + }; + result.and_then(|x| Ok(serde_json::from_value(x)?)) + } +} + +impl Drop for RpcServer { + fn drop(&mut self) { + self.abort_handle.abort(); + } +} + +struct RunningRpcServer { + message_receiver: mpsc::UnboundedReceiver, + handlers: HandlerMap, + waiting_responses: WaitingResponsesMap, +} + +impl RunningRpcServer { + async fn run(&mut self, listener: TcpListener) -> Result<()> { + let (socket, _) = listener.accept().await?; + drop(listener); + + let mut socket = LinesCodec::new().framed(socket); + loop { + let to_send = tokio::select! { + message = self.message_receiver.recv() => { + if message.is_none() { + break; + } + message + }, + message = socket.next() => { + let message: RpcMessage = match message { + None => break, + Some(Ok(message)) => serde_json::from_str(&message)?, + Some(Err(LinesCodecError::Io(e))) => Err(e)?, + Some(Err(LinesCodecError::MaxLineLengthExceeded)) => unreachable!(), + }; + self.handle_message(message).await? + }, + }; + if let Some(message) = to_send { + let json = serde_json::to_string(&message)?; + match socket.send(json).await { + Ok(()) => {} + Err(LinesCodecError::Io(e)) => Err(e)?, + Err(LinesCodecError::MaxLineLengthExceeded) => { + unreachable!() + } + }; + } + } + Ok(()) + } + + async fn handle_message( + &self, + message: RpcMessage, + ) -> Result> { + if let RpcMessageBody::Call { method, args } = message.body { + let response = match self.handlers.get(method.as_str()) { + Some(handler) => match handler(args).await { + Ok(result) => RpcMessageBody::Respond { response: result }, + Err(e) => RpcMessageBody::Error { + error: e.to_string(), + }, + }, + None => RpcMessageBody::Error { + error: format!("Unknown theseus RPC method {method}"), + }, + }; + Ok(Some(RpcMessage { + id: message.id, + body: response, + })) + } else if let Some(sender) = + self.waiting_responses.lock().unwrap().remove(&message.id) + { + let _ = sender.send(match message.body { + RpcMessageBody::Respond { response } => Ok(response), + RpcMessageBody::Error { error } => { + Err(ErrorKind::RpcError(error).into()) + } + _ => unreachable!(), + }); + Ok(None) + } else { + Ok(None) + } + } +} + +#[derive(Debug, Serialize, Deserialize)] +struct RpcMessage { + id: Uuid, + #[serde(flatten)] + body: RpcMessageBody, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +enum RpcMessageBody { + Call { + method: String, + args: Vec, + }, + Respond { + #[serde(default, skip_serializing_if = "Value::is_null")] + response: Value, + }, + Error { + error: String, + }, +}