Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

contexts #29

Merged
merged 4 commits into from Mar 14, 2017
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
102 changes: 102 additions & 0 deletions context.go
@@ -0,0 +1,102 @@
package ssh

import (
"context"
"net"

gossh "golang.org/x/crypto/ssh"
)

// contextKey is a value for use with context.WithValue. It's used as
// a pointer so it fits in an interface{} without allocation.
type contextKey struct {
name string
}

var (
ContextKeyUser = &contextKey{"user"}
ContextKeySessionID = &contextKey{"session-id"}
ContextKeyPermissions = &contextKey{"permissions"}
ContextKeyClientVersion = &contextKey{"client-version"}
ContextKeyServerVersion = &contextKey{"server-version"}
ContextKeyLocalAddr = &contextKey{"local-addr"}
ContextKeyRemoteAddr = &contextKey{"remote-addr"}
ContextKeyServer = &contextKey{"ssh-server"}
ContextKeyPublicKey = &contextKey{"public-key"}
)

type Context interface {
context.Context
User() string
SessionID() string
ClientVersion() string
ServerVersion() string
RemoteAddr() net.Addr
LocalAddr() net.Addr
Permissions() *Permissions
SetValue(key, value interface{})
}

type sshContext struct {
context.Context
}

func newContext(srv *Server) *sshContext {
ctx := &sshContext{context.Background()}
ctx.SetValue(ContextKeyServer, srv)
perms := &Permissions{&gossh.Permissions{}}
ctx.SetValue(ContextKeyPermissions, perms)
return ctx
}

// this is separate from newContext because we will get ConnMetadata
// at different points so it needs to be applied separately
func (ctx *sshContext) applyConnMetadata(conn gossh.ConnMetadata) {
if ctx.Value(ContextKeySessionID) != nil {
return
}
// for most of these, instead of converting to strings now, storing the byte
// slices means allocations only happen when accessing, not when contexts
// are being copied around
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Though after I saw how context.WithValue is implemented this might not be an issue.

ctx.SetValue(ContextKeySessionID, conn.SessionID())
ctx.SetValue(ContextKeyClientVersion, conn.ClientVersion())
ctx.SetValue(ContextKeyServerVersion, conn.ServerVersion())
ctx.SetValue(ContextKeyUser, conn.User())
ctx.SetValue(ContextKeyLocalAddr, conn.LocalAddr())
ctx.SetValue(ContextKeyRemoteAddr, conn.RemoteAddr())
}

func (ctx *sshContext) SetValue(key, value interface{}) {
ctx.Context = context.WithValue(ctx.Context, key, value)
}

func (ctx *sshContext) User() string {
return ctx.Value(ContextKeyUser).(string)
}

func (ctx *sshContext) SessionID() string {
id, _ := ctx.Value(ContextKeySessionID).([]byte)
return string(id)
}

func (ctx *sshContext) ClientVersion() string {
version, _ := ctx.Value(ContextKeyClientVersion).([]byte)
return string(version)
}

func (ctx *sshContext) ServerVersion() string {
version, _ := ctx.Value(ContextKeyServerVersion).([]byte)
return string(version)
}

func (ctx *sshContext) RemoteAddr() net.Addr {
return ctx.Value(ContextKeyRemoteAddr).(net.Addr)
}

func (ctx *sshContext) LocalAddr() net.Addr {
return ctx.Value(ContextKeyLocalAddr).(net.Addr)
}

func (ctx *sshContext) Permissions() *Permissions {
return ctx.Value(ContextKeyPermissions).(*Permissions)
}
47 changes: 47 additions & 0 deletions context_test.go
@@ -0,0 +1,47 @@
package ssh

import "testing"

func TestSetPermissions(t *testing.T) {
t.Parallel()
permsExt := map[string]string{
"foo": "bar",
}
session, cleanup := newTestSessionWithOptions(t, &Server{
Handler: func(s Session) {
if _, ok := s.Permissions().Extensions["foo"]; !ok {
t.Fatalf("got %#v; want %#v", s.Permissions().Extensions, permsExt)
}
},
}, nil, PasswordAuth(func(ctx Context, password string) bool {
ctx.Permissions().Extensions = permsExt
return true
}))
defer cleanup()
if err := session.Run(""); err != nil {
t.Fatal(err)
}
}

