diff --git a/handler.go b/handler.go index 7243a0d..ef5f53a 100644 --- a/handler.go +++ b/handler.go @@ -43,6 +43,12 @@ type request struct { Meta map[string]string `json:"meta,omitempty"` } +// Limit request size. Ideally this limit should be specific for each field +// in the JSON request but as a simple defensive measure we just limit the +// entire HTTP body. +// Configured by WithMaxRequestSize. +const DEFAULT_MAX_REQUEST_SIZE = 100 << 20 // 100 MiB + type respError struct { Code int `json:"code"` Message string `json:"message"` @@ -111,7 +117,33 @@ func (s *RPCServer) handleReader(ctx context.Context, r io.Reader, w io.Writer, } var req request - if err := json.NewDecoder(r).Decode(&req); err != nil { + // We read the entire request upfront in a buffer to be able to tell if the + // client sent more than maxRequestSize and report it back as an explicit error, + // instead of just silently truncating it and reporting a more vague parsing + // error. + bufferedRequest := new(bytes.Buffer) + // We use LimitReader to enforce maxRequestSize. Since it won't return an + // EOF we can't actually know if the client sent more than the maximum or + // not, so we read one byte more over the limit to explicitly query that. + // FIXME: Maybe there's a cleaner way to do this. + reqSize, err := bufferedRequest.ReadFrom(io.LimitReader(r, s.maxRequestSize+1)) + if err != nil { + // ReadFrom will discard EOF so any error here is unexpected and should + // be reported. + rpcError(wf, &req, rpcParseError, xerrors.Errorf("reading request: %w", err)) + return + } + if reqSize > s.maxRequestSize { + rpcError(wf, &req, rpcParseError, + // rpcParseError is the closest we have from the standard errors defined + // in [jsonrpc spec](https://www.jsonrpc.org/specification#error_object) + // to report the maximum limit. + xerrors.Errorf("request bigger than maximum %d allowed", + s.maxRequestSize)) + return + } + + if err := json.NewDecoder(bufferedRequest).Decode(&req); err != nil { rpcError(wf, &req, rpcParseError, xerrors.Errorf("unmarshaling request: %w", err)) return } diff --git a/options_server.go b/options_server.go index 6d5eb64..2da0f1c 100644 --- a/options_server.go +++ b/options_server.go @@ -8,14 +8,16 @@ import ( type ParamDecoder func(ctx context.Context, json []byte) (reflect.Value, error) type ServerConfig struct { - paramDecoders map[reflect.Type]ParamDecoder + paramDecoders map[reflect.Type]ParamDecoder + maxRequestSize int64 } type ServerOption func(c *ServerConfig) func defaultServerConfig() ServerConfig { return ServerConfig{ - paramDecoders: map[reflect.Type]ParamDecoder{}, + paramDecoders: map[reflect.Type]ParamDecoder{}, + maxRequestSize: DEFAULT_MAX_REQUEST_SIZE, } } @@ -24,3 +26,9 @@ func WithParamDecoder(t interface{}, decoder ParamDecoder) ServerOption { c.paramDecoders[reflect.TypeOf(t).Elem()] = decoder } } + +func WithMaxRequestSize(max int64) ServerOption { + return func(c *ServerConfig) { + c.maxRequestSize = max + } +} diff --git a/server.go b/server.go index d35d0ea..6d0b634 100644 --- a/server.go +++ b/server.go @@ -22,6 +22,8 @@ type RPCServer struct { methods map[string]rpcHandler paramDecoders map[reflect.Type]ParamDecoder + + maxRequestSize int64 } // NewServer creates new RPCServer instance @@ -32,8 +34,9 @@ func NewServer(opts ...ServerOption) *RPCServer { } return &RPCServer{ - methods: map[string]rpcHandler{}, - paramDecoders: config.paramDecoders, + methods: map[string]rpcHandler{}, + paramDecoders: config.paramDecoders, + maxRequestSize: config.maxRequestSize, } }