Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow WebsocketInitFunc to add payload to Ack #4

Merged
merged 3 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions _examples/websocket-initfunc/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@ import (
"github.com/rs/cors"
)

func webSocketInit(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
func webSocketInit(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
// Get the token from payload
payload := initPayload["authToken"]
token, ok := payload.(string)
if !ok || token == "" {
return nil, errors.New("authToken not found in transport payload")
return nil, nil, errors.New("authToken not found in transport payload")
}

// Perform token verification and authentication...
Expand All @@ -32,7 +32,7 @@ func webSocketInit(ctx context.Context, initPayload transport.InitPayload) (cont
// put it in context
ctxNew := context.WithValue(ctx, "username", userId)

return ctxNew, nil
return ctxNew, nil, nil
}

const defaultPort = "8080"
Expand Down Expand Up @@ -62,7 +62,7 @@ func main() {
return true
},
},
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (*transport.InitPayload, context.Context, error) {
return webSocketInit(ctx, initPayload)
},
})
Expand Down
16 changes: 13 additions & 3 deletions graphql/handler/transport/websocket.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ type (
initPayload InitPayload
}

WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, error)
WebsocketInitFunc func(ctx context.Context, initPayload InitPayload) (context.Context, *InitPayload, error)
WebsocketErrorFunc func(ctx context.Context, err error)

// Callback called when websocket is closed.
Expand Down Expand Up @@ -179,8 +179,10 @@ func (c *wsConnection) init() bool {
}
}

var initAckPayload *InitPayload = nil
telemenar marked this conversation as resolved.
Show resolved Hide resolved
if c.InitFunc != nil {
ctx, err := c.InitFunc(c.ctx, c.initPayload)
var ctx context.Context
ctx, initAckPayload, err = c.InitFunc(c.ctx, c.initPayload)
if err != nil {
c.sendConnectionError(err.Error())
c.close(websocket.CloseNormalClosure, "terminated")
Expand All @@ -189,7 +191,15 @@ func (c *wsConnection) init() bool {
c.ctx = ctx
}

c.write(&message{t: connectionAckMessageType})
if initAckPayload != nil {
initJsonAckPayload, err := json.Marshal(*initAckPayload)
if err != nil {
panic(err)
telemenar marked this conversation as resolved.
Show resolved Hide resolved
}
c.write(&message{t: connectionAckMessageType, payload: initJsonAckPayload})
} else {
c.write(&message{t: connectionAckMessageType})
}
c.write(&message{t: keepAliveMessageType})
case connectionCloseMessageType:
c.close(websocket.CloseNormalClosure, "terminated")
Expand Down
57 changes: 42 additions & 15 deletions graphql/handler/transport/websocket_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,8 @@ func TestWebsocketInitFunc(t *testing.T) {
t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
h := testserver.New()
h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil, nil
},
})
srv := httptest.NewServer(h)
Expand All @@ -226,8 +226,8 @@ func TestWebsocketInitFunc(t *testing.T) {
t.Run("reject connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
h := testserver.New()
h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
return ctx, errors.New("invalid init payload")
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
return ctx, nil, errors.New("invalid init payload")
},
})
srv := httptest.NewServer(h)
Expand Down Expand Up @@ -261,8 +261,8 @@ func TestWebsocketInitFunc(t *testing.T) {
h := handler.New(es)

h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, error) {
return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
return context.WithValue(ctx, ckey("newkey"), "newvalue"), nil, nil
},
})

Expand All @@ -282,7 +282,7 @@ func TestWebsocketInitFunc(t *testing.T) {
h := testserver.New()
var cancel func()
h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ error) {
InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ *transport.InitPayload, _ error) {
newCtx, cancel = context.WithTimeout(transport.AppendCloseReason(ctx, "beep boop"), time.Millisecond*5)
return
},
Expand All @@ -303,6 +303,33 @@ func TestWebsocketInitFunc(t *testing.T) {
assert.Equal(t, m.Type, connectionErrorMsg)
assert.Equal(t, string(m.Payload), `{"message":"beep boop"}`)
})
t.Run("accept connection if WebsocketInitFunc is provided and is accepting connection", func(t *testing.T) {
h := testserver.New()
h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
initResponsePayload := transport.InitPayload{"trackingId": "123-456"}
return context.WithValue(ctx, ckey("newkey"), "newvalue"), &initResponsePayload, nil
},
})
srv := httptest.NewServer(h)
defer srv.Close()

c := wsConnect(srv.URL)
defer c.Close()

require.NoError(t, c.WriteJSON(&operationMessage{Type: connectionInitMsg}))

connAck := readOp(c)
assert.Equal(t, connectionAckMsg, connAck.Type)

var payload map[string]interface{}
err := json.Unmarshal(connAck.Payload, &payload)
if err != nil {
t.Fatal("Unexpected Error", err)
}
assert.EqualValues(t, "123-456", payload["trackingId"])
assert.Equal(t, connectionKeepAliveMsg, readOp(c).Type)
})
}

func TestWebSocketInitTimeout(t *testing.T) {
Expand Down Expand Up @@ -382,8 +409,8 @@ func TestWebSocketErrorFunc(t *testing.T) {
t.Run("init func errors do not call the error handler", func(t *testing.T) {
h := testserver.New()
h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, error) {
return ctx, errors.New("this is not what we agreed upon")
InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, *transport.InitPayload, error) {
return ctx, nil, errors.New("this is not what we agreed upon")
},
ErrorFunc: func(_ context.Context, err error) {
assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error())
Expand All @@ -400,10 +427,10 @@ func TestWebSocketErrorFunc(t *testing.T) {
t.Run("init func context closes do not call the error handler", func(t *testing.T) {
h := testserver.New()
h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, error) {
InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, *transport.InitPayload, error) {
newCtx, cancel := context.WithCancel(ctx)
time.AfterFunc(time.Millisecond*5, cancel)
return newCtx, nil
return newCtx, nil, nil
},
ErrorFunc: func(_ context.Context, err error) {
assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error())
Expand All @@ -423,9 +450,9 @@ func TestWebSocketErrorFunc(t *testing.T) {
h := testserver.New()
var cancel func()
h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ error) {
InitFunc: func(ctx context.Context, _ transport.InitPayload) (newCtx context.Context, _ *transport.InitPayload, _ error) {
newCtx, cancel = context.WithDeadline(ctx, time.Now().Add(time.Millisecond*5))
return newCtx, nil
return newCtx, nil, nil
},
ErrorFunc: func(_ context.Context, err error) {
assert.Fail(t, "the error handler got called when it shouldn't have", "error: "+err.Error())
Expand Down Expand Up @@ -477,8 +504,8 @@ func TestWebSocketCloseFunc(t *testing.T) {
h := testserver.New()
closeFuncCalled := make(chan bool, 1)
h.AddTransport(transport.Websocket{
InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, error) {
return ctx, errors.New("error during init")
InitFunc: func(ctx context.Context, _ transport.InitPayload) (context.Context, *transport.InitPayload, error) {
return ctx, nil, errors.New("error during init")
},
CloseFunc: func(_ context.Context, _closeCode int) {
closeFuncCalled <- true
Expand Down
Loading