Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

contexts #29

Merged
merged 4 commits into from Mar 14, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion _example/ssh-publickey/public_key.go
Expand Up @@ -16,7 +16,7 @@ func main() {
s.Write(authorizedKey)
})

publicKeyOption := ssh.PublicKeyAuth(func(user string, key ssh.PublicKey) bool {
publicKeyOption := ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
return true // allow all keys, or use ssh.KeysEqual() to compare against known keys
})

Expand Down
142 changes: 142 additions & 0 deletions context.go
@@ -0,0 +1,142 @@
package ssh

import (
"context"
"net"

gossh "golang.org/x/crypto/ssh"
)

// contextKey is a value for use with context.WithValue. It's used as
// a pointer so it fits in an interface{} without allocation.
type contextKey struct {
name string
}

var (
// ContextKeyUser is a context key for use with Contexts in this package.
// The associated value will be of type string.
ContextKeyUser = &contextKey{"user"}

// ContextKeySessionID is a context key for use with Contexts in this package.
// The associated value will be of type string.
ContextKeySessionID = &contextKey{"session-id"}

// ContextKeyPermissions is a context key for use with Contexts in this package.
// The associated value will be of type *Permissions.
ContextKeyPermissions = &contextKey{"permissions"}

// ContextKeyClientVersion is a context key for use with Contexts in this package.
// The associated value will be of type string.
ContextKeyClientVersion = &contextKey{"client-version"}

// ContextKeyServerVersion is a context key for use with Contexts in this package.
// The associated value will be of type string.
ContextKeyServerVersion = &contextKey{"server-version"}

// ContextKeyLocalAddr is a context key for use with Contexts in this package.
// The associated value will be of type net.Addr.
ContextKeyLocalAddr = &contextKey{"local-addr"}

// ContextKeyRemoteAddr is a context key for use with Contexts in this package.
// The associated value will be of type net.Addr.
ContextKeyRemoteAddr = &contextKey{"remote-addr"}

// ContextKeyServer is a context key for use with Contexts in this package.
// The associated value will be of type *Server.
ContextKeyServer = &contextKey{"ssh-server"}

// ContextKeyPublicKey is a context key for use with Contexts in this package.
// The associated value will be of type PublicKey.
ContextKeyPublicKey = &contextKey{"public-key"}
)

// Context is a package specific context interface. It exposes connection
// metadata and allows new values to be easily written to it. It's used in
// authentication handlers and callbacks, and its underlying context.Context is
// exposed on Session in the session Handler.
type Context interface {
context.Context

// User returns the username used when establishing the SSH connection.
User() string

// SessionID returns the session hash.
SessionID() string

// ClientVersion returns the version reported by the client.
ClientVersion() string

// ServerVersion returns the version reported by the server.
ServerVersion() string

// RemoteAddr returns the remote address for this connection.
RemoteAddr() net.Addr

// LocalAddr returns the local address for this connection.
LocalAddr() net.Addr

// Permissions returns the Permissions object used for this connection.
Permissions() *Permissions

// SetValue allows you to easily write new values into the underlying context.
SetValue(key, value interface{})
}

type sshContext struct {
context.Context
}

func newContext(srv *Server) *sshContext {
ctx := &sshContext{context.Background()}
ctx.SetValue(ContextKeyServer, srv)
perms := &Permissions{&gossh.Permissions{}}
ctx.SetValue(ContextKeyPermissions, perms)
return ctx
}

// this is separate from newContext because we will get ConnMetadata
// at different points so it needs to be applied separately
func (ctx *sshContext) applyConnMetadata(conn gossh.ConnMetadata) {
if ctx.Value(ContextKeySessionID) != nil {
return
}
ctx.SetValue(ContextKeySessionID, string(conn.SessionID()))
ctx.SetValue(ContextKeyClientVersion, string(conn.ClientVersion()))
ctx.SetValue(ContextKeyServerVersion, string(conn.ServerVersion()))
ctx.SetValue(ContextKeyUser, conn.User())
ctx.SetValue(ContextKeyLocalAddr, conn.LocalAddr())
ctx.SetValue(ContextKeyRemoteAddr, conn.RemoteAddr())
}

func (ctx *sshContext) SetValue(key, value interface{}) {
ctx.Context = context.WithValue(ctx.Context, key, value)
}

func (ctx *sshContext) User() string {
return ctx.Value(ContextKeyUser).(string)
}

func (ctx *sshContext) SessionID() string {
return ctx.Value(ContextKeySessionID).(string)
}

func (ctx *sshContext) ClientVersion() string {
return ctx.Value(ContextKeyClientVersion).(string)
}

func (ctx *sshContext) ServerVersion() string {
return ctx.Value(ContextKeyServerVersion).(string)
}

func (ctx *sshContext) RemoteAddr() net.Addr {
return ctx.Value(ContextKeyRemoteAddr).(net.Addr)
}

