Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 26 additions & 19 deletions pkg/autoscaler/statserver/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ package statserver
import (
"bytes"
"context"
"encoding/gob"
"encoding/json"
"net"
"net/http"
Expand Down Expand Up @@ -156,27 +155,40 @@ func (s *Server) Handler(w http.ResponseWriter, r *http.Request) {
return
}

// we accept either GOB-encoded or JSON-encoded messages depending on the
// We accept either protobuf-encoded or JSON-encoded messages depending on the
// message type to ensure safe upgrades.
var dec decoder
switch messageType {
case websocket.BinaryMessage:
dec = gob.NewDecoder(bytes.NewBuffer(msg))
var wsms metrics.WireStatMessages
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

kind of a shame we can't reuse this variable between iterations, but I don't know if it even matters since it's on the stack anyway. probably not 🤷.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think unmarshal call escapes it anyway. But there is Reset standard proto method, so it can be reused, but I am not sure if is helpful from performance standpoint.

if err := wsms.Unmarshal(msg); err != nil {
s.logger.Errorw("Failed to unmarshal the object", zap.Error(err))
continue
}

for _, wsm := range wsms.Messages {
if wsm.Stat == nil {
// To allow for future protobuf schema changes.
continue
}

sm := wsm.ToStatMessage()
s.logger.Debugf("Received stat message: %+v", sm)
s.statsCh <- sm
}
case websocket.TextMessage:
dec = json.NewDecoder(bytes.NewBuffer(msg))
dec := json.NewDecoder(bytes.NewBuffer(msg))
var sm metrics.StatMessage
if err = dec.Decode(&sm); err != nil {
s.logger.Errorw("Failed to decode json", zap.Error(err))
continue
}

s.logger.Debugf("Received stat message: %+v", sm)
s.statsCh <- sm
default:
s.logger.Error("Dropping unknown message type.")
continue
}

var sm metrics.StatMessage
if err = dec.Decode(&sm); err != nil {
s.logger.Error(err)
continue
}

s.logger.Debugf("Received stat message: %+v", sm)
s.statsCh <- sm
}
}

Expand Down Expand Up @@ -211,8 +223,3 @@ func (s *Server) Shutdown(timeout time.Duration) {
s.logger.Warn("Shutdown timed out")
}
}

// decoder is the interface implemented by json.Decoder and gob.Decoder
type decoder interface {
Decode(interface{}) error
}
107 changes: 69 additions & 38 deletions pkg/autoscaler/statserver/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package statserver

import (
"bytes"
"encoding/gob"
"encoding/json"
"fmt"
"net"
Expand Down Expand Up @@ -55,6 +54,7 @@ var (
RequestCount: 30,
},
}
both = []metrics.StatMessage{msg1, msg2}
)

