Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 33 additions & 1 deletion handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ type request struct {
Meta map[string]string `json:"meta,omitempty"`
}

// Limit request size. Ideally this limit should be specific for each field
// in the JSON request but as a simple defensive measure we just limit the
// entire HTTP body.
// Configured by WithMaxRequestSize.
const DEFAULT_MAX_REQUEST_SIZE = 100 << 20 // 100 MiB

type respError struct {
Code int `json:"code"`
Message string `json:"message"`
Expand Down Expand Up @@ -111,7 +117,33 @@ func (s *RPCServer) handleReader(ctx context.Context, r io.Reader, w io.Writer,
}

var req request
if err := json.NewDecoder(r).Decode(&req); err != nil {
// We read the entire request upfront in a buffer to be able to tell if the
// client sent more than maxRequestSize and report it back as an explicit error,
// instead of just silently truncating it and reporting a more vague parsing
// error.
bufferedRequest := new(bytes.Buffer)
// We use LimitReader to enforce maxRequestSize. Since it won't return an
// EOF we can't actually know if the client sent more than the maximum or
// not, so we read one byte more over the limit to explicitly query that.
// FIXME: Maybe there's a cleaner way to do this.
reqSize, err := bufferedRequest.ReadFrom(io.LimitReader(r, s.maxRequestSize+1))
if err != nil {
// ReadFrom will discard EOF so any error here is unexpected and should
// be reported.
rpcError(wf, &req, rpcParseError, xerrors.Errorf("reading request: %w", err))
return
}
if reqSize > s.maxRequestSize {
rpcError(wf, &req, rpcParseError,
// rpcParseError is the closest we have from the standard errors defined
// in [jsonrpc spec](https://www.jsonrpc.org/specification#error_object)
// to report the maximum limit.
xerrors.Errorf("request bigger than maximum %d allowed",
s.maxRequestSize))
return
}

if err := json.NewDecoder(bufferedRequest).Decode(&req); err != nil {
rpcError(wf, &req, rpcParseError, xerrors.Errorf("unmarshaling request: %w", err))
return
}
Expand Down
12 changes: 10 additions & 2 deletions options_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@ import (
type ParamDecoder func(ctx context.Context, json []byte) (reflect.Value, error)

type ServerConfig struct {
paramDecoders map[reflect.Type]ParamDecoder
paramDecoders map[reflect.Type]ParamDecoder
maxRequestSize int64
}

type ServerOption func(c *ServerConfig)

func defaultServerConfig() ServerConfig {
return ServerConfig{
paramDecoders: map[reflect.Type]ParamDecoder{},
paramDecoders: map[reflect.Type]ParamDecoder{},
maxRequestSize: DEFAULT_MAX_REQUEST_SIZE,
}
}

Expand All @@ -24,3 +26,9 @@ func WithParamDecoder(t interface{}, decoder ParamDecoder) ServerOption {
c.paramDecoders[reflect.TypeOf(t).Elem()] = decoder
}
}

func WithMaxRequestSize(max int64) ServerOption {
return func(c *ServerConfig) {
c.maxRequestSize = max
}
}
7 changes: 5 additions & 2 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ type RPCServer struct {
methods map[string]rpcHandler

paramDecoders map[reflect.Type]ParamDecoder

maxRequestSize int64
}

// NewServer creates new RPCServer instance
Expand All @@ -32,8 +34,9 @@ func NewServer(opts ...ServerOption) *RPCServer {
}

return &RPCServer{
methods: map[string]rpcHandler{},
paramDecoders: config.paramDecoders,
methods: map[string]rpcHandler{},
paramDecoders: config.paramDecoders,
maxRequestSize: config.maxRequestSize,
}
}

Expand Down