func TestSetValue(t *testing.T) {
t.Parallel()
value := map[string]string{
"foo": "bar",
}
key := "testValue"
session, cleanup := newTestSessionWithOptions(t, &Server{
Handler: func(s Session) {
v := s.Context().Value(key).(map[string]string)
if v["foo"] != value["foo"] {
t.Fatalf("got %#v; want %#v", v, value)
}
},
}, nil, PasswordAuth(func(ctx Context, password string) bool {
ctx.SetValue(key, value)
return true
}))
defer cleanup()
if err := session.Run(""); err != nil {
t.Fatal(err)
}
}
4 changes: 2 additions & 2 deletions example_test.go
Expand Up @@ -15,7 +15,7 @@ func ExampleListenAndServe() {

func ExamplePasswordAuth() {
ssh.ListenAndServe(":2222", nil,
ssh.PasswordAuth(func(user, pass string) bool {
ssh.PasswordAuth(func(ctx ssh.Context, pass string) bool {
return pass == "secret"
}),
)
Expand All @@ -27,7 +27,7 @@ func ExampleNoPty() {

func ExamplePublicKeyAuth() {
ssh.ListenAndServe(":2222", nil,
ssh.PublicKeyAuth(func(user string, key ssh.PublicKey) bool {
ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
data, _ := ioutil.ReadFile("/path/to/allowed/key.pub")
allowed, _, _, _, _ := ssh.ParseAuthorizedKey(data)
return ssh.KeysEqual(key, allowed)
Expand Down
2 changes: 1 addition & 1 deletion options.go
Expand Up @@ -56,7 +56,7 @@ func HostKeyPEM(bytes []byte) Option {
// denying PTY requests.
func NoPty() Option {
return func(srv *Server) error {
srv.PtyCallback = func(user string, permissions *Permissions) bool {
srv.PtyCallback = func(ctx Context, pty Pty) bool {
return false
}
return nil
Expand Down
66 changes: 66 additions & 0 deletions options_test.go
@@ -0,0 +1,66 @@
package ssh

import (
"strings"
"testing"

gossh "golang.org/x/crypto/ssh"
)

func newTestSessionWithOptions(t *testing.T, srv *Server, cfg *gossh.ClientConfig, options ...Option) (*gossh.Session, func()) {
for _, option := range options {
if err := srv.SetOption(option); err != nil {
t.Fatal(err)
}
}
return newTestSession(t, srv, cfg)
}

func TestPasswordAuth(t *testing.T) {
t.Parallel()
testUser := "testuser"
testPass := "testpass"
session, cleanup := newTestSessionWithOptions(t, &Server{
Handler: func(s Session) {
// noop
},
}, &gossh.ClientConfig{
User: testUser,
Auth: []gossh.AuthMethod{
gossh.Password(testPass),
},
}, PasswordAuth(func(ctx Context, password string) bool {
if ctx.User() != testUser {
t.Fatalf("user = %#v; want %#v", ctx.User(), testUser)
}
if password != testPass {
t.Fatalf("user = %#v; want %#v", password, testPass)
}
return true
}))
defer cleanup()
if err := session.Run(""); err != nil {
t.Fatal(err)
}
}

func TestPasswordAuthBadPass(t *testing.T) {
t.Parallel()
l := newLocalListener()
srv := &Server{Handler: func(s Session) {}}
srv.SetOption(PasswordAuth(func(ctx Context, password string) bool {
return false
}))
go srv.serveOnce(l)
_, err := gossh.Dial("tcp", l.Addr().String(), &gossh.ClientConfig{
User: "testuser",
Auth: []gossh.AuthMethod{
gossh.Password("testpass"),
},
})
if err != nil {
if !strings.Contains(err.Error(), "unable to authenticate") {
t.Fatal(err)
}
}
}
69 changes: 33 additions & 36 deletions server.go
Expand Up @@ -17,21 +17,24 @@ type Server struct {
HostSigners []Signer // private keys for the host key, must have at least one
Version string // server version to be sent before the initial handshake

PasswordHandler PasswordHandler // password authentication handler
PublicKeyHandler PublicKeyHandler // public key authentication handler
PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil
PermissionsCallback PermissionsCallback // optional callback for setting up permissions
PasswordHandler PasswordHandler // password authentication handler
PublicKeyHandler PublicKeyHandler // public key authentication handler
PtyCallback PtyCallback // callback for allowing PTY sessions, allows all if nil
}

func (srv *Server) makeConfig() (*gossh.ServerConfig, error) {
config := &gossh.ServerConfig{}
func (srv *Server) ensureHostSigner() error {
if len(srv.HostSigners) == 0 {
signer, err := generateSigner()
if err != nil {
return nil, err
return err
}
srv.HostSigners = append(srv.HostSigners, signer)
}
return nil
}

func (srv *Server) config(ctx *sshContext) *gossh.ServerConfig {
config := &gossh.ServerConfig{}
for _, signer := range srv.HostSigners {
config.AddHostKey(signer)
}
Expand All @@ -43,34 +46,24 @@ func (srv *Server) makeConfig() (*gossh.ServerConfig, error) {
}
if srv.PasswordHandler != nil {
config.PasswordCallback = func(conn gossh.ConnMetadata, password []byte) (*gossh.Permissions, error) {
perms := &gossh.Permissions{}
if ok := srv.PasswordHandler(conn.User(), string(password)); !ok {
return perms, fmt.Errorf("permission denied")
ctx.applyConnMetadata(conn)
if ok := srv.PasswordHandler(ctx, string(password)); !ok {
return ctx.Permissions().Permissions, fmt.Errorf("permission denied")
}
if srv.PermissionsCallback != nil {
srv.PermissionsCallback(conn.User(), &Permissions{perms})
}
return perms, nil
return ctx.Permissions().Permissions, nil
}
}
if srv.PublicKeyHandler != nil {
config.PublicKeyCallback = func(conn gossh.ConnMetadata, key gossh.PublicKey) (*gossh.Permissions, error) {
perms := &gossh.Permissions{}
if ok := srv.PublicKeyHandler(conn.User(), key); !ok {
return perms, fmt.Errorf("permission denied")
}
// no other way to pass the key from
// auth handler to session handler
perms.Extensions = map[string]string{
"_publickey": string(key.Marshal()),
ctx.applyConnMetadata(conn)
if ok := srv.PublicKeyHandler(ctx, key); !ok {
return ctx.Permissions().Permissions, fmt.Errorf("permission denied")
}
if srv.PermissionsCallback != nil {
srv.PermissionsCallback(conn.User(), &Permissions{perms})
}
return perms, nil
ctx.SetValue(ContextKeyPublicKey, key)
return ctx.Permissions().Permissions, nil
}
}
return config, nil
return config
}

// Handle sets the Handler for the server.
Expand All @@ -85,8 +78,7 @@ func (srv *Server) Handle(fn Handler) {
// Serve always returns a non-nil error.
func (srv *Server) Serve(l net.Listener) error {
defer l.Close()
config, err := srv.makeConfig()
if err != nil {
if err := srv.ensureHostSigner(); err != nil {
return err
}
if srv.Handler == nil {
Expand All @@ -110,41 +102,46 @@ func (srv *Server) Serve(l net.Listener) error {
}
return e
}
go srv.handleConn(conn, config)
go srv.handleConn(conn)
}
}

func (srv *Server) handleConn(conn net.Conn, conf *gossh.ServerConfig) {
func (srv *Server) handleConn(conn net.Conn) {
defer conn.Close()
sshConn, chans, reqs, err := gossh.NewServerConn(conn, conf)
ctx := newContext(srv)
sshConn, chans, reqs, err := gossh.NewServerConn(conn, srv.config(ctx))
if err != nil {
// TODO: trigger event callback
return
}
ctx.applyConnMetadata(sshConn)
go gossh.DiscardRequests(reqs)
for ch := range chans {
if ch.ChannelType() != "session" {
ch.Reject(gossh.UnknownChannelType, "unsupported channel type")
continue
}
go srv.handleChannel(sshConn, ch)
go srv.handleChannel(sshConn, ch, ctx)
}
}

func (srv *Server) handleChannel(conn *gossh.ServerConn, newChan gossh.NewChannel) {
func (srv *Server) handleChannel(conn *gossh.ServerConn, newChan gossh.NewChannel, ctx *sshContext) {
ch, reqs, err := newChan.Accept()
if err != nil {
// TODO: trigger event callback
return
}
sess := srv.newSession(conn, ch)
sess := srv.newSession(conn, ch, ctx)
sess.handleRequests(reqs)
}

func (srv *Server) newSession(conn *gossh.ServerConn, ch gossh.Channel) *session {
func (srv *Server) newSession(conn *gossh.ServerConn, ch gossh.Channel, ctx *sshContext) *session {
sess := &session{
Channel: ch,
conn: conn,
handler: srv.Handler,
ptyCb: srv.PtyCallback,
ctx: ctx,
}
return sess
}
Expand Down