Skip to content

Commit

Permalink
Allow WebsocketInitFunc to add payload to Ack
Browse files Browse the repository at this point in the history
The connection ACK message in the protocol for both
graphql-ws and graphql-transport-ws allows for a payload in the
connection ack message.

We really wanted to use this to establish better telemetry in our use of
websockets in graphql.
  • Loading branch information
Chris Pride committed Sep 8, 2023
1 parent ccae370 commit 12afb29
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 22 deletions.
8 changes: 4 additions & 4 deletions _examples/websocket-initfunc/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ import (
"github.com/rs/cors"
)

func webSocketInit(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
func webSocketInit(ctx context.Context, initPayload transport.InitPayload) (*transport.InitPayload, context.Context, error) {
// Get the token from payload
payload := initPayload["authToken"]
token, ok := payload.(string)
if !ok || token == "" {
return nil, errors.New("authToken not found in transport payload")
return nil, nil, errors.New("authToken not found in transport payload")
}

// Perform token verification and authentication...
Expand All @@ -32,7 +32,7 @@ func webSocketInit(ctx context.Context, initPayload transport.InitPayload) (cont
// put it in context
ctxNew := context.WithValue(ctx, "username", userId)

return ctxNew, nil
return nil, ctxNew, nil
}

const defaultPort = "8080"
Expand Down Expand Up @@ -62,7 +62,7 @@ func main() {
return true
},
},
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (*transport.InitPayload, context.Context, error) {
return webSocketInit(ctx, initPayload)
},
})
Expand Down
16 changes: 13 additions & 3 deletions graphql/handler/transport/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ type (
initPayload InitPayload
}

WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error)
WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (*InitPayload, context.Context, error)
WebsocketErrorFunc func(ctx context.Context, err error)

// Callback called when websocket is closed.
Expand Down Expand Up @@ -179,8 +179,10 @@ func (c *wsConnection) init() bool {
}
}

var initAckPayload *InitPayload = nil
if c.InitFunc != nil {
ctx, err := c.InitFunc(c.ctx, c.initPayload)
var ctx context.Context
initAckPayload, ctx, err = c.InitFunc(c.ctx, c.initPayload)
if err != nil {
c.sendConnectionError(err.Error())
c.close(websocket.CloseNormalClosure, "terminated")
Expand All @@ -189,7 +191,15 @@ func (c *wsConnection) init() bool {
c.ctx = ctx
}

c.write(&message{t: connectionAckMessageType})
if initAckPayload != nil {
initJsonAckPayload, err := json.Marshal(*initAckPayload)
if err != nil {
panic(err)
}
c.write(&message{t: connectionAckMessageType, payload: initJsonAckPayload})
} else {
c.write(&message{t: connectionAckMessageType})
}
c.write(&message{t: keepAliveMessageType})
case connectionCloseMessageType:
c.close(websocket.CloseNormalClosure, "terminated")
Expand Down
54 changes: 39 additions & 15 deletions graphql/handler/transport/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ func TestWebsocketInitFunc(t *testing.T) {
t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
h := testserver.New()
h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (*transport.InitPayload, context.Context, error) {
return nil, context.WithValue(ctx, ckey("newkey"), "newvalue"), nil
},
})
srv := httptest.NewServer(h)
Expand All @@ -226,8 +226,8 @@ func TestWebsocketInitFunc(t *testing.T) {
t.Run("reject connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
h := testserver.New()
h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
return ctx, errors.New("invalid init payload")
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (*transport.InitPayload, context.Context, error) {
return nil, ctx, errors.New("invalid init payload")
},
})
srv := httptest.NewServer(h)
Expand Down Expand Up @@ -261,8 +261,8 @@ func TestWebsocketInitFunc(t *testing.T) {
h := handler.New(es)

h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (*transport.InitPayload, context.Context, error) {
return nil, context.WithValue(ctx, ckey("newkey"), "newvalue"), nil
},
})

