-
Notifications
You must be signed in to change notification settings - Fork 447
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
487 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,176 @@ | ||
package plugin | ||
|
||
import ( | ||
"encoding/binary" | ||
"fmt" | ||
"net" | ||
"sync" | ||
"sync/atomic" | ||
"time" | ||
|
||
"github.com/hashicorp/yamux" | ||
) | ||
|
||
// MuxBroker is responsible for brokering multiplexed connections by unique ID. | ||
// | ||
// It is used by plugins to multi-plex multiple RPC connections and data | ||
// streams on top of a single connection between the plugin process and the | ||
// host process. | ||
// | ||
// This allows a plugin to request a channel with a specific ID to connect to | ||
// or accept a connection from, and the broker handles the details of | ||
// holding these channels open while they're being negotiated. | ||
type MuxBroker struct { | ||
nextId uint32 | ||
session *yamux.Session | ||
streams map[uint32]*muxBrokerPending | ||
|
||
sync.Mutex | ||
} | ||
|
||
type muxBrokerPending struct { | ||
ch chan net.Conn | ||
doneCh chan struct{} | ||
} | ||
|
||
func newMuxBroker(s *yamux.Session) *MuxBroker { | ||
return &MuxBroker{ | ||
session: s, | ||
streams: make(map[uint32]*muxBrokerPending), | ||
} | ||
} | ||
|
||
// Accept accepts a connection by ID. | ||
// | ||
// This should not be called multiple times with the same ID at one time. | ||
func (m *MuxBroker) Accept(id uint32) (net.Conn, error) { | ||
var c net.Conn | ||
p := m.getStream(id) | ||
select { | ||
case c = <-p.ch: | ||
close(p.doneCh) | ||
case <-time.After(5 * time.Second): | ||
m.Lock() | ||
defer m.Unlock() | ||
delete(m.streams, id) | ||
|
||
return nil, fmt.Errorf("timeout waiting for accept") | ||
} | ||
|
||
// Ack our connection | ||
if err := binary.Write(c, binary.LittleEndian, id); err != nil { | ||
c.Close() | ||
return nil, err | ||
} | ||
|
||
return c, nil | ||
} | ||
|
||
// Close closes the connection and all sub-connections. | ||
func (m *MuxBroker) Close() error { | ||
return m.session.Close() | ||
} | ||
|
||
// Dial opens a connection by ID. | ||
func (m *MuxBroker) Dial(id uint32) (net.Conn, error) { | ||
// Open the stream | ||
stream, err := m.session.OpenStream() | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
// Write the stream ID onto the wire. | ||
if err := binary.Write(stream, binary.LittleEndian, id); err != nil { | ||
stream.Close() | ||
return nil, err | ||
} | ||
|
||
// Read the ack that we connected. Then we're off! | ||
var ack uint32 | ||
if err := binary.Read(stream, binary.LittleEndian, &ack); err != nil { | ||
stream.Close() | ||
return nil, err | ||
} | ||
if ack != id { | ||
stream.Close() | ||
return nil, fmt.Errorf("bad ack: %d (expected %d)", ack, id) | ||
} | ||
|
||
return stream, nil | ||
} | ||
|
||
// NextId returns a unique ID to use next. | ||
func (m *MuxBroker) NextId() uint32 { | ||
return atomic.AddUint32(&m.nextId, 1) | ||
} | ||
|
||
// Run starts the brokering and should be executed in a goroutine, since it | ||
// blocks forever, or until the session closes. | ||
func (m *MuxBroker) Run() { | ||
for { | ||
stream, err := m.session.AcceptStream() | ||
if err != nil { | ||
// Once we receive an error, just exit | ||
break | ||
} | ||
|
||
// Read the stream ID from the stream | ||
var id uint32 | ||
if err := binary.Read(stream, binary.LittleEndian, &id); err != nil { | ||
stream.Close() | ||
continue | ||
} | ||
|
||
// Initialize the waiter | ||
p := m.getStream(id) | ||
select { | ||
case p.ch <- stream: | ||
default: | ||
} | ||
|
||
// Wait for a timeout | ||
go m.timeoutWait(id, p) | ||
} | ||
} | ||
|
||
func (m *MuxBroker) getStream(id uint32) *muxBrokerPending { | ||
m.Lock() | ||
defer m.Unlock() | ||
|
||
p, ok := m.streams[id] | ||
if ok { | ||
return p | ||
} | ||
|
||
m.streams[id] = &muxBrokerPending{ | ||
ch: make(chan net.Conn, 1), | ||
doneCh: make(chan struct{}), | ||
} | ||
return m.streams[id] | ||
} | ||
|
||
func (m *MuxBroker) timeoutWait(id uint32, p *muxBrokerPending) { | ||
// Wait for the stream to either be picked up and connected, or | ||
// for a timeout. | ||
timeout := false | ||
select { | ||
case <-p.doneCh: | ||
case <-time.After(5 * time.Second): | ||
timeout = true | ||
} | ||
|
||
m.Lock() | ||
defer m.Unlock() | ||
|
||
// Delete the stream so no one else can grab it | ||
delete(m.streams, id) | ||
|
||
// If we timed out, then check if we have a channel in the buffer, | ||
// and if so, close it. | ||
if timeout { | ||
select { | ||
case s := <-p.ch: | ||
s.Close() | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
package plugin | ||
|
||
import ( | ||
"fmt" | ||
"io" | ||
"net" | ||
"net/rpc" | ||
|
||
"github.com/hashicorp/yamux" | ||
) | ||
|
||
// RPCClient connects to an RPCServer over net/rpc to dispense plugin types. | ||
type RPCClient struct { | ||
broker *MuxBroker | ||
control *rpc.Client | ||
plugins map[string]Plugin | ||
|
||
// These are the streams used for the various stdout/err overrides | ||
stdout, stderr net.Conn | ||
} | ||
|
||
// Dial opens a connection to an RPC server and returns a client. | ||
func Dial(network, address string, plugins map[string]Plugin) (*RPCClient, error) { | ||
conn, err := net.Dial(network, address) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
if tcpConn, ok := conn.(*net.TCPConn); ok { | ||
// Make sure to set keep alive so that the connection doesn't die | ||
tcpConn.SetKeepAlive(true) | ||
} | ||
|
||
return NewRPCClient(conn, plugins) | ||
} | ||
|
||
// NewRPCClient creates a client from an already-open connection-like value. | ||
// Dial is typically used instead. | ||
func NewRPCClient(conn io.ReadWriteCloser, plugins map[string]Plugin) (*RPCClient, error) { | ||
// Create the yamux client so we can multiplex | ||
mux, err := yamux.Client(conn, nil) | ||
if err != nil { | ||
conn.Close() | ||
return nil, err | ||
} | ||
|
||
// Connect to the control stream. | ||
control, err := mux.Open() | ||
if err != nil { | ||
mux.Close() | ||
return nil, err | ||
} | ||
|
||
// Connect stdout, stderr streams | ||
stdstream := make([]net.Conn, 2) | ||
for i, _ := range stdstream { | ||
stdstream[i], err = mux.Open() | ||
if err != nil { | ||
mux.Close() | ||
return nil, err | ||
} | ||
} | ||
|
||
// Create the broker and start it up | ||
broker := newMuxBroker(mux) | ||
go broker.Run() | ||
|
||
// Build the client using our broker and control channel. | ||
return &RPCClient{ | ||
broker: broker, | ||
control: rpc.NewClient(control), | ||
plugins: plugins, | ||
stdout: stdstream[0], | ||
stderr: stdstream[1], | ||
}, nil | ||
} | ||
|
||
// SyncStreams should be called to enable syncing of stdout, | ||
// stderr with the plugin. | ||
// | ||
// This will return immediately and the syncing will continue to happen | ||
// in the background. You do not need to launch this in a goroutine itself. | ||
// | ||
// This should never be called multiple times. | ||
func (c *RPCClient) SyncStreams(stdout io.Writer, stderr io.Writer) error { | ||
go copyStream("stdout", stdout, c.stdout) | ||
go copyStream("stderr", stderr, c.stderr) | ||
return nil | ||
} | ||
|
||
// Close closes the connection. The client is no longer usable after this | ||
// is called. | ||
func (c *RPCClient) Close() error { | ||
if err := c.control.Close(); err != nil { | ||
return err | ||
} | ||
if err := c.stdout.Close(); err != nil { | ||
return err | ||
} | ||
if err := c.stderr.Close(); err != nil { | ||
return err | ||
} | ||
|
||
return c.broker.Close() | ||
} | ||
|
||
func (c *RPCClient) Dispense(name string) (interface{}, error) { | ||
p, ok := c.plugins[name] | ||
if !ok { | ||
return nil, fmt.Errorf("unknown plugin type: %s", name) | ||
} | ||
|
||
var id uint32 | ||
if err := c.control.Call( | ||
"Dispenser.Dispense", name, &id); err != nil { | ||
return nil, err | ||
} | ||
|
||
conn, err := c.broker.Dial(id) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
return p.Client(c.broker, rpc.NewClient(conn)) | ||
} |
Oops, something went wrong.