diff --git a/internal/events/websockets/websocket_connection.go b/internal/events/websockets/websocket_connection.go index 9f69a13f41..79dc29687b 100644 --- a/internal/events/websockets/websocket_connection.go +++ b/internal/events/websockets/websocket_connection.go @@ -223,6 +223,9 @@ func (wc *websocketConnection) protocolError(err error) { } func (wc *websocketConnection) send(msg interface{}) error { + if wc.closed { + return i18n.NewError(wc.ctx, i18n.MsgWSClosed) + } select { case wc.sendMessages <- msg: return nil diff --git a/internal/events/websockets/websockets_test.go b/internal/events/websockets/websockets_test.go index 0ddbd6dd13..59eb97f9a0 100644 --- a/internal/events/websockets/websockets_test.go +++ b/internal/events/websockets/websockets_test.go @@ -606,3 +606,27 @@ func TestDispatchAutoAck(t *testing.T) { assert.NoError(t, err) cbs.AssertExpectations(t) } + +func TestWebsocketSendAfterClose(t *testing.T) { + cbs := &eventsmocks.Callbacks{} + ws, wsc, cancel := newTestWebsockets(t, cbs) + defer cancel() + + subscribedConn := make(chan string, 1) + cbs.On("EphemeralSubscription", + mock.MatchedBy(func(s string) bool { + subscribedConn <- s + return true + }), + "ns1", mock.Anything, mock.Anything).Return(nil) + + err := wsc.Send(context.Background(), []byte(`{"type":"start","namespace":"ns1","ephemeral":true}`)) + assert.NoError(t, err) + + connID := <-subscribedConn + connection := ws.connections[connID] + connection.wsConn.Close() + <-connection.senderDone + err = connection.send(map[string]string{"foo": "bar"}) + assert.Regexp(t, "FF10290", err) +} diff --git a/internal/i18n/en_translations.go b/internal/i18n/en_translations.go index da8818bb20..faae6f8f9f 100644 --- a/internal/i18n/en_translations.go +++ b/internal/i18n/en_translations.go @@ -207,4 +207,5 @@ var ( MsgInvalidMessageType = ffm("FF10287", "Invalid message type - allowed types are %s", 400) MsgNoUUID = ffm("FF10288", "Field '%s' must not be a UUID", 400) MsgFetchDataDesc = ffm("FF10289", "Fetch the data and include it in the messages returned", 400) + MsgWSClosed = ffm("FF10290", "Websocket closed") )