Skip to content

Commit a820457

Browse files
authored
Make client requests type safe, unmarshal (#2099)
1 parent 645ba45 commit a820457

File tree

6 files changed

+237
-41
lines changed

6 files changed

+237
-41
lines changed

internal/fourslash/fourslash.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -353,8 +353,7 @@ func getCapabilitiesWithDefaults(capabilities *lsproto.ClientCapabilities) *lspr
353353

354354
func sendRequest[Params, Resp any](t *testing.T, f *FourslashTest, info lsproto.RequestInfo[Params, Resp], params Params) (*lsproto.Message, Resp, bool) {
355355
id := f.nextID()
356-
req := lsproto.NewRequestMessage(
357-
info.Method,
356+
req := info.NewRequestMessage(
358357
lsproto.NewID(lsproto.IntegerOrString{Integer: &id}),
359358
params,
360359
)
@@ -396,8 +395,7 @@ func sendRequest[Params, Resp any](t *testing.T, f *FourslashTest, info lsproto.
396395
}
397396

398397
func sendNotification[Params any](t *testing.T, f *FourslashTest, info lsproto.NotificationInfo[Params], params Params) {
399-
notification := lsproto.NewNotificationMessage(
400-
info.Method,
398+
notification := info.NewNotificationMessage(
401399
params,
402400
)
403401
f.writeMsg(t, notification.Message())

internal/lsp/lsproto/_generate/generate.mts

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,36 @@ function generateCode() {
784784
writeLine("}");
785785
writeLine("");
786786

787+
// Generate unmarshalResult function
788+
writeLine("func unmarshalResult(method Method, data []byte) (any, error) {");
789+
writeLine("\tswitch method {");
790+
791+
// Only requests have results, not notifications
792+
for (const request of model.requests) {
793+
const methodName = methodNameIdentifier(request.method);
794+
795+
if (!("result" in request)) {
796+
continue;
797+
}
798+
799+
let responseTypeName: string;
800+
if (request.typeName && request.typeName.endsWith("Request")) {
801+
responseTypeName = request.typeName.replace(/Request$/, "Response");
802+
}
803+
else {
804+
responseTypeName = `${methodName}Response`;
805+
}
806+
807+
writeLine(`\tcase Method${methodName}:`);
808+
writeLine(`\t\treturn unmarshalValue[${responseTypeName}](data)`);
809+
}
810+
811+
writeLine("\tdefault:");
812+
writeLine(`\t\treturn unmarshalAny(data)`);
813+
writeLine("\t}");
814+
writeLine("}");
815+
writeLine("");
816+
787817
writeLine("// Methods");
788818
writeLine("const (");
789819
for (const request of requestsAndNotifications) {

internal/lsp/lsproto/jsonrpc.go

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -105,19 +105,20 @@ func (m *Message) UnmarshalJSON(data []byte) error {
105105
Method Method `json:"method"`
106106
ID *ID `json:"id,omitzero"`
107107
Params jsontext.Value `json:"params"`
108-
Result any `json:"result,omitzero"`
109-
Error *ResponseError `json:"error,omitzero"`
108+
// We don't have a method in the response, so we have no idea what to decode.
109+
// Store the raw text and let the caller decode it.
110+
Result jsontext.Value `json:"result,omitzero"`
111+
Error *ResponseError `json:"error,omitzero"`
110112
}
111113
if err := json.Unmarshal(data, &raw); err != nil {
112114
return fmt.Errorf("%w: %w", ErrInvalidRequest, err)
113115
}
114116
if raw.ID != nil && raw.Method == "" {
115117
m.Kind = MessageKindResponse
116118
m.msg = &ResponseMessage{
117-
JSONRPC: raw.JSONRPC,
118-
ID: raw.ID,
119-
Result: raw.Result,
120-
Error: raw.Error,
119+
ID: raw.ID,
120+
Result: raw.Result,
121+
Error: raw.Error,
121122
}
122123
return nil
123124
}
@@ -138,10 +139,9 @@ func (m *Message) UnmarshalJSON(data []byte) error {
138139
}
139140

140141
m.msg = &RequestMessage{
141-
JSONRPC: raw.JSONRPC,
142-
ID: raw.ID,
143-
Method: raw.Method,
144-
Params: params,
142+
ID: raw.ID,
143+
Method: raw.Method,
144+
Params: params,
145145
}
146146

147147
return nil
@@ -151,29 +151,13 @@ func (m *Message) MarshalJSON() ([]byte, error) {
151151
return json.Marshal(m.msg)
152152
}
153153

154-
func NewNotificationMessage(method Method, params any) *RequestMessage {
155-
return &RequestMessage{
156-
JSONRPC: JSONRPCVersion{},
157-
Method: method,
158-
Params: params,
159-
}
160-
}
161-
162154
type RequestMessage struct {
163155
JSONRPC JSONRPCVersion `json:"jsonrpc"`
164156
ID *ID `json:"id,omitzero"`
165157
Method Method `json:"method"`
166158
Params any `json:"params,omitzero"`
167159
}
168160

169-
func NewRequestMessage(method Method, id *ID, params any) *RequestMessage {
170-
return &RequestMessage{
171-
ID: id,
172-
Method: method,
173-
Params: params,
174-
}
175-
}
176-
177161
func (r *RequestMessage) Message() *Message {
178162
return &Message{
179163
Kind: MessageKindRequest,

internal/lsp/lsproto/lsp.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,14 @@ func unmarshalPtrTo[T any](data []byte) (*T, error) {
7171
return &v, nil
7272
}
7373

74+
func unmarshalValue[T any](data []byte) (T, error) {
75+
var v T
76+
if err := json.Unmarshal(data, &v); err != nil {
77+
return *new(T), fmt.Errorf("failed to unmarshal %T: %w", (*T)(nil), err)
78+
}
79+
return v, nil
80+
}
81+
7482
func unmarshalAny(data []byte) (any, error) {
7583
var v any
7684
if err := json.Unmarshal(data, &v); err != nil {
@@ -129,11 +137,43 @@ type RequestInfo[Params, Resp any] struct {
129137
Method Method
130138
}
131139

140+
func (info RequestInfo[Params, Resp]) UnmarshalResult(result any) (Resp, error) {
141+
if r, ok := result.(Resp); ok {
142+
return r, nil
143+
}
144+
145+
raw, ok := result.(jsontext.Value)
146+
if !ok {
147+
return *new(Resp), fmt.Errorf("expected jsontext.Value, got %T", result)
148+
}
149+
150+
r, err := unmarshalResult(info.Method, raw)
151+
if err != nil {
152+
return *new(Resp), err
153+
}
154+
return r.(Resp), nil
155+
}
156+
157+
func (info RequestInfo[Params, Resp]) NewRequestMessage(id *ID, params Params) *RequestMessage {
158+
return &RequestMessage{
159+
ID: id,
160+
Method: info.Method,
161+
Params: params,
162+
}
163+
}
164+
132165
type NotificationInfo[Params any] struct {
133166
_ [0]Params
134167
Method Method
135168
}
136169

170+
func (info NotificationInfo[Params]) NewNotificationMessage(params Params) *RequestMessage {
171+
return &RequestMessage{
172+
Method: info.Method,
173+
Params: params,
174+
}
175+
}
176+
137177
type Null struct{}
138178

139179
func (Null) UnmarshalJSONFrom(dec *jsontext.Decoder) error {

internal/lsp/lsproto/lsp_generated.go

Lines changed: 145 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)