diff --git a/Project.toml b/Project.toml index b6d9943a..994a57a5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,12 @@ name = "JSONRPC" uuid = "b9b8584e-8fd3-41f9-ad0c-7255d428e418" authors = ["David Anthoff "] -version = "1.4.3-DEV" +version = "2.0.0-DEV" [deps] JSON = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" +CancellationTokens = "2e8d271d-f2e2-407b-a864-17eb2156783e" [extras] TestItemRunner = "f8b46487-2199-4994-9208-9a1283c18c0a" @@ -13,6 +14,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Sockets = "6462fe0b-24de-5631-8697-dd941f90decc" [compat] +CancellationTokens = "1" JSON = "0.20, 0.21" julia = "1" diff --git a/src/JSONRPC.jl b/src/JSONRPC.jl index cd7204ab..fcf702e6 100644 --- a/src/JSONRPC.jl +++ b/src/JSONRPC.jl @@ -1,6 +1,6 @@ module JSONRPC -import JSON, UUIDs +import JSON, UUIDs, CancellationTokens include("packagedef.jl") diff --git a/src/core.jl b/src/core.jl index 81152ecb..5d5f3e19 100644 --- a/src/core.jl +++ b/src/core.jl @@ -105,14 +105,23 @@ function Base.showerror(io::IO, ex::JSONRPCError) end end +struct Request + method::String + params::Union{Nothing,Dict{String,Any},Vector{Any}} + id::Union{Nothing,String} + token::Union{CancellationTokens.CancellationToken,Nothing} +end + mutable struct JSONRPCEndpoint{IOIn <: IO,IOOut <: IO} pipe_in::IOIn pipe_out::IOOut out_msg_queue::Channel{Any} - in_msg_queue::Channel{Any} + in_msg_queue::Channel{Request} - outstanding_requests::Dict{String,Channel{Any}} + outstanding_requests::Dict{String,Channel{Any}} # These are requests sent where we are waiting for a response + cancellation_sources::Dict{String,CancellationTokens.CancellationTokenSource} # These are the cancellation sources for requests that are not finished processing + no_longer_needed_cancellation_sources::Channel{String} err_handler::Union{Nothing,Function} @@ -123,7 +132,18 @@ mutable struct JSONRPCEndpoint{IOIn <: IO,IOOut <: IO} end JSONRPCEndpoint(pipe_in, pipe_out, err_handler = nothing) = - JSONRPCEndpoint(pipe_in, pipe_out, Channel{Any}(Inf), Channel{Any}(Inf), Dict{String,Channel{Any}}(), err_handler, :idle, nothing, nothing) + JSONRPCEndpoint( + pipe_in, + pipe_out, + Channel{Any}(Inf), + Channel{Request}(Inf), + Dict{String,Channel{Any}}(), + Dict{String,CancellationTokens.CancellationTokenSource}(), + Channel{String}(Inf), + err_handler, + :idle, + nothing, + nothing) function write_transport_layer(stream, response) response_utf8 = transcode(UInt8, response) @@ -187,6 +207,13 @@ function Base.run(x::JSONRPCEndpoint) x.read_task = @async try while true + # First we delete any cancellation sources that are no longer needed. We do it this way to avoid a lock + while isready(x.no_longer_needed_cancellation_sources) + no_longer_needed_cs_id = take!(x.no_longer_needed_cancellation_sources) + delete!(x.cancellation_sources, no_longer_needed_cs_id) + end + + # Now handle new messages message = read_transport_layer(x.pipe_in) if message === nothing || x.status == :closed @@ -196,13 +223,38 @@ function Base.run(x::JSONRPCEndpoint) message_dict = JSON.parse(message) if haskey(message_dict, "method") - try - put!(x.in_msg_queue, message_dict) - catch err - if err isa InvalidStateException - break - else - rethrow(err) + method_name = message_dict["method"] + params = get(message_dict, "params", nothing) + id = get(message_dict, "id", nothing) + cancel_source = id === nothing ? nothing : CancellationTokens.CancellationTokenSource() + cancel_token = cancel_source === nothing ? nothing : CancellationTokens.get_token(cancel_source) + + if method_name == "\$/cancelRequest" + id_of_cancelled_request = params["id"] + cs = get(x.cancellation_sources, id_of_cancelled_request, nothing) # We might have sent the response already + if cs !== nothing + CancellationTokens.cancel(cs) + end + else + if id !== nothing + x.cancellation_sources[id] = cancel_source + end + + request = Request( + method_name, + params, + id, + cancel_token + ) + + try + put!(x.in_msg_queue, request) + catch err + if err isa InvalidStateException + break + else + rethrow(err) + end end end else @@ -294,20 +346,28 @@ function Base.iterate(endpoint::JSONRPCEndpoint, state = nothing) end end -function send_success_response(endpoint, original_request, result) +function send_success_response(endpoint, original_request::Request, result) check_dead_endpoint!(endpoint) - response = Dict("jsonrpc" => "2.0", "id" => original_request["id"], "result" => result) + original_request.id === nothing && error("Cannot send a response to a notification.") + + put!(endpoint.no_longer_needed_cancellation_sources, original_request.id) + + response = Dict("jsonrpc" => "2.0", "id" => original_request.id, "result" => result) response_json = JSON.json(response) put!(endpoint.out_msg_queue, response_json) end -function send_error_response(endpoint, original_request, code, message, data) +function send_error_response(endpoint, original_request::Request, code, message, data) check_dead_endpoint!(endpoint) - response = Dict("jsonrpc" => "2.0", "id" => original_request["id"], "error" => Dict("code" => code, "message" => message, "data" => data)) + original_request.id === nothing && error("Cannot send a response to a notification.") + + put!(endpoint.no_longer_needed_cancellation_sources, original_request.id) + + response = Dict("jsonrpc" => "2.0", "id" => original_request.id, "error" => Dict("code" => code, "message" => message, "data" => data)) response_json = JSON.json(response) diff --git a/src/typed.jl b/src/typed.jl index 72e22bc3..e080d81e 100644 --- a/src/typed.jl +++ b/src/typed.jl @@ -55,16 +55,20 @@ function Base.setindex!(dispatcher::MsgDispatcher, func::Function, message_type: dispatcher._handlers[message_type.method] = Handler(message_type, func) end -function dispatch_msg(x::JSONRPCEndpoint, dispatcher::MsgDispatcher, msg) +function dispatch_msg(x::JSONRPCEndpoint, dispatcher::MsgDispatcher, msg::Request) dispatcher._currentlyHandlingMsg = true try - method_name = msg["method"] + method_name = msg.method handler = get(dispatcher._handlers, method_name, nothing) if handler !== nothing param_type = get_param_type(handler.message_type) - params = param_type === Nothing ? nothing : param_type <: NamedTuple ? convert(param_type,(;(Symbol(i[1])=>i[2] for i in msg["params"])...)) : param_type(msg["params"]) + params = param_type === Nothing ? nothing : param_type <: NamedTuple ? convert(param_type,(;(Symbol(i[1])=>i[2] for i in msg.params)...)) : param_type(msg.params) - res = handler.func(x, params) + if handler.message_type isa RequestType + res = handler.func(x, params, msg.token) + else + res = handler.func(x, params) + end if handler.message_type isa RequestType if res isa JSONRPCError @@ -89,20 +93,28 @@ is_currently_handling_msg(d::MsgDispatcher) = d._currentlyHandlingMsg macro message_dispatcher(name, body) quote - function $(esc(name))(x, msg::Dict{String,Any}, context=nothing) - method_name = msg["method"]::String + function $(esc(name))(x, msg::Request, context=nothing) + method_name = msg.method $( ( :( if method_name == $(esc(i.args[2])).method param_type = get_param_type($(esc(i.args[2]))) - params = param_type === Nothing ? nothing : param_type <: NamedTuple ? convert(param_type,(;(Symbol(i[1])=>i[2] for i in msg["params"])...)) : param_type(msg["params"]) + params = param_type === Nothing ? nothing : param_type <: NamedTuple ? convert(param_type,(;(Symbol(i[1])=>i[2] for i in msg.params)...)) : param_type(msg.params) if context===nothing - res = $(esc(i.args[3]))(x, params) + if $(esc(i.args[2])) isa RequestType + res = $(esc(i.args[3]))(params, msg.token) + else + res = $(esc(i.args[3]))(params) + end else - res = $(esc(i.args[3]))(x, params, context) + if $(esc(i.args[2])) isa RequestType + res = $(esc(i.args[3]))(params, context, msg.token) + else + res = $(esc(i.args[3]))(params, context) + end end if $(esc(i.args[2])) isa RequestType diff --git a/test/test_typed.jl b/test/test_typed.jl index d2c88bd2..a9b6098e 100644 --- a/test/test_typed.jl +++ b/test/test_typed.jl @@ -6,7 +6,7 @@ request1_type = JSONRPC.RequestType("request1", Foo, String) request2_type = JSONRPC.RequestType("request2", Nothing, String) - notify1_type = JSONRPC.NotificationType("notify1", String) + notify1_type = JSONRPC.NotificationType("notify1", Vector{String}) global g_var = "" @@ -19,20 +19,23 @@ global conn = JSONRPC.JSONRPCEndpoint(sock, sock) global msg_dispatcher = JSONRPC.MsgDispatcher() - msg_dispatcher[request1_type] = (conn, params) -> begin + msg_dispatcher[request1_type] = (conn, params, token) -> begin @test JSONRPC.is_currently_handling_msg(msg_dispatcher) params.fieldA == 1 ? "YES" : "NO" end - msg_dispatcher[request2_type] = (conn, params) -> JSONRPC.JSONRPCError(-32600, "Our message", nothing) - msg_dispatcher[notify1_type] = (conn, params) -> global g_var = params + msg_dispatcher[request2_type] = (conn, params, token) -> JSONRPC.JSONRPCError(-32600, "Our message", nothing) + msg_dispatcher[notify1_type] = (conn, params) -> global g_var = params[1] run(conn) for msg in conn + @info "Got a message, now dispatching" msg JSONRPC.dispatch_msg(conn, msg_dispatcher, msg) + @info "Finished dispatching" end catch err Base.display_error(stderr, err, catch_backtrace()) + Base.flush(stderr) end wait(server_is_up) @@ -42,7 +45,7 @@ run(conn2) - JSONRPC.send(conn2, notify1_type, "TEST") + JSONRPC.send(conn2, notify1_type, ["TEST"]) res = JSONRPC.send(conn2, request1_type, Foo(fieldA=1, fieldB="FOO")) @@ -70,7 +73,7 @@ global conn = JSONRPC.JSONRPCEndpoint(sock, sock) global msg_dispatcher = JSONRPC.MsgDispatcher() - msg_dispatcher[request2_type] = (conn, params)->34 # The request type requires a `String` return, so this tests whether we get an error. + msg_dispatcher[request2_type] = (conn, params, token)->34 # The request type requires a `String` return, so this tests whether we get an error. run(conn) @@ -79,6 +82,7 @@ end catch err Base.display_error(stderr, err, catch_backtrace()) + Base.flush(stderr) end wait(server_is_up) @@ -117,18 +121,18 @@ end request1_type = JSONRPC.RequestType("request1", Foo, String) request2_type = JSONRPC.RequestType("request2", Nothing, String) - notify1_type = JSONRPC.NotificationType("notify1", String) + notify1_type = JSONRPC.NotificationType("notify1", Vector{String}) global g_var = "" server_is_up = Base.Condition() JSONRPC.@message_dispatcher my_dispatcher begin - request1_type => (conn, params) -> begin + request1_type => (params, token) -> begin params.fieldA == 1 ? "YES" : "NO" end - request2_type => (conn, params) -> JSONRPC.JSONRPCError(-32600, "Our message", nothing) - notify1_type => (conn, params) -> global g_var = params + request2_type => (params, token) -> JSONRPC.JSONRPCError(-32600, "Our message", nothing) + notify1_type => (params) -> global g_var = params[1] end server_task = @async try @@ -154,7 +158,7 @@ end run(conn2) - JSONRPC.send(conn2, notify1_type, "TEST") + JSONRPC.send(conn2, notify1_type, ["TEST"]) res = JSONRPC.send(conn2, request1_type, Foo(fieldA=1, fieldB="FOO")) @@ -176,7 +180,7 @@ end server_is_up = Base.Condition() JSONRPC.@message_dispatcher my_dispatcher2 begin - request2_type => (conn, params) -> 34 # The request type requires a `String` return, so this tests whether we get an error. + request2_type => (params, token) -> 34 # The request type requires a `String` return, so this tests whether we get an error. end server_task2 = @async try