Skip to content

Commit

Permalink
ssh: allow server auth callbacks to send additional banners
Browse files Browse the repository at this point in the history
Add a new BannerError error type that auth callbacks can return to send
banner to the client. While the BannerCallback can send the initial
banner message, auth callbacks might want to communicate more
information to the client to help them diagnose failures.

Updates golang/go#64962

Change-Id: I97a26480ff4064b95a0a26042b0a5e19737cfb62
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/558695
LUCI-TryBot-Result: Go LUCI <golang-scoped@luci-project-accounts.iam.gserviceaccount.com>
Reviewed-by: Roland Shoemaker <roland@golang.org>
Reviewed-by: Nicola Murino <nicola.murino@gmail.com>
Auto-Submit: Nicola Murino <nicola.murino@gmail.com>
Reviewed-by: Dmitri Shuralyov <dmitshur@google.com>
  • Loading branch information
awly authored and gopherbot committed May 22, 2024
1 parent 67b1361 commit 44c9b0f
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 0 deletions.
30 changes: 30 additions & 0 deletions ssh/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,24 @@ func (p *PartialSuccessError) Error() string {
// It is returned in ServerAuthError.Errors from NewServerConn.
var ErrNoAuth = errors.New("ssh: no auth passed yet")

// BannerError is an error that can be returned by authentication handlers in
// ServerConfig to send a banner message to the client.
type BannerError struct {
Err error
Message string
}

func (b *BannerError) Unwrap() error {
return b.Err
}

func (b *BannerError) Error() string {
if b.Err == nil {
return b.Message
}
return b.Err.Error()
}

func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, error) {
sessionID := s.transport.getSessionID()
var cache pubKeyCache
Expand Down Expand Up @@ -734,6 +752,18 @@ userAuthLoop:
config.AuthLogCallback(s, userAuthReq.Method, authErr)
}

var bannerErr *BannerError
if errors.As(authErr, &bannerErr) {
if bannerErr.Message != "" {
bannerMsg := &userAuthBannerMsg{
Message: bannerErr.Message,
}
if err := s.transport.writePacket(Marshal(bannerMsg)); err != nil {
return nil, err
}
}
}

if authErr == nil {
break userAuthLoop
}
Expand Down
74 changes: 74 additions & 0 deletions ssh/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@ package ssh

import (
"errors"
"fmt"
"io"
"net"
"slices"
"strings"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -225,6 +227,78 @@ func TestNewServerConnValidationErrors(t *testing.T) {
}
}

func TestBannerError(t *testing.T) {
serverConfig := &ServerConfig{
BannerCallback: func(ConnMetadata) string {
return "banner from BannerCallback"
},
NoClientAuth: true,
NoClientAuthCallback: func(ConnMetadata) (*Permissions, error) {
err := &BannerError{
Err: errors.New("error from NoClientAuthCallback"),
Message: "banner from NoClientAuthCallback",
}
return nil, fmt.Errorf("wrapped: %w", err)
},
PasswordCallback: func(conn ConnMetadata, password []byte) (*Permissions, error) {
return &Permissions{}, nil
},
PublicKeyCallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) {
return nil, &BannerError{
Err: errors.New("error from PublicKeyCallback"),
Message: "banner from PublicKeyCallback",
}
},
KeyboardInteractiveCallback: func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error) {
return nil, &BannerError{
Err: nil, // make sure that a nil inner error is allowed
Message: "banner from KeyboardInteractiveCallback",
}
},
}
serverConfig.AddHostKey(testSigners["rsa"])

var banners []string
clientConfig := &ClientConfig{
User: "test",
Auth: []AuthMethod{
PublicKeys(testSigners["rsa"]),
KeyboardInteractive(func(name, instruction string, questions []string, echos []bool) ([]string, error) {
return []string{"letmein"}, nil
}),
Password(clientPassword),
},
HostKeyCallback: InsecureIgnoreHostKey(),
BannerCallback: func(msg string) error {
banners = append(banners, msg)
return nil
},
}

c1, c2, err := netPipe()
if err != nil {
t.Fatalf("netPipe: %v", err)
}
defer c1.Close()
defer c2.Close()
go newServer(c1, serverConfig)
c, _, _, err := NewClientConn(c2, "", clientConfig)
if err != nil {
t.Fatalf("client connection failed: %v", err)
}
defer c.Close()

wantBanners := []string{
"banner from BannerCallback",
"banner from NoClientAuthCallback",
"banner from PublicKeyCallback",
"banner from KeyboardInteractiveCallback",
}
if !slices.Equal(banners, wantBanners) {
t.Errorf("got banners:\n%q\nwant banners:\n%q", banners, wantBanners)
}
}

type markerConn struct {
closed uint32
used uint32
Expand Down

0 comments on commit 44c9b0f

Please sign in to comment.