From 3fa88903d34269d97551d8473d7f058e6791189a Mon Sep 17 00:00:00 2001 From: Raunak Raj <71929976+bajrangCoder@users.noreply.github.com> Date: Thu, 9 Oct 2025 22:00:25 +0530 Subject: [PATCH 1/3] Add LSP bridge WebSocket server --- src/lsp.rs | 312 ++++++++++++++++++++++++++++++++++++++++++++++++++++ src/main.rs | 60 ++++++++-- 2 files changed, 363 insertions(+), 9 deletions(-) create mode 100644 src/lsp.rs diff --git a/src/lsp.rs b/src/lsp.rs new file mode 100644 index 0000000..290a807 --- /dev/null +++ b/src/lsp.rs @@ -0,0 +1,312 @@ +use axum::extract::{ + ws::{Message, WebSocket, WebSocketUpgrade}, + State, +}; +use axum::http::HeaderValue; +use axum::response::IntoResponse; +use axum::routing::get; +use axum::Router; +use futures::{SinkExt, StreamExt}; +use std::net::Ipv4Addr; +use std::process::ExitStatus; +use std::sync::Arc; +use tokio::io::{AsyncBufReadExt, AsyncReadExt, AsyncWriteExt, BufReader}; +use tokio::process::{Child, ChildStderr, Command}; +use tokio::sync::{mpsc, oneshot}; +use tokio::time::{sleep, Duration, Instant}; +use tower_http::cors::{Any, CorsLayer}; +use tower_http::trace::{DefaultMakeSpan, TraceLayer}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +const EXIT_POLL_INTERVAL: Duration = Duration::from_millis(200); +const GRACEFUL_SHUTDOWN: Duration = Duration::from_secs(2); + +#[derive(Clone)] +pub struct LspBridgeConfig { + pub program: String, + pub args: Vec, +} + +#[derive(Clone)] +struct LspState { + config: Arc, +} + +pub async fn start_lsp_server( + host: Ipv4Addr, + port: u16, + allow_any_origin: bool, + config: LspBridgeConfig, +) { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!("{}=debug,tower_http=info", env!("CARGO_CRATE_NAME")).into() + }), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + tracing::info!( + program = %config.program, + args = ?config.args, + "Starting LSP bridge server", + ); + + let cors = if allow_any_origin { + CorsLayer::new() + .allow_origin(Any) + .allow_methods(Any) + .allow_headers(Any) + } else { + let localhost = "https://localhost" + .parse::() + .expect("valid origin"); + CorsLayer::new() + .allow_origin(localhost) + .allow_methods(Any) + .allow_headers(Any) + }; + + let state = LspState { + config: Arc::new(config), + }; + + let app = Router::new() + .route("/", get(upgrade_lsp_bridge)) + .with_state(state) + .layer( + TraceLayer::new_for_http() + .make_span_with(DefaultMakeSpan::default().include_headers(true)), + ) + .layer(cors); + + let addr: std::net::SocketAddr = (host, port).into(); + + match tokio::net::TcpListener::bind(addr).await { + Ok(listener) => { + tracing::info!("listening on {}", listener.local_addr().unwrap()); + + if let Err(e) = axum::serve(listener, app).await { + tracing::error!("Server error: {}", e); + } + } + Err(e) => { + if e.kind() == std::io::ErrorKind::AddrInUse { + tracing::error!("Port is already in use please kill all other instances of axs server or stop any other process or app that maybe be using port {}", port); + } else { + tracing::error!("Failed to bind: {}", e); + } + } + } +} + +async fn upgrade_lsp_bridge( + ws: WebSocketUpgrade, + State(state): State, +) -> impl IntoResponse { + let config = state.config.clone(); + ws.on_upgrade(move |socket| async move { + if let Err(err) = run_bridge(socket, config).await { + tracing::error!(error = %err, "LSP bridge session ended with error"); + } + }) +} + +async fn run_bridge(socket: WebSocket, config: Arc) -> Result<(), String> { + let mut command = Command::new(&config.program); + command.args(&config.args); + command.stdin(std::process::Stdio::piped()); + command.stdout(std::process::Stdio::piped()); + command.stderr(std::process::Stdio::piped()); + + let mut child = command + .spawn() + .map_err(|e| format!("Failed to spawn LSP command '{}': {e}", config.program))?; + + let stdout = child + .stdout + .take() + .ok_or_else(|| "Failed to capture LSP stdout".to_string())?; + let stdin = child + .stdin + .take() + .ok_or_else(|| "Failed to capture LSP stdin".to_string())?; + let stderr = child.stderr.take(); + + if let Some(stderr) = stderr { + tokio::spawn(async move { + if let Err(err) = forward_stderr(stderr).await { + tracing::error!(error = %err, "Failed to read LSP stderr"); + } + }); + } + + let (mut ws_sender, ws_receiver) = socket.split(); + let (ws_send_tx, mut ws_send_rx) = mpsc::channel::(32); + let (client_closed_tx, client_closed_rx) = oneshot::channel::<()>(); + + let stdout_task = { + let tx = ws_send_tx.clone(); + tokio::spawn(async move { forward_stdout(stdout, tx).await }) + }; + + let ws_sender_task = tokio::spawn(async move { + while let Some(msg) = ws_send_rx.recv().await { + if ws_sender.send(msg).await.is_err() { + break; + } + } + let _ = ws_sender.close().await; + }); + + let ws_to_child_task = { + let tx = ws_send_tx.clone(); + tokio::spawn(async move { + forward_client_messages(ws_receiver, stdin, tx, client_closed_tx).await + }) + }; + + let exit_status = monitor_child(&mut child, client_closed_rx).await?; + + let _ = ws_send_tx.send(Message::Close(None)).await; + drop(ws_send_tx); + + let _ = ws_to_child_task.await; + let _ = stdout_task.await; + let _ = ws_sender_task.await; + + if exit_status.success() { + tracing::info!("LSP command exited cleanly"); + } else { + tracing::warn!(?exit_status, "LSP command exited with non-zero status"); + } + + Ok(()) +} + +async fn forward_stdout(mut stdout: tokio::process::ChildStdout, tx: mpsc::Sender) { + let mut buffer = vec![0u8; 8192]; + loop { + match stdout.read(&mut buffer).await { + Ok(0) => break, + Ok(n) => { + if tx + .send(Message::Binary(buffer[..n].to_vec().into())) + .await + .is_err() + { + break; + } + } + Err(err) => { + tracing::error!(error = %err, "Failed to read from LSP stdout"); + break; + } + } + } +} + +async fn forward_client_messages( + mut receiver: futures::stream::SplitStream, + mut stdin: tokio::process::ChildStdin, + tx: mpsc::Sender, + shutdown_tx: oneshot::Sender<()>, +) { + while let Some(msg) = receiver.next().await { + match msg { + Ok(Message::Binary(data)) => { + if let Err(err) = stdin.write_all(&data).await { + tracing::error!(error = %err, "Failed to write binary frame to LSP"); + break; + } + } + Ok(Message::Text(text)) => { + if let Err(err) = stdin.write_all(text.as_bytes()).await { + tracing::error!(error = %err, "Failed to write text frame to LSP"); + break; + } + } + Ok(Message::Ping(payload)) => { + let _ = tx.send(Message::Pong(payload)).await; + continue; + } + Ok(Message::Pong(_)) => { + continue; + } + Ok(Message::Close(frame)) => { + let _ = tx.send(Message::Close(frame.clone())).await; + break; + } + Err(err) => { + tracing::error!(error = %err, "WebSocket receive error"); + break; + } + } + } + + let _ = stdin.shutdown().await; + let _ = shutdown_tx.send(()); +} + +async fn forward_stderr(stderr: ChildStderr) -> Result<(), std::io::Error> { + let mut reader = BufReader::new(stderr); + let mut line = String::new(); + + loop { + line.clear(); + let read = reader.read_line(&mut line).await?; + if read == 0 { + break; + } + + tracing::warn!(target: "lsp_stderr", message = %line.trim_end()); + } + + Ok(()) +} + +async fn monitor_child( + child: &mut Child, + mut client_closed: oneshot::Receiver<()>, +) -> Result { + loop { + tokio::select! { + res = &mut client_closed => { + if res.is_err() { + tracing::debug!("LSP client channel dropped without close signal"); + } + + let deadline = Instant::now() + GRACEFUL_SHUTDOWN; + loop { + match child.try_wait() { + Ok(Some(status)) => return Ok(status), + Ok(None) => { + if Instant::now() >= deadline { + break; + } + sleep(EXIT_POLL_INTERVAL).await; + } + Err(err) => return Err(format!("Failed to poll LSP process: {err}")), + } + } + + child + .kill() + .await + .map_err(|e| format!("Failed to terminate LSP process: {e}"))?; + return child + .wait() + .await + .map_err(|e| format!("Failed to await LSP process exit: {e}")); + } + _ = sleep(EXIT_POLL_INTERVAL) => { + match child.try_wait() { + Ok(Some(status)) => return Ok(status), + Ok(None) => {} + Err(err) => return Err(format!("Failed to poll LSP process: {err}")), + } + } + } + } +} diff --git a/src/main.rs b/src/main.rs index a46e3f0..cde6064 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,12 +1,13 @@ +mod lsp; mod terminal; mod updates; mod utils; use clap::{Parser, Subcommand}; use colored::Colorize; +use lsp::{start_lsp_server, LspBridgeConfig}; use std::net::Ipv4Addr; -use terminal::set_default_command; -use terminal::start_server; +use terminal::{set_default_command, start_server}; use updates::UpdateChecker; use utils::get_ip_address; @@ -17,16 +18,16 @@ const LOCAL_IP: Ipv4Addr = Ipv4Addr::new(127, 0, 0, 1); #[command(name = "acodex_server(axs)",version, author = "Raunak Raj ", about = "CLI/Server backend to serve pty over socket", long_about = None)] struct Cli { /// Port to start the server - #[arg(short, long, default_value_t = DEFAULT_PORT, value_parser = clap::value_parser!(u16).range(1..))] + #[arg(short, long, default_value_t = DEFAULT_PORT, value_parser = clap::value_parser!(u16).range(1..), global = true)] port: u16, /// Start the server on local network (ip) - #[arg(short, long)] + #[arg(short, long, global = true)] ip: bool, /// Custom command or shell for interactive PTY (e.g. "/usr/bin/bash") #[arg(short = 'c', long = "command")] command_override: Option, /// Allow all origins for CORS (dangerous). By default only https://localhost is allowed. - #[arg(long = "allow-any-origin")] + #[arg(long = "allow-any-origin", global = true)] allow_any_origin: bool, #[command(subcommand)] command: Option, @@ -36,6 +37,14 @@ struct Cli { enum Commands { /// Update axs server Update, + /// Start a WebSocket LSP bridge for a stdio language server + Lsp { + /// The language server binary to run (e.g. "rust-analyzer") + server: String, + /// Additional arguments to forward to the language server + #[arg(trailing_var_arg = true)] + server_args: Vec, + }, } fn print_update_available(current_version: &str, new_version: &str) { @@ -66,7 +75,15 @@ async fn check_updates_in_background() { async fn main() { let cli: Cli = Cli::parse(); - match cli.command { + let Cli { + port, + ip, + command_override, + allow_any_origin, + command, + } = cli; + + match command { Some(Commands::Update) => { println!("{} {}", "⟳".blue().bold(), "Checking for updates...".blue()); @@ -122,15 +139,40 @@ async fn main() { } } } + Some(Commands::Lsp { + server, + server_args, + }) => { + let host = if ip { + get_ip_address().unwrap_or_else(|| { + println!( + "{} localhost.", + "Error: IP address not found. Starting server on" + .red() + .bold() + ); + LOCAL_IP + }) + } else { + LOCAL_IP + }; + + let config = LspBridgeConfig { + program: server, + args: server_args, + }; + + start_lsp_server(host, port, allow_any_origin, config).await; + } None => { tokio::task::spawn(check_updates_in_background()); - if let Some(cmd) = cli.command_override { + if let Some(cmd) = command_override { // Set custom default command for interactive terminals set_default_command(cmd); } - let ip = if cli.ip { + let ip = if ip { get_ip_address().unwrap_or_else(|| { println!( "{} localhost.", @@ -144,7 +186,7 @@ async fn main() { LOCAL_IP }; - start_server(ip, cli.port, cli.allow_any_origin).await; + start_server(ip, port, allow_any_origin).await; } } } From daba58a771195f27771deee302114f5bdcb7a8c8 Mon Sep 17 00:00:00 2001 From: Raunak Raj <71929976+bajrangCoder@users.noreply.github.com> Date: Fri, 10 Oct 2025 17:17:35 +0530 Subject: [PATCH 2/3] fix --- src/lsp.rs | 100 +++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 93 insertions(+), 7 deletions(-) diff --git a/src/lsp.rs b/src/lsp.rs index 290a807..88c4076 100644 --- a/src/lsp.rs +++ b/src/lsp.rs @@ -7,6 +7,7 @@ use axum::response::IntoResponse; use axum::routing::get; use axum::Router; use futures::{SinkExt, StreamExt}; +use std::collections::VecDeque; use std::net::Ipv4Addr; use std::process::ExitStatus; use std::sync::Arc; @@ -124,6 +125,12 @@ async fn run_bridge(socket: WebSocket, config: Arc) -> Result<( .spawn() .map_err(|e| format!("Failed to spawn LSP command '{}': {e}", config.program))?; + tracing::info!( + program = %config.program, + args = ?config.args, + "WebSocket client connected; LSP process spawned", + ); + let stdout = child .stdout .take() @@ -186,19 +193,30 @@ async fn run_bridge(socket: WebSocket, config: Arc) -> Result<( } async fn forward_stdout(mut stdout: tokio::process::ChildStdout, tx: mpsc::Sender) { - let mut buffer = vec![0u8; 8192]; + let mut buf = vec![0u8; 8192]; + let mut decoder = LspMessageFramer::default(); + loop { - match stdout.read(&mut buffer).await { + match stdout.read(&mut buf).await { Ok(0) => break, Ok(n) => { - if tx - .send(Message::Binary(buffer[..n].to_vec().into())) - .await - .is_err() - { + if let Err(err) = decoder.push(&buf[..n]) { + tracing::error!(error = %err, "Failed to decode LSP stdout stream"); break; } + + while let Some(frame) = decoder.next_message() { + let message = match String::from_utf8(frame.clone()) { + Ok(text) => Message::Text(text.into()), + Err(_) => Message::Binary(frame.into()), + }; + + if tx.send(message).await.is_err() { + return; + } + } } + Err(err) if err.kind() == std::io::ErrorKind::Interrupted => continue, Err(err) => { tracing::error!(error = %err, "Failed to read from LSP stdout"); break; @@ -220,12 +238,20 @@ async fn forward_client_messages( tracing::error!(error = %err, "Failed to write binary frame to LSP"); break; } + if let Err(err) = stdin.flush().await { + tracing::error!(error = %err, "Failed to flush LSP stdin"); + break; + } } Ok(Message::Text(text)) => { if let Err(err) = stdin.write_all(text.as_bytes()).await { tracing::error!(error = %err, "Failed to write text frame to LSP"); break; } + if let Err(err) = stdin.flush().await { + tracing::error!(error = %err, "Failed to flush LSP stdin"); + break; + } } Ok(Message::Ping(payload)) => { let _ = tx.send(Message::Pong(payload)).await; @@ -249,6 +275,66 @@ async fn forward_client_messages( let _ = shutdown_tx.send(()); } +#[derive(Default)] +struct LspMessageFramer { + buffer: Vec, + messages: VecDeque>, +} + +impl LspMessageFramer { + fn push(&mut self, chunk: &[u8]) -> Result<(), String> { + self.buffer.extend_from_slice(chunk); + + loop { + let Some(header_end) = find_header_terminator(&self.buffer) else { + break; + }; + + let header = &self.buffer[..header_end]; + let content_length = parse_content_length(header)?; + let frame_len = header_end + 4 + content_length; // include delimiter + + if self.buffer.len() < frame_len { + break; + } + + let frame = self.buffer.drain(..frame_len).collect::>(); + self.messages.push_back(frame); + } + + Ok(()) + } + + fn next_message(&mut self) -> Option> { + self.messages.pop_front() + } +} + +fn find_header_terminator(buffer: &[u8]) -> Option { + buffer.windows(4).position(|window| window == b"\r\n\r\n") +} + +fn parse_content_length(header: &[u8]) -> Result { + let header_str = + std::str::from_utf8(header).map_err(|_| "Invalid UTF-8 in LSP header".to_string())?; + + for line in header_str.split("\r\n") { + let mut parts = line.splitn(2, ':'); + let key = parts.next().map(str::trim); + let value = parts.next().map(str::trim); + + if let (Some(key), Some(value)) = (key, value) { + if key.eq_ignore_ascii_case("content-length") { + return value + .parse::() + .map_err(|_| format!("Invalid Content-Length header: {value}")); + } + } + } + + Err("Missing Content-Length header".to_string()) +} + async fn forward_stderr(stderr: ChildStderr) -> Result<(), std::io::Error> { let mut reader = BufReader::new(stderr); let mut line = String::new(); From 8df5e95f8aeafc707f053e55bae30d442f34ee5a Mon Sep 17 00:00:00 2001 From: Raunak Raj <71929976+bajrangCoder@users.noreply.github.com> Date: Fri, 10 Oct 2025 17:32:19 +0530 Subject: [PATCH 3/3] Fix LSP message framing and header handling --- src/lsp.rs | 23 ++++++++++++++++------- 1 file changed, 16 insertions(+), 7 deletions(-) diff --git a/src/lsp.rs b/src/lsp.rs index 88c4076..a8cbcf0 100644 --- a/src/lsp.rs +++ b/src/lsp.rs @@ -244,8 +244,15 @@ async fn forward_client_messages( } } Ok(Message::Text(text)) => { - if let Err(err) = stdin.write_all(text.as_bytes()).await { - tracing::error!(error = %err, "Failed to write text frame to LSP"); + let body = text.as_bytes(); + let header = format!("Content-Length: {}\r\n\r\n", body.len()); + + if let Err(err) = stdin.write_all(header.as_bytes()).await { + tracing::error!(error = %err, "Failed to send LSP header"); + break; + } + if let Err(err) = stdin.write_all(body).await { + tracing::error!(error = %err, "Failed to write LSP payload"); break; } if let Err(err) = stdin.flush().await { @@ -290,16 +297,18 @@ impl LspMessageFramer { break; }; - let header = &self.buffer[..header_end]; - let content_length = parse_content_length(header)?; - let frame_len = header_end + 4 + content_length; // include delimiter + let headers = &self.buffer[..header_end]; + let content_length = parse_content_length(headers)?; + let body_start = header_end + 4; + let frame_len = body_start + content_length; if self.buffer.len() < frame_len { break; } - let frame = self.buffer.drain(..frame_len).collect::>(); - self.messages.push_back(frame); + let body = self.buffer[body_start..frame_len].to_vec(); + self.buffer.drain(..frame_len); + self.messages.push_back(body); } Ok(())