Skip to content

Commit

Permalink
termite: introduce connMuxer to prepare for SSH connections.
Browse files Browse the repository at this point in the history
  • Loading branch information
hanwen committed Aug 25, 2014
1 parent a6bfd2b commit c6335be
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 23 deletions.
7 changes: 5 additions & 2 deletions termite/connection_test.go
Expand Up @@ -34,14 +34,17 @@ func TestAuthenticate(t *testing.T) {
addr := fmt.Sprintf("%s:%d", hostname, l.Addr().(*net.TCPAddr).Port)
dialer := newTCPDialer(secret)

c, _ := dialer.Open(addr, RPC_CHANNEL)
m, _ := dialer.Dial(addr)

c, _ := m.Open(RPC_CHANNEL)
c.Close()
if <-ch == nil {
t.Fatal("unexpected failure")
}

dialer = newTCPDialer([]byte("foobar"))
c, _ = dialer.Open(addr, RPC_CHANNEL)
m, _ = dialer.Dial(addr)
c, _ = m.Open(RPC_CHANNEL)
if c != nil {
c.Close()
}
Expand Down
28 changes: 18 additions & 10 deletions termite/coordinator.go
@@ -1,8 +1,8 @@
package termite

import (
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
Expand Down Expand Up @@ -91,15 +91,14 @@ func NewCoordinator(opts *CoordinatorOptions) *Coordinator {
}

func (c *Coordinator) Register(req *RegistrationRequest, rep *Empty) error {
conn, err := c.dialer.Open(req.Address, RPC_CHANNEL)
if conn != nil {
conn.Close()
}
rwc, err := c.dialWorker(req.Address)
if err != nil {
return errors.New(fmt.Sprintf(
"error contacting address: %v", err))
return fmt.Errorf(
"Dial(%v).Open(%q): %v", req.Address, RPC_CHANNEL, err)
}

rwc.Close()

c.mutex.Lock()
defer c.mutex.Unlock()

Expand Down Expand Up @@ -137,14 +136,23 @@ func (c *Coordinator) List(req *ListRequest, rep *ListResponse) error {
return nil
}

func (c *Coordinator) dialWorker(address string) (io.ReadWriteCloser, error) {
mux, err := c.dialer.Dial(address)
if err != nil {
return nil, err
}

return mux.Open(RPC_CHANNEL)
}

func (c *Coordinator) checkReachable() {
now := time.Now()

addrs := c.workerAddresses()

var toDelete []string
for _, a := range addrs {
conn, err := c.dialer.Open(a, RPC_CHANNEL)
conn, err := c.dialWorker(a)
if err != nil {
toDelete = append(toDelete, a)
} else {
Expand Down Expand Up @@ -178,7 +186,7 @@ func (c *Coordinator) PeriodicCheck() {
}

func (c *Coordinator) killWorker(addr string, restart bool) error {
conn, err := c.dialer.Open(addr, RPC_CHANNEL)
conn, err := c.dialWorker(addr)
if err == nil {
killReq := ShutdownRequest{Restart: restart}
rep := ShutdownResponse{}
Expand All @@ -190,7 +198,7 @@ func (c *Coordinator) killWorker(addr string, restart bool) error {
}

func (c *Coordinator) shutdownWorker(addr string, restart bool) error {
conn, err := c.dialer.Open(addr, RPC_CHANNEL)
conn, err := c.dialWorker(addr)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion termite/coordinatorhttp.go
Expand Up @@ -154,7 +154,7 @@ func (c *Coordinator) killHandler(w http.ResponseWriter, req *http.Request) {
addr, err := c.getHost(req)
var conn io.ReadWriteCloser
if err == nil {
conn, err = c.dialer.Open(addr, RPC_CHANNEL)
conn, err = c.dialWorker(addr)
}
if err != nil {
w.WriteHeader(http.StatusNotFound)
Expand Down
23 changes: 17 additions & 6 deletions termite/master.go
Expand Up @@ -285,35 +285,41 @@ func (m *Master) createMirror(addr string, jobs int) (*mirrorConnection, error)
c.Close()
}
}()
conn, err := m.dialer.Open(addr, RPC_CHANNEL)
mux, err := m.dialer.Dial(addr)
if err != nil {
return nil, err
}

conn, err := mux.Open(RPC_CHANNEL)
if err != nil {
return nil, err
}

defer conn.Close()

rpcId := ConnectionId()
rpcConn, err := m.dialer.Open(addr, rpcId)
rpcConn, err := mux.Open(rpcId)
if err != nil {
return nil, err
}
closeMe = append(closeMe, rpcConn)

revId := ConnectionId()
revConn, err := m.dialer.Open(addr, revId)
revConn, err := mux.Open(revId)
if err != nil {
return nil, err
}
closeMe = append(closeMe, revConn)

contentId := ConnectionId()
contentConn, err := m.dialer.Open(addr, contentId)
contentConn, err := mux.Open(contentId)
if err != nil {
return nil, err
}
closeMe = append(closeMe, contentConn)

revContentId := ConnectionId()
revContentConn, err := m.dialer.Open(addr, revContentId)
revContentConn, err := mux.Open(revContentId)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -370,7 +376,12 @@ func (m *Master) runOnMirror(mirror *mirrorConnection, req *WorkRequest, rep *Wo

// Tunnel stdin.
if req.StdinConn != nil {
destInputConn, err := m.dialer.Open(mirror.workerAddr, req.StdinId)
mux, err := m.dialer.Dial(mirror.workerAddr)
if err != nil {
return err
}

destInputConn, err := mux.Open(req.StdinId)
if err != nil {
return err
}
Expand Down
15 changes: 12 additions & 3 deletions termite/tcp-connection.go
Expand Up @@ -16,24 +16,33 @@ type tcpDialer struct {
secret []byte
}

type tcpMux struct {
dial *tcpDialer
addr string
}

// newTCPDialer returns a connDialer that uses plaintext TCP/IP
// connections, and HMAC-SHA1 for authentication. It should not
// be used in hostile environments.
func newTCPDialer(secret []byte) connDialer {
return &tcpDialer{secret}
}

func (c *tcpDialer) Open(addr string, id string) (io.ReadWriteCloser, error) {
func (c *tcpDialer) Dial(addr string) (connMuxer, error) {
return &tcpMux{c, addr}, nil
}

func (m *tcpMux) Open(id string) (io.ReadWriteCloser, error) {
if len(id) != HEADER_LEN {
return nil, fmt.Errorf("len(%q) != %d", id, HEADER_LEN)
}

conn, err := net.Dial("tcp", addr)
conn, err := net.Dial("tcp", m.addr)
if err != nil {
return nil, err
}

if err := authenticate(conn, c.secret); err != nil {
if err := authenticate(conn, m.dial.secret); err != nil {
return nil, err
}

Expand Down
6 changes: 5 additions & 1 deletion termite/workerconn.go
Expand Up @@ -7,7 +7,11 @@ import (

// connDialer dials connections that have IDs beyond address.
type connDialer interface {
Open(addr string, id string) (io.ReadWriteCloser, error)
Dial(addr string) (connMuxer, error)
}

type connMuxer interface {
Open(id string) (io.ReadWriteCloser, error)
}

// connListener accepts connections that have string IDs.
Expand Down

0 comments on commit c6335be

Please sign in to comment.