Skip to content

Commit

Permalink
Handle invalid requests in generated_impl::handle_request()
Browse files Browse the repository at this point in the history
This abstraction is a little bit cleaner since it hides the concept of
invalid requests from the user. Additionally, it simplifies the logic in
the `Service::call` trait method implementation on `LspService`.

This commit is an early step towards refactoring the service model:

#177
  • Loading branch information
ebkalderon committed Sep 24, 2020
1 parent 22268bf commit 7af7102
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 41 deletions.
24 changes: 15 additions & 9 deletions src/jsonrpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,15 +115,6 @@ pub enum Incoming {
Request(ServerRequest),
/// Response to a server-to-client request.
Response(Response),
/// An invalid JSON-RPC request.
Invalid {
/// Request ID, if known.
#[serde(default)]
id: Option<Id>,
/// Method name, if known.
#[serde(default)]
method: Option<String>,
},
}

/// A server-to-client LSP request.
Expand Down Expand Up @@ -276,4 +267,19 @@ mod tests {
let from_value: Outgoing = serde_json::from_value(v).unwrap();
assert_eq!(from_str, from_value);
}

#[test]
fn parses_invalid_server_request() {
let unknown_method = json!({"jsonrpc":"2.0","method":"foo"});
let _: Incoming = serde_json::from_value(unknown_method).unwrap();

let unknown_method_with_id = json!({"jsonrpc":"2.0","method":"foo","id":1});
let _: Incoming = serde_json::from_value(unknown_method_with_id).unwrap();

let missing_method = json!({"jsonrpc":"2.0"});
let _: Incoming = serde_json::from_value(missing_method).unwrap();

let missing_method_with_id = json!({"jsonrpc":"2.0","id":1});
let _: Incoming = serde_json::from_value(missing_method_with_id).unwrap();
}
}
16 changes: 2 additions & 14 deletions src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@ use std::task::{Context, Poll};
use futures::channel::mpsc::{self, Receiver};
use futures::stream::FusedStream;
use futures::{future, FutureExt, Stream};
use log::{error, trace};
use log::trace;
use tower_service::Service;

use super::client::Client;
use super::jsonrpc::{self, ClientRequests, Incoming, Outgoing, Response, ServerRequests};
use super::jsonrpc::{ClientRequests, Incoming, Outgoing, ServerRequests};
use super::{generated_impl, LanguageServer, ServerState, State};

/// Error that occurs when attempting to call the language server after it has already exited.
Expand Down Expand Up @@ -130,18 +130,6 @@ impl Service<Incoming> for LspService {
self.pending_client.insert(res);
future::ok(None).boxed()
}
Incoming::Invalid { id, method } => match (id, method) {
(None, Some(method)) if method.starts_with("$/") => future::ok(None).boxed(),
(id, Some(method)) => {
error!("method {:?} not found", method);
let res = Response::error(id, jsonrpc::Error::method_not_found());
future::ok(Some(Outgoing::Response(res))).boxed()
}
(id, None) => {
let res = Response::error(id, jsonrpc::Error::invalid_request());
future::ok(Some(Outgoing::Response(res))).boxed()
}
},
}
}
}
Expand Down
62 changes: 44 additions & 18 deletions tower-lsp-macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -149,17 +149,17 @@ fn gen_server_router(trait_name: &syn::Ident, methods: &[MethodCall]) -> proc_ma

