Skip to content

Commit

Permalink
termite: add SSH based connections.
Browse files Browse the repository at this point in the history
  • Loading branch information
hanwen committed Aug 26, 2014
1 parent c6335be commit 01d2dec
Show file tree
Hide file tree
Showing 5 changed files with 306 additions and 60 deletions.
2 changes: 0 additions & 2 deletions termite/connection.go
Expand Up @@ -15,8 +15,6 @@ import (

const _SOCKET = ".termite-socket"

const challengeLength = 20

var Hostname string

func init() {
Expand Down
78 changes: 77 additions & 1 deletion termite/connection_test.go
@@ -1,12 +1,15 @@
package termite

import (
"crypto/rand"
"crypto/rsa"
"fmt"
"io"
// "log"
"net"
"os"
"testing"

"code.google.com/p/go.crypto/ssh"
)

func TestAuthenticate(t *testing.T) {
Expand Down Expand Up @@ -53,3 +56,76 @@ func TestAuthenticate(t *testing.T) {
}
l.Close()
}

func testDialerMux(t *testing.T, dialer connDialer, listener connListener) {
found := make(chan bool, 10)
go func() {
for c := range listener.RPCChan() {
go func() {
var b [HEADER_LEN]byte
n, _ := c.Read(b[:])
conn := listener.Accept(string(b[:n]))
found <- conn != nil
}()
}
}()

mux, err := dialer.Dial(listener.Addr().String())
if err != nil {
t.Fatalf("Dial: %v", err)
}

id := ConnectionId()
if ch, err := mux.Open(RPC_CHANNEL); err != nil {
t.Fatalf("Open(%q): %v", RPC_CHANNEL, err)
} else {
ch.Write([]byte(id))
}

if ch, err := mux.Open(id); err != nil {
t.Fatalf("Open(%q): %v", id, err)
} else {
ch.Close()
}

mux.Close()

if !<-found {
t.Fatal("Did not accept requested channel.")
}
}

func TestTCPMux(t *testing.T) {
secret := make([]byte, 20)
dialer := newTCPDialer(secret)
l, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal("net.Listen", err)
}

listener := newTCPListener(l, secret)

testDialerMux(t, dialer, listener)
}

func TestSSHMux(t *testing.T) {
key, err := rsa.GenerateKey(rand.Reader, 512)
if err != nil {
t.Fatal("GenerateKey", err)
}

l, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal("net.Listen", err)
}

id, err := ssh.NewSignerFromKey(key)
if err != nil {
t.Fatal("NewSignerFromKey(%T)", key, err)
}

listener := newSSHListener(l, id)
dialer := newSSHDialer(id)

testDialerMux(t, dialer, listener)
}
164 changes: 164 additions & 0 deletions termite/ssh-conn.go
@@ -0,0 +1,164 @@
package termite

import (
"bytes"
"fmt"
"io"
"net"

"code.google.com/p/go.crypto/ssh"
)

type sshDialer struct {
identity ssh.Signer
}

func newSSHDialer(id ssh.Signer) connDialer {
return &sshDialer{id}
}

func (d *sshDialer) checkHost(hostname string, remote net.Addr, key ssh.PublicKey) error {
if bytes.Equal(key.Marshal(), d.identity.PublicKey().Marshal()) {
return nil
}
return fmt.Errorf("key mismatch")
}

func (d *sshDialer) Dial(addr string) (connMuxer, error) {
conf := ssh.ClientConfig{
User: "termite",
Auth: []ssh.AuthMethod{ssh.PublicKeys(d.identity)},
HostKeyCallback: d.checkHost,
}

c, err := net.Dial("tcp", addr)
if err != nil {
return nil, err
}

defer func() {
if c != nil {
c.Close()
}
}()
conn, chans, reqs, err := ssh.NewClientConn(c, addr, &conf)
if err != nil {
return nil, err
}
go ssh.DiscardRequests(reqs)
go func() {
for c := range chans {
go c.Reject(ssh.Prohibited, "")
}
}()

c = nil
return &sshMuxer{conn}, nil
}

type sshMuxer struct {
conn ssh.Conn
}

func (m *sshMuxer) Close() error {
return m.conn.Close()
}

func (m *sshMuxer) Open(id string) (io.ReadWriteCloser, error) {
channel, reqs, err := m.conn.OpenChannel(id, nil)
if err != nil {
return nil, err
}
go ssh.DiscardRequests(reqs)

return channel, nil
}