func TestServerLifecycle(t *testing.T) {
Expand Down Expand Up @@ -103,13 +103,12 @@ func TestStatsReceived(t *testing.T) {

statSink := dialOK(t, server.listenAddr())

// gob encoding
assertReceivedOK(t, msg1, statSink, statsCh, false)
assertReceivedOK(t, msg2, statSink, statsCh, false)
// protobuf
assertReceivedProto(t, both, statSink, statsCh)

// json encoding
assertReceivedOK(t, msg1, statSink, statsCh, true)
assertReceivedOK(t, msg2, statSink, statsCh, true)
assertReceivedJSON(t, msg1, statSink, statsCh)
assertReceivedJSON(t, msg2, statSink, statsCh)

closeSink(t, statSink)
}
Expand All @@ -123,14 +122,14 @@ func TestServerShutdown(t *testing.T) {
listenAddr := server.listenAddr()
statSink := dialOK(t, listenAddr)

assertReceivedOK(t, msg1, statSink, statsCh, false)
assertReceivedProto(t, both, statSink, statsCh)

server.Shutdown(time.Second)
// We own the channel.
close(statsCh)

// Send a statistic to the server
if err := send(statSink, msg2, false); err != nil {
if err := sendProto(statSink, both); err != nil {
t.Fatal("Expected send to succeed, got:", err)
}

Expand Down Expand Up @@ -173,7 +172,7 @@ func TestServerDoesNotLeakGoroutines(t *testing.T) {
listenAddr := server.listenAddr()
statSink := dialOK(t, listenAddr)

assertReceivedOK(t, msg1, statSink, statsCh, false)
assertReceivedProto(t, both, statSink, statsCh)

closeSink(t, statSink)

Expand Down Expand Up @@ -209,26 +208,38 @@ func BenchmarkStatServer(b *testing.B) {
msgs = append(msgs, msg1)
}

for encoding, jsonEncoding := range map[string]bool{"json": true, "gob": false} {
b.Run(fmt.Sprintf("%s-encoding-%d-msgs", encoding, len(msgs)), func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, msg := range msgs {
if err := send(statSink, msg, jsonEncoding); err != nil {
b.Fatal("Expected send to succeed, but got:", err)
}
b.Run(fmt.Sprintf("json-encoding-%d-msgs", len(msgs)), func(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, msg := range msgs {
if err := sendJSON(statSink, msg); err != nil {
b.Fatal("Expected send to succeed, but got:", err)
}
}

for range msgs {
<-statsCh
}
for range msgs {
<-statsCh
}
})
}
}
})

b.Run(fmt.Sprintf("proto-encoding-%d-msgs", len(msgs)), func(b *testing.B) {
for i := 0; i < b.N; i++ {
if err := sendProto(statSink, msgs); err != nil {
b.Fatal("Expected send to succeed, but got:", err)
}

for range msgs {
<-statsCh
}
}
})
}
}

func assertReceivedOK(t *testing.T, sm metrics.StatMessage, statSink *websocket.Conn, statsCh <-chan metrics.StatMessage, jsonEncoding bool) {
if err := send(statSink, sm, jsonEncoding); err != nil {
func assertReceivedJSON(t *testing.T, sm metrics.StatMessage, statSink *websocket.Conn, statsCh <-chan metrics.StatMessage) {
t.Helper()

if err := sendJSON(statSink, sm); err != nil {
t.Fatal("Expected send to succeed, got:", err)
}

Expand All @@ -238,7 +249,25 @@ func assertReceivedOK(t *testing.T, sm metrics.StatMessage, statSink *websocket.
}
}

func assertReceivedProto(t *testing.T, sms []metrics.StatMessage, statSink *websocket.Conn, statsCh <-chan metrics.StatMessage) {
t.Helper()

if err := sendProto(statSink, sms); err != nil {
t.Fatal("Expected send to succeed, got:", err)
}

got := make([]metrics.StatMessage, 0, len(sms))
for range sms {
got = append(got, <-statsCh)
}
if !cmp.Equal(sms, got) {
t.Fatalf("StatMessage mismatch: diff (-got, +want) %s", cmp.Diff(got, sms))
}
}

func dialOK(t *testing.T, serverURL string) *websocket.Conn {
t.Helper()

statSink, err := dial(serverURL)
if err != nil {
t.Fatal("Dial failed:", err)
Expand All @@ -260,29 +289,36 @@ func dial(serverURL string) (*websocket.Conn, error) {
return statSink, err
}

func send(statSink *websocket.Conn, sm metrics.StatMessage, jsonEncoding bool) error {
func sendJSON(statSink *websocket.Conn, sm metrics.StatMessage) error {
var b bytes.Buffer
enc := json.NewEncoder(&b)
if err := enc.Encode(sm); err != nil {
return fmt.Errorf("failed to encode StatMessage: %w", err)
}

var enc encoder = gob.NewEncoder(&b)
messageType := websocket.BinaryMessage

if jsonEncoding {
enc = json.NewEncoder(&b)
messageType = websocket.TextMessage
if err := statSink.WriteMessage(websocket.TextMessage, b.Bytes()); err != nil {
return fmt.Errorf("failed to write to stat sink: %w", err)
}
return nil
}

if err := enc.Encode(sm); err != nil {
return fmt.Errorf("failed to encode data from stats channel: %w", err)
func sendProto(statSink *websocket.Conn, sms []metrics.StatMessage) error {
wsms := metrics.ToWireStatMessages(sms)
msg, err := wsms.Marshal()
if err != nil {
return fmt.Errorf("failed to marshal StatMessage: %w", err)
}

if err := statSink.WriteMessage(messageType, b.Bytes()); err != nil {
if err := statSink.WriteMessage(websocket.BinaryMessage, msg); err != nil {
return fmt.Errorf("failed to write to stat sink: %w", err)
}

return nil
}

func closeSink(t *testing.T, statSink *websocket.Conn) {
t.Helper()

if err := statSink.Close(); err != nil {
t.Fatal("Failed to close", err)
}
Expand Down Expand Up @@ -324,8 +360,3 @@ func (t *testListener) Accept() (net.Conn, error) {
t.listenAddr <- "http://" + t.Listener.Addr().String()
return t.Listener.Accept()
}

// encoder is the interface implemented by gob.Encoder and json.Encoder
type encoder interface {
Encode(interface{}) error
}