Skip to content

Commit

Permalink
feat: Add peerbroker proxy for agent connections (#349)
Browse files Browse the repository at this point in the history
* feat: Add peerbroker proxy for agent connections

Agents will connect using this proxy. Eventually we'll intercept
some of these messages for validation, but that's not necessary right now.

* Add ASCII chart
  • Loading branch information
kylecarbs committed Feb 23, 2022
1 parent a053fe8 commit b58e168
Show file tree
Hide file tree
Showing 2 changed files with 341 additions and 0 deletions.
260 changes: 260 additions & 0 deletions peerbroker/proxy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,260 @@
package peerbroker

import (
"context"
"errors"
"fmt"
"io"
"net"
"sync"

"github.com/google/uuid"
"github.com/hashicorp/yamux"
"golang.org/x/xerrors"
protobuf "google.golang.org/protobuf/proto"
"storj.io/drpc/drpcmux"
"storj.io/drpc/drpcserver"

"cdr.dev/slog"
"github.com/coder/coder/database"
"github.com/coder/coder/peerbroker/proto"
)

var (
// Each NegotiateConnection() function call spawns a new stream.
streamIDLength = len(uuid.NewString())
// We shouldn't PubSub anything larger than this!
maxPayloadSizeBytes = 8192
)

// ProxyOptions provides values to configure a proxy.
type ProxyOptions struct {
ChannelID string
Logger slog.Logger
Pubsub database.Pubsub
}

// ProxyDial writes client negotiation streams over PubSub.
//
// PubSub is used to geodistribute WebRTC handshakes. All negotiation
// messages are small in size (<=8KB), and we don't require delivery
// guarantees because connections can always be renegotiated.
// ┌────────────────────┐ ┌─────────────────────────────┐
// │ coderd │ │ coderd │
// ┌─────────────────────┐ │/<agent-id>/connect │ │ /<agent-id>/listen │
// │ client │ │ │ │ │ ┌─────┐
// │ ├──►│Creates a stream ID │◄─►│Subscribe() to the <agent-id>│◄──┤agent│
// │NegotiateConnection()│ │and Publish() to the│ │channel. Parse the stream ID │ └─────┘
// └─────────────────────┘ │<agent-id> channel: │ │from payloads to create new │
// │ │ │NegotiateConnection() streams│
// │<stream-id><payload>│ │or write to existing ones. │
// └────────────────────┘ └─────────────────────────────┘
func ProxyDial(client proto.DRPCPeerBrokerClient, options ProxyOptions) (io.Closer, error) {
proxyDial := &proxyDial{
channelID: options.ChannelID,
logger: options.Logger,
pubsub: options.Pubsub,
connection: client,
streams: make(map[string]proto.DRPCPeerBroker_NegotiateConnectionClient),
}
return proxyDial, proxyDial.listen()
}

// ProxyListen accepts client negotiation streams over PubSub and writes them to the listener
// as new NegotiateConnection() streams.
func ProxyListen(ctx context.Context, connListener net.Listener, options ProxyOptions) error {
mux := drpcmux.New()
err := proto.DRPCRegisterPeerBroker(mux, &proxyListen{
channelID: options.ChannelID,
pubsub: options.Pubsub,
logger: options.Logger,
})
if err != nil {
return xerrors.Errorf("register peer broker: %w", err)
}
server := drpcserver.New(mux)
err = server.Serve(ctx, connListener)
if err != nil {
if errors.Is(err, yamux.ErrSessionShutdown) {
return nil
}
return xerrors.Errorf("serve: %w", err)
}
return nil
}

type proxyListen struct {
channelID string
pubsub database.Pubsub
logger slog.Logger
}

func (p *proxyListen) NegotiateConnection(stream proto.DRPCPeerBroker_NegotiateConnectionStream) error {
streamID := uuid.NewString()
var err error
closeSubscribe, err := p.pubsub.Subscribe(proxyInID(p.channelID), func(ctx context.Context, message []byte) {
err := p.onServerToClientMessage(streamID, stream, message)
if err != nil {
p.logger.Debug(ctx, "failed to accept server message", slog.Error(err))
}
})
if err != nil {
return xerrors.Errorf("subscribe: %w", err)
}
defer closeSubscribe()
for {
clientToServerMessage, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
break
}
return xerrors.Errorf("recv: %w", err)
}
data, err := protobuf.Marshal(clientToServerMessage)
if err != nil {
return xerrors.Errorf("marshal: %w", err)
}
if len(data) > maxPayloadSizeBytes {
return xerrors.Errorf("maximum payload size %d exceeded", maxPayloadSizeBytes)
}
data = append([]byte(streamID), data...)
err = p.pubsub.Publish(proxyOutID(p.channelID), data)
if err != nil {
return xerrors.Errorf("publish: %w", err)
}
}
return nil
}

func (*proxyListen) onServerToClientMessage(streamID string, stream proto.DRPCPeerBroker_NegotiateConnectionStream, message []byte) error {
if len(message) < streamIDLength {
return xerrors.Errorf("got message length %d < %d", len(message), streamIDLength)
}
serverStreamID := string(message[0:streamIDLength])
if serverStreamID != streamID {
// It's not trying to communicate with this stream!
return nil
}
var msg proto.NegotiateConnection_ServerToClient
err := protobuf.Unmarshal(message[streamIDLength:], &msg)
if err != nil {
return xerrors.Errorf("unmarshal message: %w", err)
}
err = stream.Send(&msg)
if err != nil {
return xerrors.Errorf("send message: %w", err)
}
return nil
}

type proxyDial struct {
channelID string
pubsub database.Pubsub
logger slog.Logger

connection proto.DRPCPeerBrokerClient
closeSubscribe func()
streamMutex sync.Mutex
streams map[string]proto.DRPCPeerBroker_NegotiateConnectionClient
}