func (ctx *sshContext) LocalAddr() net.Addr {
return ctx.Value(ContextKeyLocalAddr).(net.Addr)
}

func (ctx *sshContext) Permissions() *Permissions {
return ctx.Value(ContextKeyPermissions).(*Permissions)
}
47 changes: 47 additions & 0 deletions context_test.go
@@ -0,0 +1,47 @@
package ssh

import "testing"

func TestSetPermissions(t *testing.T) {
t.Parallel()
permsExt := map[string]string{
"foo": "bar",
}
session, cleanup := newTestSessionWithOptions(t, &Server{
Handler: func(s Session) {
if _, ok := s.Permissions().Extensions["foo"]; !ok {
t.Fatalf("got %#v; want %#v", s.Permissions().Extensions, permsExt)
}
},
}, nil, PasswordAuth(func(ctx Context, password string) bool {
ctx.Permissions().Extensions = permsExt
return true
}))
defer cleanup()
if err := session.Run(""); err != nil {
t.Fatal(err)
}
}

func TestSetValue(t *testing.T) {
t.Parallel()
value := map[string]string{
"foo": "bar",
}
key := "testValue"
session, cleanup := newTestSessionWithOptions(t, &Server{
Handler: func(s Session) {
v := s.Context().Value(key).(map[string]string)
if v["foo"] != value["foo"] {
t.Fatalf("got %#v; want %#v", v, value)
}
},
}, nil, PasswordAuth(func(ctx Context, password string) bool {
ctx.SetValue(key, value)
return true
}))
defer cleanup()
if err := session.Run(""); err != nil {
t.Fatal(err)
}
}
4 changes: 2 additions & 2 deletions example_test.go
Expand Up @@ -15,7 +15,7 @@ func ExampleListenAndServe() {

func ExamplePasswordAuth() {
ssh.ListenAndServe(":2222", nil,
ssh.PasswordAuth(func(user, pass string) bool {
ssh.PasswordAuth(func(ctx ssh.Context, pass string) bool {
return pass == "secret"
}),
)
Expand All @@ -27,7 +27,7 @@ func ExampleNoPty() {

func ExamplePublicKeyAuth() {
ssh.ListenAndServe(":2222", nil,
ssh.PublicKeyAuth(func(user string, key ssh.PublicKey) bool {
ssh.PublicKeyAuth(func(ctx ssh.Context, key ssh.PublicKey) bool {
data, _ := ioutil.ReadFile("/path/to/allowed/key.pub")
allowed, _, _, _, _ := ssh.ParseAuthorizedKey(data)
return ssh.KeysEqual(key, allowed)
Expand Down
2 changes: 1 addition & 1 deletion options.go
Expand Up @@ -56,7 +56,7 @@ func HostKeyPEM(bytes []byte) Option {
// denying PTY requests.
func NoPty() Option {
return func(srv *Server) error {
srv.PtyCallback = func(user string, permissions *Permissions) bool {
srv.PtyCallback = func(ctx Context, pty Pty) bool {
return false
}
return nil
Expand Down
66 changes: 66 additions & 0 deletions options_test.go
@@ -0,0 +1,66 @@
package ssh

import (
"strings"
"testing"

gossh "golang.org/x/crypto/ssh"
)

func newTestSessionWithOptions(t *testing.T, srv *Server, cfg *gossh.ClientConfig, options ...Option) (*gossh.Session, func()) {
for _, option := range options {
if err := srv.SetOption(option); err != nil {
t.Fatal(err)
}
}
return newTestSession(t, srv, cfg)
}

func TestPasswordAuth(t *testing.T) {
t.Parallel()
testUser := "testuser"
testPass := "testpass"
session, cleanup := newTestSessionWithOptions(t, &Server{
Handler: func(s Session) {
// noop
},
}, &gossh.ClientConfig{
User: testUser,
Auth: []gossh.AuthMethod{
gossh.Password(testPass),
},
}, PasswordAuth(func(ctx Context, password string) bool {
if ctx.User() != testUser {
t.Fatalf("user = %#v; want %#v", ctx.User(), testUser)
}
if password != testPass {
t.Fatalf("user = %#v; want %#v", password, testPass)
}
return true
}))
defer cleanup()
if err := session.Run(""); err != nil {
t.Fatal(err)
}
}

func TestPasswordAuthBadPass(t *testing.T) {
t.Parallel()
l := newLocalListener()
srv := &Server{Handler: func(s Session) {}}
srv.SetOption(PasswordAuth(func(ctx Context, password string) bool {
return false
}))
go srv.serveOnce(l)
_, err := gossh.Dial("tcp", l.Addr().String(), &gossh.ClientConfig{
User: "testuser",
Auth: []gossh.AuthMethod{
gossh.Password("testpass"),
},
})
if err != nil {
if !strings.Contains(err.Error(), "unable to authenticate") {
t.Fatal(err)
}
}
}