type sshListener struct {
id ssh.Signer
listener net.Listener
pending *pendingConns
rpcChans chan io.ReadWriteCloser
}

func (l *sshListener) Addr() net.Addr {
return l.listener.Addr()
}

func (l *sshListener) Accept(id string) io.ReadWriteCloser {
return l.pending.accept(id)
}

func (l *sshListener) Close() error {
err := l.listener.Close()
l.pending.fail()
return err
}

func (l *sshListener) RPCChan() <-chan io.ReadWriteCloser {
return l.rpcChans
}

func (l *sshListener) checkLogin(conn ssh.ConnMetadata, key ssh.PublicKey) (*ssh.Permissions, error) {
if bytes.Equal(key.Marshal(), l.id.PublicKey().Marshal()) && conn.User() == "termite" {
return nil, nil
}

return nil, fmt.Errorf("denied")
}

func newSSHListener(listener net.Listener, id ssh.Signer) connListener {
l := sshListener{
id: id,
pending: newPendingConns(),
listener: listener,
rpcChans: make(chan io.ReadWriteCloser, 1),
}
go l.loop()

return &l
}

func (l *sshListener) loop() {
for {
conn, err := l.listener.Accept()
if err != nil {
break
}

go l.handle(conn)
}
l.pending.fail()
}

func (l *sshListener) handle(c net.Conn) error {
conf := ssh.ServerConfig{
PublicKeyCallback: l.checkLogin,
}
conf.AddHostKey(l.id)
_, chans, reqs, err := ssh.NewServerConn(c, &conf)
if err != nil {
return err
}

go ssh.DiscardRequests(reqs)
for newCh := range chans {
id := newCh.ChannelType()
if len(id) != len(RPC_CHANNEL) {
newCh.Reject(ssh.Prohibited, "wrong ID length")
}

ch, reqs, err := newCh.Accept()
if err != nil {
continue
}
go ssh.DiscardRequests(reqs)

if id == RPC_CHANNEL {
l.rpcChans <- ch
} else {
l.pending.add(id, ch)
}
}
return nil
}
63 changes: 6 additions & 57 deletions termite/tcp-connection.go
Expand Up @@ -9,9 +9,10 @@ import (
"io"
"log"
"net"
"sync"
)

const challengeLength = 20

type tcpDialer struct {
secret []byte
}
Expand All @@ -32,6 +33,10 @@ func (c *tcpDialer) Dial(addr string) (connMuxer, error) {
return &tcpMux{c, addr}, nil
}

func (m *tcpMux) Close() error {
return nil
}

func (m *tcpMux) Open(id string) (io.ReadWriteCloser, error) {
if len(id) != HEADER_LEN {
return nil, fmt.Errorf("len(%q) != %d", id, HEADER_LEN)
Expand All @@ -52,62 +57,6 @@ func (m *tcpMux) Open(id string) (io.ReadWriteCloser, error) {
return conn, nil
}

type pendingConns struct {
conns map[string]io.ReadWriteCloser
cond sync.Cond
}

func newPendingConns() *pendingConns {
p := &pendingConns{
conns: map[string]io.ReadWriteCloser{},
}
p.cond.L = new(sync.Mutex)
return p
}

func (p *pendingConns) fail() {
p.cond.L.Lock()
defer p.cond.L.Unlock()
p.conns = nil
p.cond.Broadcast()
}

func (p *pendingConns) wait() {
p.cond.L.Lock()
defer p.cond.L.Unlock()
for p.conns != nil {
p.cond.Wait()
}
}

func (p *pendingConns) add(key string, conn io.ReadWriteCloser) {
p.cond.L.Lock()
defer p.cond.L.Unlock()
if p.conns == nil {
panic("shut down")
}
if p.conns[key] != nil {
panic("collision")
}
p.conns[key] = conn
p.cond.Broadcast()
}

func (p *pendingConns) accept(key string) io.ReadWriteCloser {
p.cond.L.Lock()
defer p.cond.L.Unlock()
for p.conns != nil && p.conns[key] == nil {
p.cond.Wait()
}
if p.conns == nil {
return nil
}

ch := p.conns[key]
delete(p.conns, key)
return ch
}

type tcpListener struct {
net.Listener
incoming chan io.ReadWriteCloser
Expand Down

0 comments on commit 01d2dec

Please sign in to comment.