func (p *proxyDial) listen() error {
var err error
p.closeSubscribe, err = p.pubsub.Subscribe(proxyOutID(p.channelID), func(ctx context.Context, message []byte) {
err := p.onClientToServerMessage(ctx, message)
if err != nil {
p.logger.Debug(ctx, "failed to accept client message", slog.Error(err))
}
})
if err != nil {
return err
}
return nil
}

func (p *proxyDial) onClientToServerMessage(ctx context.Context, message []byte) error {
if len(message) < streamIDLength {
return xerrors.Errorf("got message length %d < %d", len(message), streamIDLength)
}
var err error
streamID := string(message[0:streamIDLength])
p.streamMutex.Lock()
stream, ok := p.streams[streamID]
if !ok {
stream, err = p.connection.NegotiateConnection(ctx)
if err != nil {
p.streamMutex.Unlock()
return xerrors.Errorf("negotiate connection: %w", err)
}
p.streams[streamID] = stream
go func() {
defer stream.Close()

err = p.onServerToClientMessage(streamID, stream)
if err != nil {
p.logger.Debug(ctx, "failed to accept server message", slog.Error(err))
}
}()
go func() {
<-stream.Context().Done()
p.streamMutex.Lock()
delete(p.streams, streamID)
p.streamMutex.Unlock()
}()
}
p.streamMutex.Unlock()

var msg proto.NegotiateConnection_ClientToServer
err = protobuf.Unmarshal(message[streamIDLength:], &msg)
if err != nil {
return xerrors.Errorf("unmarshal message: %w", err)
}
err = stream.Send(&msg)
if err != nil {
return xerrors.Errorf("write message: %w", err)
}
return nil
}

func (p *proxyDial) onServerToClientMessage(streamID string, stream proto.DRPCPeerBroker_NegotiateConnectionClient) error {
for {
serverToClientMessage, err := stream.Recv()
if err != nil {
if errors.Is(err, io.EOF) {
break
}
if errors.Is(err, context.Canceled) {
break
}
return xerrors.Errorf("recv: %w", err)
}
data, err := protobuf.Marshal(serverToClientMessage)
if err != nil {
return xerrors.Errorf("marshal: %w", err)
}
if len(data) > maxPayloadSizeBytes {
return xerrors.Errorf("maximum payload size %d exceeded", maxPayloadSizeBytes)
}
data = append([]byte(streamID), data...)
err = p.pubsub.Publish(proxyInID(p.channelID), data)
if err != nil {
return xerrors.Errorf("publish: %w", err)
}
}
return nil
}

func (p *proxyDial) Close() error {
p.streamMutex.Lock()
defer p.streamMutex.Unlock()
p.closeSubscribe()
return nil
}

func proxyOutID(channelID string) string {
return fmt.Sprintf("%s-out", channelID)
}

func proxyInID(channelID string) string {
return fmt.Sprintf("%s-in", channelID)
}
81 changes: 81 additions & 0 deletions peerbroker/proxy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
package peerbroker_test

import (
"context"
"sync"
"testing"

"github.com/pion/webrtc/v3"
"github.com/stretchr/testify/require"

"cdr.dev/slog"
"cdr.dev/slog/sloggers/slogtest"
"github.com/coder/coder/database"
"github.com/coder/coder/peer"
"github.com/coder/coder/peerbroker"
"github.com/coder/coder/peerbroker/proto"
"github.com/coder/coder/provisionersdk"
)

func TestProxy(t *testing.T) {
t.Parallel()
ctx := context.Background()
channelID := "hello"
pubsub := database.NewPubsubInMemory()
dialerClient, dialerServer := provisionersdk.TransportPipe()
defer dialerClient.Close()
defer dialerServer.Close()
listenerClient, listenerServer := provisionersdk.TransportPipe()
defer listenerClient.Close()
defer listenerServer.Close()

listener, err := peerbroker.Listen(listenerServer, &peer.ConnOptions{
Logger: slogtest.Make(t, nil).Named("server").Leveled(slog.LevelDebug),
})
require.NoError(t, err)

proxyCloser, err := peerbroker.ProxyDial(proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(listenerClient)), peerbroker.ProxyOptions{
ChannelID: channelID,
Logger: slogtest.Make(t, nil).Named("proxy-listen").Leveled(slog.LevelDebug),
Pubsub: pubsub,
})
require.NoError(t, err)
t.Cleanup(func() {
_ = proxyCloser.Close()
})

var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
err = peerbroker.ProxyListen(ctx, dialerServer, peerbroker.ProxyOptions{
ChannelID: channelID,
Logger: slogtest.Make(t, nil).Named("proxy-dial").Leveled(slog.LevelDebug),
Pubsub: pubsub,
})
require.NoError(t, err)
}()

api := proto.NewDRPCPeerBrokerClient(provisionersdk.Conn(dialerClient))
stream, err := api.NegotiateConnection(ctx)
require.NoError(t, err)
clientConn, err := peerbroker.Dial(stream, []webrtc.ICEServer{{
URLs: []string{"stun:stun.l.google.com:19302"},
}}, &peer.ConnOptions{
Logger: slogtest.Make(t, nil).Named("client").Leveled(slog.LevelDebug),
})
require.NoError(t, err)
defer clientConn.Close()

serverConn, err := listener.Accept()
require.NoError(t, err)
defer serverConn.Close()
_, err = serverConn.Ping()
require.NoError(t, err)

_, err = clientConn.Ping()
require.NoError(t, err)

_ = dialerServer.Close()
wg.Wait()
}

0 comments on commit b58e168

Please sign in to comment.