Ok(Some(Outgoing::Response(res)))
})
},
}
(ServerMethod::#var_name { params: Invalid(e), id }, State::Uninitialized) => {
error!("invalid parameters for {:?} request", #rpc_name);
let res = Response::error(Some(id), Error::invalid_params(e));
future::ok(Some(Outgoing::Response(res))).boxed()
},
}
(ServerMethod::#var_name { id, .. }, State::Initializing) => {
warn!("received duplicate `initialize` request, ignoring");
let res = Response::error(Some(id), Error::invalid_request());
future::ok(Some(Outgoing::Response(res))).boxed()
},
}
},
(true, false) if rpc_name == "shutdown" => quote! {
(ServerMethod::#var_name { id }, State::Initialized) => {
Expand All @@ -169,42 +169,42 @@ fn gen_server_router(trait_name: &syn::Ident, methods: &[MethodCall]) -> proc_ma
.execute(id, async move { server.#handler().await })
.map(|v| Ok(Some(Outgoing::Response(v))))
.boxed()
},
}
},
(true, true) => quote! {
(ServerMethod::#var_name { params: Valid(p), id }, State::Initialized) => {
pending
.execute(id, async move { server.#handler(p).await })
.map(|v| Ok(Some(Outgoing::Response(v))))
.boxed()
},
}
(ServerMethod::#var_name { params: Invalid(e), id }, State::Initialized) => {
error!("invalid parameters for {:?} request", #rpc_name);
let res = Response::error(Some(id), Error::invalid_params(e));
future::ok(Some(Outgoing::Response(res))).boxed()
},
}
},
(true, false) => quote! {
(ServerMethod::#var_name { id }, State::Initialized) => {
pending
.execute(id, async move { server.#handler().await })
.map(|v| Ok(Some(Outgoing::Response(v))))
.boxed()
},
}
},
(false, true) => quote! {
(ServerMethod::#var_name { params: Valid(p) }, State::Initialized) => {
Box::pin(async move { server.#handler(p).await; Ok(None) })
},
}
(ServerMethod::#var_name { .. }, State::Initialized) => {
warn!("invalid parameters for {:?} notification", #rpc_name);
future::ok(None).boxed()
},
}
},
(false, false) => quote! {
(ServerMethod::#var_name, State::Initialized) => {
Box::pin(async move { server.#handler().await; Ok(None) })
},
}
},
}
})
Expand Down Expand Up @@ -236,7 +236,15 @@ fn gen_server_router(trait_name: &syn::Ident, methods: &[MethodCall]) -> proc_ma
pub struct ServerRequest {
jsonrpc: Version,
#[serde(flatten)]
inner: ServerMethod,
kind: RequestKind,
}

#[derive(Clone, Debug, PartialEq, serde::Deserialize)]
#[cfg_attr(test, derive(serde::Serialize))]
#[serde(untagged)]
enum RequestKind {
Valid(ServerMethod),
Invalid { id: Option<Id>, method: Option<String> },
}

#[derive(Clone, Debug, PartialEq, serde::Deserialize)]
Expand All @@ -251,7 +259,6 @@ fn gen_server_router(trait_name: &syn::Ident, methods: &[MethodCall]) -> proc_ma
}

impl ServerMethod {
#[inline]
fn id(&self) -> Option<&Id> {
match *self {
#id_match_arms
Expand Down Expand Up @@ -285,34 +292,53 @@ fn gen_server_router(trait_name: &syn::Ident, methods: &[MethodCall]) -> proc_ma
server: T,
state: &Arc<ServerState>,
pending: &ServerRequests,
incoming: ServerRequest,
request: ServerRequest,
) -> Pin<Box<dyn Future<Output = Result<Option<Outgoing>, ExitedError>> + Send>> {
use Params::*;
match (incoming.inner, state.get()) {

let method = match request.kind {
RequestKind::Valid(method) => method,
RequestKind::Invalid { id: Some(id), method: Some(m) } => {
error!("method {:?} not found", m);
let res = Response::error(Some(id), Error::method_not_found());
return future::ok(Some(Outgoing::Response(res))).boxed();
}
RequestKind::Invalid { id: Some(id), .. } => {
let res = Response::error(Some(id), Error::invalid_request());
return future::ok(Some(Outgoing::Response(res))).boxed();
}
RequestKind::Invalid { id: None, method: Some(m) } if !m.starts_with("$/") => {
error!("method {:?} not found", m);
return future::ok(None).boxed();
}
RequestKind::Invalid { id: None, .. } => return future::ok(None).boxed(),
};

match (method, state.get()) {
#route_match_arms
(ServerMethod::CancelRequest { id }, State::Initialized) => {
pending.cancel(&id);
future::ok(None).boxed()
},
}
(ServerMethod::Exit, _) => {
info!("exit notification received, stopping");
state.set(State::Exited);
pending.cancel_all();
future::ok(None).boxed()
},
}
(other, State::Uninitialized) => Box::pin(match other.id().cloned() {
None => future::ok(None),
Some(id) => {
let res = Response::error(Some(id), not_initialized_error());
future::ok(Some(Outgoing::Response(res)))
},
}
}),
(other, _) => Box::pin(match other.id().cloned() {
None => future::ok(None),
Some(id) => {
let res = Response::error(Some(id), Error::invalid_request());
future::ok(Some(Outgoing::Response(res)))
},
}
}),
}
}
Expand Down

0 comments on commit 7af7102

Please sign in to comment.