Expand All @@ -282,7 +282,7 @@ func TestWebsocketInitFunc(t *testing.T) {
h := testserver.New()
var cancel func()
h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ error) {
InitFunc: func(ctx context.Context, _ transport.InitPayload) (_ *transport.InitPayload, newCtx context.Context, _ error) {
newCtx, cancel = context.WithTimeout(transport.AppendCloseReason(ctx, "beep boop"), time.Millisecond*5)
return
},
Expand All @@ -303,6 +303,30 @@ func TestWebsocketInitFunc(t *testing.T) {
assert.Equal(t, m.Type, connectionErrorMsg)
assert.Equal(t, string(m.Payload), `{"message":"beep boop"}`)
})
t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
h := testserver.New()
h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (*transport.InitPayload, context.Context, error) {
initResponsePayload := transport.InitPayload{"trackingId": "123-456"}
return &initResponsePayload, context.WithValue(ctx, ckey("newkey"), "newvalue"), nil
},
})
srv := httptest.NewServer(h)
defer srv.Close()

c := wsConnect(srv.URL)
defer c.Close()

require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))

connAck := readOp(c)
assert.Equal(t, connectionAckMsg, connAck.Type)

var payload map[string]interface{}
json.Unmarshal(connAck.Payload, &payload)

Check failure on line 326 in graphql/handler/transport/websocket_test.go

View workflow job for this annotation

GitHub Actions / golangci-lint (1.19)

Error return value of `json.Unmarshal` is not checked (errcheck)

Check failure on line 326 in graphql/handler/transport/websocket_test.go

View workflow job for this annotation

GitHub Actions / golangci-lint (1.19)

Error return value of `json.Unmarshal` is not checked (errcheck)
assert.EqualValues(t, "123-456", payload["trackingId"])
assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
})
}

func TestWebSocketInitTimeout(t *testing.T) {
Expand Down Expand Up @@ -382,8 +406,8 @@ func TestWebSocketErrorFunc(t *testing.T) {
t.Run("init func errors do not call the error handler", func(t *testing.T) {
h := testserver.New()
h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, error) {
return ctx, errors.New("this is not what we agreed upon")
InitFunc: func(ctx context.Context, _ transport.InitPayload) (*transport.InitPayload, context.Context, error) {
return nil, ctx, errors.New("this is not what we agreed upon")
},
ErrorFunc: func(_ context.Context, err error) {
assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error())
Expand All @@ -400,10 +424,10 @@ func TestWebSocketErrorFunc(t *testing.T) {
t.Run("init func context closes do not call the error handler", func(t *testing.T) {
h := testserver.New()
h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, error) {
InitFunc: func(ctx context.Context, _ transport.InitPayload) (*transport.InitPayload, context.Context, error) {
newCtx, cancel := context.WithCancel(ctx)
time.AfterFunc(time.Millisecond*5, cancel)
return newCtx, nil
return nil, newCtx, nil
},
ErrorFunc: func(_ context.Context, err error) {
assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error())
Expand All @@ -423,9 +447,9 @@ func TestWebSocketErrorFunc(t *testing.T) {
h := testserver.New()
var cancel func()
h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ error) {
InitFunc: func(ctx context.Context, _ transport.InitPayload) (_ *transport.InitPayload, newCtx context.Context, _ error) {
newCtx, cancel = context.WithDeadline(ctx, time.Now().Add(time.Millisecond*5))
return newCtx, nil
return nil, newCtx, nil
},
ErrorFunc: func(_ context.Context, err error) {
assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error())
Expand Down Expand Up @@ -477,8 +501,8 @@ func TestWebSocketCloseFunc(t *testing.T) {
h := testserver.New()
closeFuncCalled := make(chan bool, 1)
h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, error) {
return ctx, errors.New("error during init")
InitFunc: func(ctx context.Context, _ transport.InitPayload) (*transport.InitPayload, context.Context, error) {
return nil, ctx, errors.New("error during init")
},
CloseFunc: func(_ context.Context, _closeCode int) {
closeFuncCalled <- true
Expand Down

0 comments on commit 12afb29

Please sign in to comment.