Skip to content

Commit

Permalink
[proposal] ConnCallback (#36)
Browse files Browse the repository at this point in the history
ConnCallback lets you wrap connection objects for timeouts and limiting
  • Loading branch information
progrium committed Jul 12, 2017
1 parent bf30736 commit 33ad2fe
Show file tree
Hide file tree
Showing 11 changed files with 122 additions and 1 deletion.
2 changes: 1 addition & 1 deletion README.md
Expand Up @@ -30,7 +30,7 @@ This package was built after working on nearly a dozen projects at Glider Labs u

## Examples

A bunch of great examples are in the `_example` directory.
A bunch of great examples are in the `_examples` directory.

## Usage

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
58 changes: 58 additions & 0 deletions _examples/ssh-timeouts/timeouts.go
@@ -0,0 +1,58 @@
package main

import (
"fmt"
"log"
"net"
"time"

"github.com/gliderlabs/ssh"
)

var (
MaxLifeTimeout = 30 * time.Second
IdleTimeout = 5 * time.Second
)

type timeoutConn struct {
net.Conn
maxlife time.Time
idle time.Time
}

func (c *timeoutConn) Write(p []byte) (n int, err error) {
c.updateDeadline()
return c.Conn.Write(p)
}

func (c *timeoutConn) Read(b []byte) (n int, err error) {
c.idle = time.Now().Add(IdleTimeout)
c.updateDeadline()
return c.Conn.Read(b)
}

func (c *timeoutConn) updateDeadline() {
if c.idle.Unix() < c.maxlife.Unix() {
c.Conn.SetDeadline(c.idle)
} else {
c.Conn.SetDeadline(c.maxlife)
}
}

func main() {
ssh.Handle(func(s ssh.Session) {
i := 0
for {
i += 1
fmt.Fprintln(s, i)
time.Sleep(time.Second)
}
})

log.Println("starting ssh server on port 2222...")
log.Printf("connections will only last %s\n", MaxLifeTimeout)
log.Printf("and timeout after %s of no client activity\n", IdleTimeout)
log.Fatal(ssh.ListenAndServe(":2222", nil, ssh.WrapConn(func(conn net.Conn) net.Conn {
return &timeoutConn{conn, time.Now().Add(MaxLifeTimeout), time.Now().Add(IdleTimeout)}
})))
}
8 changes: 8 additions & 0 deletions options.go
Expand Up @@ -62,3 +62,11 @@ func NoPty() Option {
return nil
}
}

// WrapConn returns a functional option that sets ConnCallback on the server.
func WrapConn(fn ConnCallback) Option {
return func(srv *Server) error {
srv.ConnCallback = fn
return nil
}
}
41 changes: 41 additions & 0 deletions options_test.go
@@ -1,7 +1,9 @@
package ssh

import (
"net"
"strings"
"sync/atomic"
"testing"

gossh "golang.org/x/crypto/ssh"
Expand Down Expand Up @@ -66,3 +68,42 @@ func TestPasswordAuthBadPass(t *testing.T) {
}
}
}

type wrappedConn struct {
net.Conn
written int32
}

func (c *wrappedConn) Write(p []byte) (n int, err error) {
n, err = c.Conn.Write(p)
atomic.AddInt32(&(c.written), int32(n))
return
}

func TestConnWrapping(t *testing.T) {
t.Parallel()
var wrapped *wrappedConn
session, _, cleanup := newTestSessionWithOptions(t, &Server{
Handler: func(s Session) {
// nothing
},
}, &gossh.ClientConfig{
User: "testuser",
Auth: []gossh.AuthMethod{
gossh.Password("testpass"),
},
HostKeyCallback: gossh.InsecureIgnoreHostKey(),
}, PasswordAuth(func(ctx Context, password string) bool {
return true
}), WrapConn(func(conn net.Conn) net.Conn {
wrapped = &wrappedConn{conn, 0}
return wrapped
}))
defer cleanup()
if err := session.Shell(); err != nil {
t.Fatal(err)
}
if atomic.LoadInt32(&(wrapped.written)) == 0 {
t.Fatal("wrapped conn not written to")
}
}
9 changes: 9 additions & 0 deletions server.go
Expand Up @@ -27,6 +27,7 @@ type Server struct {
PasswordHandler PasswordHandler // password authentication handler
PublicKeyHandler PublicKeyHandler // public key authentication handler
PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil
ConnCallback ConnCallback // optional callback for wrapping net.Conn before handling
LocalPortForwardingCallback LocalPortForwardingCallback // callback for allowing local port forwarding, denies all if nil

channelHandlers map[string]channelHandler
Expand Down Expand Up @@ -191,6 +192,14 @@ func (srv *Server) Serve(l net.Listener) error {
}

func (srv *Server) handleConn(conn net.Conn) {
if srv.ConnCallback != nil {
cbConn := srv.ConnCallback(conn)
if cbConn == nil {
conn.Close()
return
}
conn = cbConn
}
defer conn.Close()
ctx := newContext(srv)
sshConn, chans, reqs, err := gossh.NewServerConn(conn, srv.config(ctx))
Expand Down
5 changes: 5 additions & 0 deletions ssh.go
Expand Up @@ -42,6 +42,11 @@ type PasswordHandler func(ctx Context, password string) bool
// PtyCallback is a hook for allowing PTY sessions.
type PtyCallback func(ctx Context, pty Pty) bool

// ConnCallback is a hook for new connections before handling.
// It allows wrapping for timeouts and limiting by returning
// the net.Conn that will be used as the underlying connection.
type ConnCallback func(conn net.Conn) net.Conn

// LocalPortForwardingCallback is a hook for allowing port forwarding
type LocalPortForwardingCallback func(ctx Context, destinationHost string, destinationPort uint32) bool

Expand Down

0 comments on commit 33ad2fe

Please sign in to comment.